Skip to content

Commit 8983373

Browse files
committed
implement __deepcopy__
1 parent 3244bb9 commit 8983373

File tree

2 files changed

+134
-1
lines changed

2 files changed

+134
-1
lines changed

src/msgspec/_core.c

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ typedef struct {
519519
#endif
520520
PyObject *astimezone;
521521
PyObject *re_compile;
522+
PyObject *copy_deepcopy;
522523
uint8_t gc_cycle;
523524
} MsgspecState;
524525

@@ -7938,6 +7939,66 @@ Struct_copy(PyObject *self, PyObject *args)
79387939
return NULL;
79397940
}
79407941

7942+
7943+
static PyObject* get_deepcopy_func() {
7944+
// lazily copy.deepcopy and cache in global state
7945+
PyObject *copy_mod, *deepcopy_func;
7946+
MsgspecState* st = msgspec_get_global_state();
7947+
deepcopy_func = st->copy_deepcopy;
7948+
if (deepcopy_func == NULL) {
7949+
copy_mod = PyImport_ImportModule("copy");
7950+
if (copy_mod == NULL) return NULL;
7951+
deepcopy_func = PyObject_GetAttrString(copy_mod, "deepcopy");
7952+
st->copy_deepcopy = deepcopy_func;
7953+
Py_DECREF(copy_mod);
7954+
if (st->copy_deepcopy == NULL) return NULL;
7955+
}
7956+
7957+
return deepcopy_func;
7958+
}
7959+
7960+
static PyObject *
7961+
Struct_deepcopy(PyObject *self, PyObject *args)
7962+
{
7963+
PyObject *memo;
7964+
PyObject *val = NULL, *res = NULL, *dc_val = NULL;
7965+
PyObject *deepcopy_func;
7966+
Py_ssize_t i, nfields;
7967+
7968+
if (!PyArg_ParseTuple(args, "O!:__deepcopy__", &PyDict_Type, &memo))
7969+
return NULL;
7970+
7971+
deepcopy_func = get_deepcopy_func();
7972+
7973+
res = Struct_alloc(Py_TYPE(self));
7974+
if (res == NULL)
7975+
return NULL;
7976+
7977+
nfields = StructMeta_GET_NFIELDS(Py_TYPE(self));
7978+
for (i = 0; i < nfields; i++) {
7979+
val = Struct_get_index(self, i);
7980+
if (val == NULL)
7981+
goto error;
7982+
7983+
dc_val = PyObject_CallFunctionObjArgs(deepcopy_func, val, memo, NULL);
7984+
if (dc_val == NULL)
7985+
goto error;
7986+
7987+
Struct_set_index(res, i, dc_val);
7988+
}
7989+
7990+
/* If self is tracked, then copy is tracked */
7991+
if (MS_OBJECT_IS_GC(self) && MS_IS_TRACKED(self))
7992+
PyObject_GC_Track(res);
7993+
7994+
return res;
7995+
7996+
error:
7997+
Py_DECREF(res);
7998+
return NULL;
7999+
}
8000+
8001+
79418002
static PyObject *
79428003
Struct_replace(
79438004
PyObject *self,
@@ -8360,6 +8421,7 @@ StructMixin_config(StructMetaObject *self, void *closure) {
83608421

83618422
static PyMethodDef Struct_methods[] = {
83628423
{"__copy__", Struct_copy, METH_NOARGS, "copy a struct"},
8424+
{"__deepcopy__", Struct_deepcopy, METH_VARARGS, "deepcopy a struct"},
83638425
{"__replace__", (PyCFunction) Struct_replace, METH_FASTCALL | METH_KEYWORDS, "create a new struct with replacements" },
83648426
{"__reduce__", Struct_reduce, METH_NOARGS, "reduce a struct"},
83658427
{"__rich_repr__", Struct_rich_repr, METH_NOARGS, "rich repr"},
@@ -22308,6 +22370,7 @@ msgspec_clear(PyObject *m)
2230822370
#endif
2230922371
Py_CLEAR(st->astimezone);
2231022372
Py_CLEAR(st->re_compile);
22373+
Py_CLEAR(st->copy_deepcopy);
2231122374
return 0;
2231222375
}
2231322376

@@ -22382,6 +22445,7 @@ msgspec_traverse(PyObject *m, visitproc visit, void *arg)
2238222445
#endif
2238322446
Py_VISIT(st->astimezone);
2238422447
Py_VISIT(st->re_compile);
22448+
Py_VISIT(st->copy_deepcopy);
2238522449
return 0;
2238622450
}
2238722451

@@ -22676,6 +22740,8 @@ PyInit__core(void)
2267622740
Py_DECREF(temp_module);
2267722741
if (st->re_compile == NULL) return NULL;
2267822742

22743+
st->copy_deepcopy = NULL;
22744+
2267922745
/* Initialize cached constant strings */
2268022746
#define CACHED_STRING(attr, str) \
2268122747
if ((st->attr = PyUnicode_InternFromString(str)) == NULL) return NULL

tests/unit/test_struct.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,12 +809,65 @@ class Test(Struct):
809809
b: int
810810
a: int
811811

812-
x = copy.copy(Test(1, 2))
812+
o = Test(1, 2)
813+
x = copy.copy(o)
813814
assert type(x) is Test
815+
assert x is not o
814816
assert x.b == 1
815817
assert x.a == 2
816818

817819

820+
def test_struct_deepcopy():
821+
o = Struct()
822+
x = copy.deepcopy(Struct())
823+
assert type(x) is Struct
824+
assert x is not o
825+
826+
class Sub(Struct):
827+
one: str
828+
two: list[int]
829+
830+
class Test(Struct):
831+
a: int
832+
b: int
833+
c: list[str]
834+
sub: Sub
835+
836+
o = Test(
837+
a=1,
838+
b=2,
839+
c=["1", "2"],
840+
sub=Sub(one="hello", two=[3]),
841+
)
842+
x = copy.deepcopy(o)
843+
assert type(x) is Test
844+
assert x.a == 1
845+
assert x.b == 2
846+
assert x.c == ["1", "2"]
847+
assert x.c is not o.c
848+
assert x.sub is not o.sub
849+
assert x.sub.one == "hello"
850+
assert x.sub.two == [3]
851+
assert x.sub.two is not o.sub.two
852+
853+
854+
def test_struct_deepcopy_custom_impl():
855+
# ensure we respect custom __deepcopy__ methods
856+
class CustomThing:
857+
def __init__(self, value):
858+
self.value = value
859+
860+
def __deepcopy__(self, memo):
861+
return CustomThing(value=self.value + 1)
862+
863+
class TestWithCustom(Struct):
864+
custom: CustomThing
865+
866+
t = TestWithCustom(CustomThing(1))
867+
tc = copy.deepcopy(t)
868+
assert tc.custom.value == 2
869+
870+
818871
class FrozenPoint(Struct, frozen=True):
819872
x: int
820873
y: int
@@ -2663,6 +2716,20 @@ def __post_init__(self):
26632716
assert x1 == x2
26642717
assert count == 1
26652718

2719+
def test_post_init_not_called_on_deepcopy(self):
2720+
count = 0
2721+
2722+
class Ex(Struct):
2723+
def __post_init__(self):
2724+
nonlocal count
2725+
count += 1
2726+
2727+
x1 = Ex()
2728+
assert count == 1
2729+
x2 = copy.deepcopy(x1)
2730+
assert x1 == x2
2731+
assert count == 1
2732+
26662733
def test_post_init_called_on_replace(self, replace):
26672734
count = 0
26682735

0 commit comments

Comments
 (0)