Skip to content

Commit d775279

Browse files
committed
support for complex numbers
for more information, see https://pre-commit.ci
1 parent d9cee77 commit d775279

File tree

10 files changed

+140
-0
lines changed

10 files changed

+140
-0
lines changed

cbor2/_decoder.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,21 @@ def decode_sharedref(self) -> Any:
623623
else:
624624
return shared
625625

626+
def decode_complex(self) -> complex:
627+
# Semantic tag 43000
628+
inputval = self._decode(immutable=True, unshared=True)
629+
try:
630+
value = complex(*inputval)
631+
except TypeError as exc:
632+
if not isinstance(inputval, tuple):
633+
raise CBORDecodeValueError(
634+
"error decoding complex: input value was not a tuple"
635+
) from None
636+
637+
raise CBORDecodeValueError("error decoding complex") from exc
638+
639+
return self.set_shareable(value)
640+
626641
def decode_rational(self) -> Fraction:
627642
# Semantic tag 30
628643
from fractions import Fraction
@@ -780,6 +795,7 @@ def decode_float64(self) -> float:
780795
260: CBORDecoder.decode_ipaddress,
781796
261: CBORDecoder.decode_ipnetwork,
782797
1004: CBORDecoder.decode_date_string,
798+
43000: CBORDecoder.decode_complex,
783799
55799: CBORDecoder.decode_self_describe_cbor,
784800
}
785801

cbor2/_encoder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,11 @@ def encode_float(self, value: float) -> None:
614614
else:
615615
self._fp_write(struct.pack(">Bd", 0xFB, value))
616616

617+
def encode_complex(self, value: complex) -> None:
618+
# Semantic tag 43000
619+
with self.disable_value_sharing():
620+
self.encode_semantic(CBORTag(43000, [value.real, value.imag]))
621+
617622
def encode_minimal_float(self, value: float) -> None:
618623
# Handle special values efficiently
619624
if math.isnan(value):
@@ -652,6 +657,7 @@ def encode_undefined(self, value: UndefinedType) -> None:
652657
str: CBOREncoder.encode_string,
653658
int: CBOREncoder.encode_int,
654659
float: CBOREncoder.encode_float,
660+
complex: CBOREncoder.encode_complex,
655661
("decimal", "Decimal"): CBOREncoder.encode_decimal,
656662
bool: CBOREncoder.encode_boolean,
657663
type(None): CBOREncoder.encode_none,

docs/usage.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ Tag Semantics Python type(s)
9999
258 Set of unique items set
100100
260 Network address :class:`ipaddress.IPv4Address` (or IPv6)
101101
261 Network prefix :class:`ipaddress.IPv4Network` (or IPv6)
102+
43000 Single complex number complex
102103
55799 Self-Described CBOR object
103104
===== ======================================== ====================================================
104105

docs/versionhistory.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ This library adheres to `Semantic Versioning <https://semver.org/>`_.
99

