diff --git a/msgpack/_unpacker.pyx b/msgpack/_unpacker.pyx index 40d12291..4bfbe064 100644 --- a/msgpack/_unpacker.pyx +++ b/msgpack/_unpacker.pyx @@ -317,9 +317,6 @@ cdef class Unpacker: cdef Py_ssize_t max_buffer_size cdef uint64_t stream_offset - def __cinit__(self): - self.buf = NULL - def __dealloc__(self): unpack_clear(&self.ctx) PyMem_Free(self.buf) @@ -338,6 +335,12 @@ cdef class Unpacker: Py_ssize_t max_ext_len=-1): cdef const char *cerr=NULL + unpack_clear(&self.ctx) + unpack_init(&self.ctx) + if self.buf != NULL: + PyMem_Free(self.buf) + self.buf = NULL + self.object_hook = object_hook self.object_pairs_hook = object_pairs_hook self.list_hook = list_hook diff --git a/test/test_unpack.py b/test/test_unpack.py index b17c3c53..705c16a6 100644 --- a/test/test_unpack.py +++ b/test/test_unpack.py @@ -1,4 +1,6 @@ +import gc import sys +import weakref from io import BytesIO from pytest import mark, raises @@ -87,3 +89,37 @@ def test_unpacker_tell_read_bytes(): assert obj == unp assert pos == unpacker.tell() assert unpacker.read_bytes(n) == raw + + +@mark.skipif( + Unpacker.__module__ == "msgpack.fallback", + reason="specific to C extension reinit leak", +) +def test_unpacker_reinit_clears_partial_state(): + refs = [] + + class Marker: + pass + + def hook(code, data): + obj = Marker() + refs.append(weakref.ref(obj)) + return obj + + unpacker = Unpacker(ext_hook=hook, strict_map_key=False) + # Keep parser state mid-map with a live key object from ext_hook. + # Encodes: [ {ExtType(1, b"a"): } ]. + unpacker.feed(b"\x91\x81\xd4\x01a") + with raises(OutOfData): + unpacker.unpack() + assert len(refs) == 1 + assert refs[0]() is not None + + unpacker.__init__() + gc.collect() + assert refs[0]() is None + with raises(OutOfData): + unpacker.unpack() + + unpacker.feed(packb({"a": 1})) + assert unpacker.unpack() == {"a": 1}