diff --git a/noxfile.py b/noxfile.py index 3830ac54..d6b52ce7 100644 --- a/noxfile.py +++ b/noxfile.py @@ -9,6 +9,10 @@ nox.options.sessions = ["lint", "test"] +ALL_CPYTHON = [f"3.{minor}" for minor in range(6, 12 + 1)] +ALL_PYPY = [f"pypy3.{minor}" for minor in range(8, 10 + 1)] +ALL_PYTHON = ALL_CPYTHON + ALL_PYPY + @nox.session def lint(session: nox.Session) -> None: @@ -46,7 +50,7 @@ def remove_extension(session: nox.Session, in_place: bool = False) -> None: assert removed -@nox.session(python=["3.6", "3.7", "3.8", "3.9", "3.10"]) +@nox.session(python=ALL_PYTHON) def test(session: nox.Session) -> None: """Run tests.""" session.install("-r", "requirements-test.txt") @@ -54,11 +58,11 @@ def test(session: nox.Session) -> None: env = {"CIBUILDWHEEL": "1"} update_env_macos(session, env) session.install(".", env=env) - session.run("pytest", env=env) + session.run("pytest", *session.posargs, env=env) # run without extension as well env.pop("CIBUILDWHEEL") remove_extension(session) - session.run("pytest", env=env) + session.run("pytest", *session.posargs, env=env) @nox.session(python=["3.8", "3.11"]) diff --git a/src/pybase64/_fallback.py b/src/pybase64/_fallback.py index c30aa6b1..a0d1ef39 100644 --- a/src/pybase64/_fallback.py +++ b/src/pybase64/_fallback.py @@ -35,7 +35,10 @@ def _get_bytes(s: Any) -> Union[bytes, bytearray]: if isinstance(s, _bytes_types): return s try: - return memoryview(s).tobytes() + mv = memoryview(s) + if not mv.c_contiguous: + raise BufferError("memoryview: underlying buffer is not C-contiguous") + return mv.tobytes() except TypeError: raise TypeError( "argument should be a bytes-like object or ASCII " @@ -63,24 +66,25 @@ def b64decode(s: Any, altchars: Any = None, validate: bool = False) -> bytes: A :exc:`binascii.Error` is raised if ``s`` is incorrectly padded. """ + s = _get_bytes(s) + if altchars is not None: + altchars = _get_bytes(altchars) if validate: if len(s) % 4 != 0: raise BinAsciiError("Incorrect padding") - s = _get_bytes(s) - if altchars is not None: - altchars = _get_bytes(altchars) - assert len(altchars) == 2, repr(altchars) - map = bytes.maketrans(altchars, b"+/") - s = s.translate(map) result = builtin_decode(s, altchars, validate=False) # check length of result vs length of input - padding = 0 - if len(s) > 1 and s[-2] in (b"=", 61): - padding = padding + 1 - if len(s) > 0 and s[-1] in (b"=", 61): - padding = padding + 1 - if 3 * (len(s) / 4) - padding != len(result): + expected_len = 0 + if len(s) > 0: + padding = 0 + # len(s) % 4 != 0 implies len(s) >= 4 here + if s[-2] == 61: # 61 == ord("=") + padding += 1 + if s[-1] == 61: + padding += 1 + expected_len = 3 * (len(s) // 4) - padding + if expected_len != len(result): raise BinAsciiError("Non-base64 digit found") return result return builtin_decode(s, altchars, validate=False) @@ -122,9 +126,11 @@ def b64encode(s: Any, altchars: Any = None) -> bytes: The result is returned as a :class:`bytes` object. """ + mv = memoryview(s) + if not mv.c_contiguous: + raise BufferError("memoryview: underlying buffer is not C-contiguous") if altchars is not None: altchars = _get_bytes(altchars) - assert len(altchars) == 2, repr(altchars) return builtin_encode(s, altchars) @@ -151,4 +157,7 @@ def encodebytes(s: Any) -> bytes: The result is returned as a :class:`bytes` object. """ + mv = memoryview(s) + if not mv.c_contiguous: + raise BufferError("memoryview: underlying buffer is not C-contiguous") return builtin_encodebytes(s) diff --git a/src/pybase64/_pybase64.c b/src/pybase64/_pybase64.c index 63865e95..76bebe3d 100644 --- a/src/pybase64/_pybase64.c +++ b/src/pybase64/_pybase64.c @@ -24,6 +24,25 @@ static int libbase64_simd_flag = 0; static uint32_t active_simd_flag = 0U; static uint32_t simd_flags; + +/* returns 0 on success */ +static int get_buffer(PyObject* object, Py_buffer* buffer) +{ + if (PyObject_GetBuffer(object, buffer, PyBUF_RECORDS_RO | PyBUF_C_CONTIGUOUS) != 0) { + return -1; + } +#if defined(PYPY_VERSION) + /* PyPy does not respect PyBUF_C_CONTIGUOUS */ + if (!PyBuffer_IsContiguous(buffer, 'C')) { + PyBuffer_Release(buffer); + PyErr_Format(PyExc_BufferError, "%R: underlying buffer is not C-contiguous", Py_TYPE(object)); + return -1; + } +#endif + return 0; +} + + /* returns 0 on success */ static int parse_alphabet(PyObject* alphabetObject, char* alphabet, int* useAlphabet) { @@ -49,7 +68,7 @@ static int parse_alphabet(PyObject* alphabetObject, char* alphabet, int* useAlph Py_INCREF(alphabetObject); } - if (PyObject_GetBuffer(alphabetObject, &buffer, PyBUF_SIMPLE) < 0) { + if (get_buffer(alphabetObject, &buffer) != 0) { Py_DECREF(alphabetObject); return -1; } @@ -314,7 +333,7 @@ static PyObject* pybase64_encode_impl(PyObject* self, PyObject* args, PyObject * return NULL; } - if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) { + if (get_buffer(in_object, &buffer) != 0) { return NULL; } @@ -434,7 +453,7 @@ static PyObject* pybase64_decode_impl(PyObject* self, PyObject* args, PyObject * Py_INCREF(in_object); } if (source == NULL) { - if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) { + if (get_buffer(in_object, &buffer) != 0) { Py_DECREF(in_object); return NULL; } @@ -467,7 +486,7 @@ static PyObject* pybase64_decode_impl(PyObject* self, PyObject* args, PyObject * Py_DECREF(in_object); } in_object = translate_object; - if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) { + if (get_buffer(in_object, &buffer) != 0) { Py_DECREF(in_object); return NULL; } @@ -605,10 +624,17 @@ static PyObject* pybase64_encodebytes(PyObject* self, PyObject* in_object) size_t out_len; PyObject* out_object; - if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) { + if (get_buffer(in_object, &buffer) != 0) { return NULL; } - + if (((buffer.format[0] != 'c') && (buffer.format[0] != 'b') && (buffer.format[0] != 'B')) || buffer.format[1] != '\0' ) { + PyBuffer_Release(&buffer); + return PyErr_Format(PyExc_TypeError, "expected single byte elements, not '%s' from %R", buffer.format, Py_TYPE(in_object)); + } + if (buffer.ndim != 1) { + PyBuffer_Release(&buffer); + return PyErr_Format(PyExc_TypeError, "expected 1-D data, not %d-D data from %R", buffer.ndim, Py_TYPE(in_object)); + } if (buffer.len > (3 * (PY_SSIZE_T_MAX / 4))) { PyBuffer_Release(&buffer); return PyErr_NoMemory(); diff --git a/tests/test_pybase64.py b/tests/test_pybase64.py index 2d817891..13b14e32 100644 --- a/tests/test_pybase64.py +++ b/tests/test_pybase64.py @@ -24,6 +24,24 @@ _has_extension = False +def b64encode_as_string(s, altchars=None): + """helper returning bytes instead of string for tests""" + return pybase64.b64encode_as_string(s, altchars).encode("ascii") + + +def b64decode_as_bytearray(s, altchars=None, validate=False): + """helper returning bytes instead of bytearray for tests""" + return bytes(pybase64.b64decode_as_bytearray(s, altchars, validate)) + + +param_encode_functions = pytest.mark.parametrize( + "efn", [pybase64.b64encode, b64encode_as_string] +) +param_decode_functions = pytest.mark.parametrize( + "dfn", [pybase64.b64decode, b64decode_as_bytearray] +) + + STD = 0 URL = 1 ALT1 = 2 @@ -153,25 +171,7 @@ def simd_setup(simd_id): ) -param_encode_functions = pytest.mark.parametrize( - "efn, ecast", - [ - (pybase64.b64encode, lambda x: x), - (pybase64.b64encode_as_string, lambda x: x.encode("ascii")), - ], -) - - -param_decode_functions = pytest.mark.parametrize( - "dfn, dcast", - [ - (pybase64.b64decode, lambda x: x), - (pybase64.b64decode_as_bytearray, lambda x: bytes(x)), - ], -) - - -@pytest.fixture +@pytest.fixture() def simd(request): simd_setup(request.param) return request.param @@ -245,10 +245,10 @@ def test_encbytes(vector_id, simd): @param_vector @param_altchars @param_encode_functions -def test_enc(efn, ecast, altchars_id, vector_id, simd): +def test_enc(efn, altchars_id, vector_id, simd): vector = test_vectors_bin[altchars_id][vector_id] altchars = altchars_lut[altchars_id] - test = ecast(efn(vector, altchars)) + test = efn(vector, altchars) base = base64.b64encode(vector, altchars) assert test == base @@ -258,14 +258,14 @@ def test_enc(efn, ecast, altchars_id, vector_id, simd): @param_altchars @param_validate @param_decode_functions -def test_dec(dfn, dcast, altchars_id, vector_id, validate, simd): +def test_dec(dfn, altchars_id, vector_id, validate, simd): vector = test_vectors_b64[altchars_id][vector_id] altchars = altchars_lut[altchars_id] if validate: base = base64.b64decode(vector, altchars, validate) else: base = base64.b64decode(vector, altchars) - test = dcast(dfn(vector, altchars, validate)) + test = dfn(vector, altchars, validate) assert test == base @@ -274,7 +274,7 @@ def test_dec(dfn, dcast, altchars_id, vector_id, validate, simd): @param_altchars @param_validate @param_decode_functions -def test_dec_unicode(dfn, dcast, altchars_id, vector_id, validate, simd): +def test_dec_unicode(dfn, altchars_id, vector_id, validate, simd): vector = test_vectors_b64[altchars_id][vector_id] vector = str(vector, "utf-8") altchars = altchars_lut[altchars_id] @@ -286,7 +286,7 @@ def test_dec_unicode(dfn, dcast, altchars_id, vector_id, validate, simd): base = base64.b64decode(vector, altchars, validate) else: base = base64.b64decode(vector, altchars) - test = dcast(dfn(vector, altchars, validate)) + test = dfn(vector, altchars, validate) assert test == base @@ -296,11 +296,11 @@ def test_dec_unicode(dfn, dcast, altchars_id, vector_id, validate, simd): @param_validate @param_encode_functions @param_decode_functions -def test_rnd(dfn, dcast, efn, ecast, altchars_id, vector_id, validate, simd): +def test_rnd(dfn, efn, altchars_id, vector_id, validate, simd): vector = test_vectors_b64[altchars_id][vector_id] altchars = altchars_lut[altchars_id] - test = dcast(dfn(vector, altchars, validate)) - test = ecast(efn(test, altchars)) + test = dfn(vector, altchars, validate) + test = efn(test, altchars) assert test == vector @@ -310,11 +310,11 @@ def test_rnd(dfn, dcast, efn, ecast, altchars_id, vector_id, validate, simd): @param_validate @param_encode_functions @param_decode_functions -def test_rnd_unicode(dfn, dcast, efn, ecast, altchars_id, vector_id, validate, simd): +def test_rnd_unicode(dfn, efn, altchars_id, vector_id, validate, simd): vector = test_vectors_b64[altchars_id][vector_id] altchars = altchars_lut[altchars_id] - test = dcast(dfn(str(vector, "utf-8"), altchars, validate)) - test = ecast(efn(test, altchars)) + test = dfn(str(vector, "utf-8"), altchars, validate) + test = efn(test, altchars) assert test == vector @@ -323,7 +323,7 @@ def test_rnd_unicode(dfn, dcast, efn, ecast, altchars_id, vector_id, validate, s @param_altchars @param_validate @param_decode_functions -def test_invalid_padding_dec(dfn, dcast, altchars_id, vector_id, validate, simd): +def test_invalid_padding_dec(dfn, altchars_id, vector_id, validate, simd): vector = test_vectors_b64[altchars_id][vector_id][1:] if len(vector) > 0: altchars = altchars_lut[altchars_id] @@ -337,6 +337,7 @@ def test_invalid_padding_dec(dfn, dcast, altchars_id, vector_id, validate, simd) [b"-__", AssertionError], [3.0, TypeError], ["-€", ValueError], + [memoryview(b"- _")[::2], BufferError], ] params_invalid_altchars = pytest.mark.parametrize( "altchars,exception", @@ -348,7 +349,7 @@ def test_invalid_padding_dec(dfn, dcast, altchars_id, vector_id, validate, simd) @param_simd @params_invalid_altchars @param_encode_functions -def test_invalid_altchars_enc(efn, ecast, altchars, exception, simd): +def test_invalid_altchars_enc(efn, altchars, exception, simd): with pytest.raises(exception): efn(b"ABCD", altchars) @@ -356,7 +357,7 @@ def test_invalid_altchars_enc(efn, ecast, altchars, exception, simd): @param_simd @params_invalid_altchars @param_decode_functions -def test_invalid_altchars_dec(dfn, dcast, altchars, exception, simd): +def test_invalid_altchars_dec(dfn, altchars, exception, simd): with pytest.raises(exception): dfn(b"ABCD", altchars) @@ -364,7 +365,7 @@ def test_invalid_altchars_dec(dfn, dcast, altchars, exception, simd): @param_simd @params_invalid_altchars @param_decode_functions -def test_invalid_altchars_dec_validate(dfn, dcast, altchars, exception, simd): +def test_invalid_altchars_dec_validate(dfn, altchars, exception, simd): with pytest.raises(exception): dfn(b"ABCD", altchars, True) @@ -373,6 +374,7 @@ def test_invalid_altchars_dec_validate(dfn, dcast, altchars, exception, simd): [b"A@@@@FG", None, BinAsciiError], ["ABC€", None, ValueError], [3.0, None, TypeError], + [memoryview(b"ABCDEFGH")[::2], None, BufferError], ] params_invalid_data_validate = [ [b"\x00\x00\x00\x00", None, BinAsciiError], @@ -409,7 +411,7 @@ def test_invalid_altchars_dec_validate(dfn, dcast, altchars, exception, simd): @param_simd @params_invalid_data_novalidate @param_decode_functions -def test_invalid_data_dec(dfn, dcast, vector, altchars, exception, simd): +def test_invalid_data_dec(dfn, vector, altchars, exception, simd): with pytest.raises(exception): dfn(vector, altchars) @@ -417,8 +419,8 @@ def test_invalid_data_dec(dfn, dcast, vector, altchars, exception, simd): @param_simd @params_invalid_data_validate @param_decode_functions -def test_invalid_data_dec_skip(dfn, dcast, vector, altchars, exception, simd): - test = dcast(dfn(vector, altchars)) +def test_invalid_data_dec_skip(dfn, vector, altchars, exception, simd): + test = dfn(vector, altchars) base = base64.b64decode(vector, altchars) assert test == base @@ -426,25 +428,52 @@ def test_invalid_data_dec_skip(dfn, dcast, vector, altchars, exception, simd): @param_simd @params_invalid_data_all @param_decode_functions -def test_invalid_data_dec_validate(dfn, dcast, vector, altchars, exception, simd): +def test_invalid_data_dec_validate(dfn, vector, altchars, exception, simd): with pytest.raises(exception): dfn(vector, altchars, True) +params_invalid_data_enc = [ + ["this is a test", TypeError], + [memoryview(b"abcd")[::2], BufferError], +] +params_invalid_data_encodebytes = params_invalid_data_enc + [ + [memoryview(b"abcd").cast("B", (2, 2)), TypeError], + [memoryview(b"abcd").cast("I"), TypeError], +] +params_invalid_data_enc = pytest.mark.parametrize( + "vector,exception", + params_invalid_data_enc, + ids=[str(i) for i in range(len(params_invalid_data_enc))], +) +params_invalid_data_encodebytes = pytest.mark.parametrize( + "vector,exception", + params_invalid_data_encodebytes, + ids=[str(i) for i in range(len(params_invalid_data_encodebytes))], +) + + +@params_invalid_data_enc @param_encode_functions -def test_invalid_data_enc_0(efn, ecast): - with pytest.raises(TypeError): - efn("this is a test") +def test_invalid_data_enc(efn, vector, exception): + with pytest.raises(exception): + efn(vector) + + +@params_invalid_data_encodebytes +def test_invalid_data_encodebytes(vector, exception): + with pytest.raises(exception): + pybase64.encodebytes(vector) @param_encode_functions -def test_invalid_args_enc_0(efn, ecast): +def test_invalid_args_enc_0(efn): with pytest.raises(TypeError): efn() @param_decode_functions -def test_invalid_args_dec_0(dfn, dcast): +def test_invalid_args_dec_0(dfn): with pytest.raises(TypeError): dfn() @@ -460,3 +489,23 @@ def test_flags(request): "hsw": 1 | 2 | 4 | 8 | 16 | 32 | 64, # AVX2 "spr": 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, # AVX512VBMI }[cpu] == runtime_flags + + +@param_encode_functions +def test_enc_multi_dimensional(efn): + source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV" + vector = memoryview(source).cast("B", (4, len(source) // 4)) + assert vector.c_contiguous + test = efn(vector, None) + base = base64.b64encode(source) + assert test == base + + +@param_decode_functions +def test_dec_multi_dimensional(dfn): + source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV" + vector = memoryview(source).cast("B", (4, len(source) // 4)) + assert vector.c_contiguous + test = dfn(vector, None) + base = base64.b64decode(source) + assert test == base