1010
- Dropped support for Python 3.8
1111
(#247 <https://github.com/agronholm/cbor2/pull/247>_; PR by @hugovk)
12+
- Added complex number support (tag 43000)
13+
(#249 <https://github.com/agronholm/cbor2/pull/249>_; PR by @chillenb)
1214

1315
**5.6.5** (2024-10-09)
1416

source/decoder.c

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ static PyObject * CBORDecoder_decode_epoch_date(CBORDecoderObject *);
5555
static PyObject * CBORDecoder_decode_date_string(CBORDecoderObject *);
5656
static PyObject * CBORDecoder_decode_fraction(CBORDecoderObject *);
5757
static PyObject * CBORDecoder_decode_bigfloat(CBORDecoderObject *);
58+
static PyObject * CBORDecoder_decode_complex(CBORDecoderObject *);
5859
static PyObject * CBORDecoder_decode_rational(CBORDecoderObject *);
5960
static PyObject * CBORDecoder_decode_regexp(CBORDecoderObject *);
6061
static PyObject * CBORDecoder_decode_uuid(CBORDecoderObject *);
@@ -1172,6 +1173,7 @@ decode_semantic(CBORDecoderObject *self, uint8_t subtype)
11721173
case 260: ret = CBORDecoder_decode_ipaddress(self); break;
11731174
case 261: ret = CBORDecoder_decode_ipnetwork(self); break;
11741175
case 1004: ret = CBORDecoder_decode_date_string(self); break;
1176+
case 43000: ret = CBORDecoder_decode_complex(self); break;
11751177
case 55799: ret = CBORDecoder_decode_self_describe_cbor(self);
11761178
break;
11771179

@@ -1636,6 +1638,34 @@ CBORDecoder_decode_sharedref(CBORDecoderObject *self)
16361638
return ret;
16371639
}
16381640

1641+
// CBORDecoder.decode_complex(self)
1642+
static PyObject *
1643+
CBORDecoder_decode_complex(CBORDecoderObject *self)
1644+
{
1645+
// semantic type 43000
1646+
PyObject *real, *imag, *ret = NULL;
1647+
payload_t = decode(self, DECODE_IMMUTABLE | DECODE_UNSHARED);
1648+
if (payload_t) {
1649+
if (PyTuple_CheckExact(payload_t) && PyTuple_GET_SIZE(payload_t) == 2) {
1650+
real = PyTuple_GET_ITEM(payload_t, 0);
1651+
imag = PyTuple_GET_ITEM(payload_t, 1);
1652+
f(PyFloat_CheckExact(real) && PyFloat_CheckExact(imag)) {
1653+
ret = PyComplex_FromDoubles(PyFloat_AS_DOUBLE(real), PyFloat_AS_DOUBLE(imag));
1654+
} else {
1655+
PyErr_Format(
1656+
_CBOR2_CBORDecodeValueError,
1657+
"Incorrect tag 43000 payload: does not contain two floats");
1658+
}
1659+
} else {
1660+
PyErr_Format(
1661+
_CBOR2_CBORDecodeValueError,
1662+
"Incorrect tag 43000 payload: not an array of length 2");
1663+
}
1664+
Py_DECREF(payload_t);
1665+
}
1666+
set_shareable(self, ret);
1667+
return ret;
1668+
}
16391669

16401670
// CBORDecoder.decode_rational(self)
16411671
static PyObject *
@@ -2159,6 +2189,8 @@ static PyMethodDef CBORDecoder_methods[] = {
21592189
"decode a fractional number from the input"},
21602190
{"decode_rational", (PyCFunction) CBORDecoder_decode_rational, METH_NOARGS,
21612191
"decode a rational value from the input"},
2192+
{"decode_complex", (PyCFunction) CBORDecoder_decode_complex, METH_NOARGS,
2193+
"decode a complex value from the input"},
21622194
{"decode_bigfloat", (PyCFunction) CBORDecoder_decode_bigfloat, METH_NOARGS,
21632195
"decode a large floating-point value from the input"},
21642196
{"decode_regexp", (PyCFunction) CBORDecoder_decode_regexp, METH_NOARGS,

source/encoder.c

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,6 +1362,36 @@ CBOREncoder_encode_rational(CBOREncoderObject *self, PyObject *value)
13621362
return ret;
13631363
}
13641364

1365+
// CBOREncoder.encode_complex(self, value)
1366+
static PyObject *
1367+
CBOREncoder_encode_complex(CBOREncoderObject *self, PyObject *value)
1368+
{
1369+
// semantic type 43000
1370+
PyObject *tuple, *real, *imag, *ret = NULL;
1371+
bool sharing;
1372+
1373+
real = PyObject_GetAttr(value, _CBOR2_str_real);
1374+
if (real) {
1375+
imag = PyObject_GetAttr(value, _CBOR2_str_imag);
1376+
if (imag) {
1377+
tuple = PyTuple_Pack(2, real, imag);
1378+
if (tuple) {
1379+
sharing = self->value_sharing;
1380+
self->value_sharing = false;
1381+
if (encode_semantic(self, 43000, tuple) == 0) {
1382+
Py_INCREF(Py_None);
1383+
ret = Py_None;
1384+
}
1385+
self->value_sharing = sharing;
1386+
Py_DECREF(tuple);
1387+
}
1388+
Py_DECREF(imag);
1389+
}
1390+
Py_DECREF(real);
1391+
}
1392+
return ret;
1393+
}
1394+
13651395

13661396
// CBOREncoder.encode_regexp(self, value)
13671397
static PyObject *
@@ -2118,6 +2148,8 @@ static PyMethodDef CBOREncoder_methods[] = {
21182148
"encode the specified integer *value* to the output"},
21192149
{"encode_float", (PyCFunction) CBOREncoder_encode_float, METH_O,
21202150
"encode the specified floating-point *value* to the output"},
2151+
{"encode_complex", (PyCFunction) CBOREncoder_encode_complex, METH_O,
2152+
"encode the specified complex *value* to the output"},
21212153
{"encode_boolean", (PyCFunction) CBOREncoder_encode_boolean, METH_O,
21222154
"encode the specified boolean *value* to the output"},
21232155
{"encode_none", (PyCFunction) CBOREncoder_encode_none, METH_O,

source/module.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ PyObject *_CBOR2_str_FrozenDict = NULL;
621621
PyObject *_CBOR2_str_fromordinal = NULL;
622622
PyObject *_CBOR2_str_getvalue = NULL;
623623
PyObject *_CBOR2_str_groups = NULL;
624+
PyObject *_CBOR2_str_imag = NULL;
624625
PyObject *_CBOR2_str_ip_address = NULL;
625626
PyObject *_CBOR2_str_ip_network = NULL;
626627
PyObject *_CBOR2_str_is_infinite = NULL;
@@ -637,6 +638,7 @@ PyObject *_CBOR2_str_parsestr = NULL;
637638
PyObject *_CBOR2_str_pattern = NULL;
638639
PyObject *_CBOR2_str_prefixlen = NULL;
639640
PyObject *_CBOR2_str_read = NULL;
641+
PyObject *_CBOR2_str_real = NULL;
640642
PyObject *_CBOR2_str_s = NULL;
641643
PyObject *_CBOR2_str_timestamp = NULL;
642644
PyObject *_CBOR2_str_toordinal = NULL;
@@ -955,6 +957,7 @@ PyInit__cbor2(void)
955957
INTERN_STRING(fromordinal);
956958
INTERN_STRING(getvalue);
957959
INTERN_STRING(groups);
960+
INTERN_STRING(imag);
958961
INTERN_STRING(ip_address);
959962
INTERN_STRING(ip_network);
960963
INTERN_STRING(is_infinite);
@@ -971,6 +974,7 @@ PyInit__cbor2(void)
971974
INTERN_STRING(pattern);
972975
INTERN_STRING(prefixlen);
973976
INTERN_STRING(read);
977+
INTERN_STRING(real);
974978
INTERN_STRING(s);
975979
INTERN_STRING(timestamp);
976980
INTERN_STRING(toordinal);

source/module.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ extern PyObject *_CBOR2_str_FrozenDict;
5353
extern PyObject *_CBOR2_str_fromordinal;
5454
extern PyObject *_CBOR2_str_getvalue;
5555
extern PyObject *_CBOR2_str_groups;
56+
extern PyObject *_CBOR2_str_imag;
5657
extern PyObject *_CBOR2_str_ip_address;
5758
extern PyObject *_CBOR2_str_ip_network;
5859
extern PyObject *_CBOR2_str_is_infinite;
@@ -69,6 +70,7 @@ extern PyObject *_CBOR2_str_parsestr;
6970
extern PyObject *_CBOR2_str_pattern;
7071
extern PyObject *_CBOR2_str_prefixlen;
7172
extern PyObject *_CBOR2_str_read;
73+
extern PyObject *_CBOR2_str_real;
7274
extern PyObject *_CBOR2_str_s;
7375
extern PyObject *_CBOR2_str_timestamp;
7476
extern PyObject *_CBOR2_str_toordinal;

tests/test_decoder.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,35 @@ def test_bigfloat(impl):
539539
assert decoded == Decimal("1.5")
540540

541541

542+
@pytest.mark.parametrize(
543+
"payload, expected",
544+
[
545+
("d9a7f882f90000f90000", 0.0j),
546+
("d9a7f882fb0000000000000000fb0000000000000000", 0.0j),
547+
("d9a7f882f98000f98000", -0.0j),
548+
("d9a7f882f90000f93c00", 1.0j),
549+
("d9a7f882fb0000000000000000fb3ff199999999999a", 1.1j),
550+
("d9a7f882f93e00f93e00", 1.5 + 1.5j),
551+
("d9a7f882f97bfff97bff", 65504.0 + 65504.0j),
552+
("d9a7f882fa47c35000fa47c35000", 100000.0 + 100000.0j),
553+
("fa7f7fffff", 3.4028234663852886e38),
554+
("d9a7f882f90000fb7e37e43c8800759c", 1.0e300j),
555+
("d9a7f882f90000f90001", 5.960464477539063e-8j),
556+
("d9a7f882f90000f90400", 0.00006103515625j),
557+
("d9a7f882f90000f9c400", -4.0j),
558+
("d9a7f882f90000fbc010666666666666", -4.1j),
559+
("d9a7f882f90000f97c00", complex(0.0, float("inf"))),
560+
("d9a7f882f97c00f90000", complex(float("inf"), 0.0)),
561+
("d9a7f882f90000f9fc00", complex(0.0, float("-inf"))),
562+
("d9a7f882f90000fa7f800000", complex(0.0, float("inf"))),
563+
("d9a7f882f90000faff800000", complex(0.0, float("-inf"))),
564+
],
565+
)
566+
def test_complex(impl, payload, expected):
567+
decoded = impl.loads(unhexlify(payload))
568+
assert decoded == expected
569+
570+
542571
def test_rational(impl):
543572
decoded = impl.loads(unhexlify("d81e820205"))
544573
assert decoded == Fraction(2, 5)

tests/test_encoder.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,22 @@ def test_decimal(impl, value, expected):
324324
assert impl.dumps(value) == expected
325325

326326

327+
@pytest.mark.parametrize(
328+
"value, expected",
329+
[
330+
(3.1 + 2.1j, "d9a7f882fb4008cccccccccccdfb4000cccccccccccd"),
331+
(1.0e300j, "d9a7f882fb0000000000000000fb7e37e43c8800759c"),
332+
(0.0j, "d9a7f882fb0000000000000000fb0000000000000000"),
333+
(complex(float("inf"), float("inf")), "d9a7f882f97c00f97c00"),
334+
(complex(float("inf"), 0.0), "d9a7f882f97c00fb0000000000000000"),
335+
(complex(float("nan"), float("inf")), "d9a7f882f97e00f97c00"),
336+
],
337+
)
338+
def test_complex(impl, value, expected):
339+
expected = unhexlify(expected)
340+
assert impl.dumps(value) == expected
341+
342+
327343
def test_rational(impl):
328344
expected = unhexlify("d81e820205")
329345
assert impl.dumps(Fraction(2, 5)) == expected

0 commit comments

Comments
 (0)