Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Be robust against invalid utf-8 byte sequences and surrogateescape them when en- or decoding #144

Merged
merged 2 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flow/record/adapter/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
}

if self.hash_record:
document["_id"] = hashlib.md5(document["_source"].encode()).hexdigest()
document["_id"] = hashlib.md5(document["_source"].encode(errors="surrogateescape")).hexdigest()

Check warning on line 109 in flow/record/adapter/elastic.py

View check run for this annotation

Codecov / codecov/patch

flow/record/adapter/elastic.py#L109

Added line #L109 was not covered by tests

return document

Expand Down
2 changes: 1 addition & 1 deletion flow/record/adapter/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def write(self, rec: Record) -> None:
for key, value in rdict.items():
if rdict_types:
key = f"{key} ({rdict_types[key]})"
self.fp.write(fmt.format(key, value).encode())
self.fp.write(fmt.format(key, value).encode(errors="surrogateescape"))

def flush(self) -> None:
if self.fp:
Expand Down
2 changes: 1 addition & 1 deletion flow/record/adapter/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def read_table(self, table_name: str) -> Iterator[Record]:
if value == 0:
row[idx] = None
elif isinstance(value, str):
row[idx] = value.encode("utf-8")
row[idx] = value.encode(errors="surrogateescape")
yield descriptor_cls.init_from_dict(dict(zip(fnames, row)))

def __iter__(self) -> Iterator[Record]:
Expand Down
2 changes: 1 addition & 1 deletion flow/record/adapter/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def write(self, rec):
buf = self.format_spec.format_map(DefaultMissing(rec._asdict()))
else:
buf = repr(rec)
self.fp.write(buf.encode() + b"\n")
self.fp.write(buf.encode(errors="surrogateescape") + b"\n")

# because stdout is usually line buffered we force flush here if wanted
if self.auto_flush:
Expand Down
4 changes: 2 additions & 2 deletions flow/record/adapter/xlsx.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
elif isinstance(value, bytes):
base64_encode = False
try:
new_value = 'b"' + value.decode() + '"'
new_value = 'b"' + value.decode(errors="surrogateescape") + '"'
if ILLEGAL_CHARACTERS_RE.search(new_value):
base64_encode = True
else:
Expand Down Expand Up @@ -142,7 +142,7 @@
if field_types[idx] == "bytes":
if value[1] == '"': # If so, we know this is b""
# Cut of the b" at the start and the trailing "
value = value[2:-1].encode()
value = value[2:-1].encode(errors="surrogateescape")

Check warning on line 145 in flow/record/adapter/xlsx.py

View check run for this annotation

Codecov / codecov/patch

flow/record/adapter/xlsx.py#L145

Added line #L145 was not covered by tests
else:
# If not, we know it is base64 encoded (so we cut of the starting 'base64:')
value = b64decode(value[7:])
Expand Down
4 changes: 2 additions & 2 deletions flow/record/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

from collections import OrderedDict

from .utils import to_native_str, to_str
from .utils import to_str
from .whitelist import WHITELIST, WHITELIST_TREE

log = logging.getLogger(__package__)
Expand Down Expand Up @@ -513,7 +513,7 @@ def __init__(self, name: str, fields: Optional[Sequence[tuple[str, str]]] = None
name, fields = parse_def(name)

self.name = name
self._field_tuples = tuple([(to_native_str(k), to_str(v)) for k, v in fields])
self._field_tuples = tuple([(to_str(k), to_str(v)) for k, v in fields])
self.recordType = _generate_record_class(name, self._field_tuples)
self.recordType._desc = self

Expand Down
29 changes: 2 additions & 27 deletions flow/record/fieldtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from flow.record.base import FieldType

RE_NORMALIZE_PATH = re.compile(r"[\\/]+")
NATIVE_UNICODE = isinstance("", str)

UTC = timezone.utc

Expand Down Expand Up @@ -207,10 +206,7 @@
class string(string_type, FieldType):
def __new__(cls, value):
if isinstance(value, bytes_type):
value = cls._decode(value, "utf-8")
if isinstance(value, bytes_type):
# Still bytes, so decoding failed (Python 2)
return bytes(value)
value = value.decode(errors="surrogateescape")
return super().__new__(cls, value)

def _pack(self):
Expand All @@ -221,27 +217,6 @@
return defang(self)
return str.__format__(self, spec)

@classmethod
def _decode(cls, data, encoding):
"""Decode a byte-string into a unicode-string.

Python 3: When `data` contains invalid unicode characters a `UnicodeDecodeError` is raised.
Python 2: When `data` contains invalid unicode characters the original byte-string is returned.
"""
if NATIVE_UNICODE:
# Raises exception on decode error
return data.decode(encoding)
try:
return data.decode(encoding)
except UnicodeDecodeError:
# Fallback to bytes (Python 2 only)
preview = data[:16].encode("hex_codec") + (".." if len(data) > 16 else "")
warnings.warn(
"Got binary data in string field (hex: {}). Compatibility is not guaranteed.".format(preview),
RuntimeWarning,
)
return data


# Alias for backwards compatibility
wstring = string
Expand Down Expand Up @@ -278,7 +253,7 @@
if len(args) == 1 and not kwargs:
arg = args[0]
if isinstance(arg, bytes_type):
arg = arg.decode("utf-8")
arg = arg.decode(errors="surrogateescape")

Check warning on line 256 in flow/record/fieldtypes/__init__.py

View check run for this annotation

Codecov / codecov/patch

flow/record/fieldtypes/__init__.py#L256

Added line #L256 was not covered by tests
if isinstance(arg, string_type):
# If we are on Python 3.11 or newer, we can use fromisoformat() to parse the string (fast path)
#
Expand Down
7 changes: 0 additions & 7 deletions flow/record/fieldtypes/net/ipv4.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings

from flow.record import FieldType
from flow.record.utils import to_native_str


def addr_long(s):
Expand Down Expand Up @@ -45,9 +44,6 @@ def __init__(self, addr, netmask=None):
DeprecationWarning,
stacklevel=5,
)
if isinstance(addr, type("")):
addr = to_native_str(addr)

if not isinstance(addr, str):
raise TypeError("Subnet() argument 1 must be string, not {}".format(type(addr).__name__))

Expand All @@ -67,9 +63,6 @@ def __contains__(self, addr):
if addr is None:
return False

if isinstance(addr, type("")):
addr = to_native_str(addr)

if isinstance(addr, str):
addr = addr_long(addr)

Expand Down
6 changes: 1 addition & 5 deletions flow/record/jsonpacker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,8 @@ def pack_obj(self, obj):
serial["_recorddescriptor"] = obj._desc.identifier

for field_type, field_name in obj._desc.get_field_tuples():
# PYTHON2: Because "bytes" are also "str" we have to handle this here
if field_type == "bytes" and isinstance(serial[field_name], str):
serial[field_name] = base64.b64encode(serial[field_name]).decode()

# Boolean field types should be cast to a bool instead of staying ints
elif field_type == "boolean" and isinstance(serial[field_name], int):
if field_type == "boolean" and isinstance(serial[field_name], int):
serial[field_name] = bool(serial[field_name])

return serial
Expand Down
40 changes: 18 additions & 22 deletions flow/record/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
import base64
import os
import sys
import warnings
from functools import wraps
from typing import BinaryIO, TextIO

_native = str
_unicode = type("")
_bytes = type(b"")


def get_stdout(binary: bool = False) -> TextIO | BinaryIO:
"""Return the stdout stream as binary or text stream.
Expand Down Expand Up @@ -50,33 +47,32 @@

def to_bytes(value):
"""Convert a value to a byte string."""
if value is None or isinstance(value, _bytes):
if value is None or isinstance(value, bytes):
return value
if isinstance(value, _unicode):
return value.encode("utf-8")
return _bytes(value)
if isinstance(value, str):
return value.encode(errors="surrogateescape")
return bytes(value)

Check warning on line 54 in flow/record/utils.py

View check run for this annotation

Codecov / codecov/patch

flow/record/utils.py#L54

Added line #L54 was not covered by tests


def to_str(value):
"""Convert a value to a unicode string."""
if value is None or isinstance(value, _unicode):
if value is None or isinstance(value, str):
return value
if isinstance(value, _bytes):
return value.decode("utf-8")
return _unicode(value)
if isinstance(value, bytes):
return value.decode(errors="surrogateescape")

Check warning on line 62 in flow/record/utils.py

View check run for this annotation

Codecov / codecov/patch

flow/record/utils.py#L62

Added line #L62 was not covered by tests
return str(value)


def to_native_str(value):
"""Convert a value to a native `str`."""
if value is None or isinstance(value, _native):
return value
if isinstance(value, _unicode):
# Python 2: unicode -> str
return value.encode("utf-8")
if isinstance(value, _bytes):
# Python 3: bytes -> str
return value.decode("utf-8")
return _native(value)
warnings.warn(

Check warning on line 67 in flow/record/utils.py

View check run for this annotation

Codecov / codecov/patch

flow/record/utils.py#L67

Added line #L67 was not covered by tests
(
"The to_native_str() function is deprecated, "
"this function will be removed in flow.record 3.20, "
"use to_str() instead"
),
DeprecationWarning,
)
return to_str(value)

Check warning on line 75 in flow/record/utils.py

View check run for this annotation

Codecov / codecov/patch

flow/record/utils.py#L75

Added line #L75 was not covered by tests


def to_base64(value):
Expand Down
29 changes: 29 additions & 0 deletions tests/test_adapter_line.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from io import BytesIO

from flow.record import RecordDescriptor
from flow.record.adapter.line import LineWriter


def test_line_writer_write_surrogateescape():
output = BytesIO()

lw = LineWriter(
path=output,
fields="name",
)

TestRecord = RecordDescriptor(
"test/string",
[
("string", "name"),
],
)

# construct from 'bytes' but with invalid unicode bytes
record = TestRecord(b"R\xc3\xa9\xeamy")
lw.write(record)

output.seek(0)
data = output.read()

assert data == b"--[ RECORD 1 ]--\nname = R\xc3\xa9\xeamy\n"
28 changes: 28 additions & 0 deletions tests/test_adapter_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from io import BytesIO

from flow.record import RecordDescriptor
from flow.record.adapter.text import TextWriter


def test_text_writer_write_surrogateescape():
output = BytesIO()

tw = TextWriter(
path=output,
)

TestRecord = RecordDescriptor(
"test/string",
[
("string", "name"),
],
)

# construct from 'bytes' but with invalid unicode bytes
record = TestRecord(b"R\xc3\xa9\xeamy")
tw.write(record)

output.seek(0)
data = output.read()

assert data == b"<test/string name='R\xc3\xa9\\udceamy'>\n"
11 changes: 2 additions & 9 deletions tests/test_fieldtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,8 @@ def test_string():
assert r.name == "Rémy"

# construct from 'bytes' but with invalid unicode bytes
if isinstance("", str):
# Python 3
with pytest.raises(UnicodeDecodeError):
TestRecord(b"R\xc3\xa9\xeamy")
else:
# Python 2
with pytest.warns(RuntimeWarning):
r = TestRecord(b"R\xc3\xa9\xeamy")
assert r.name
r = TestRecord(b"R\xc3\xa9\xeamy")
assert r.name == "Ré\udceamy"


def test_wstring():
Expand Down
20 changes: 20 additions & 0 deletions tests/test_json_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,23 @@ def test_record_pack_bool_regression() -> None:

# pack the json string back to a record and make sure it is the same as before
assert packer.unpack(data) == record


def test_record_pack_surrogateescape() -> None:
TestRecord = RecordDescriptor(
"test/string",
[
("string", "name"),
],
)

record = TestRecord(b"R\xc3\xa9\xeamy")
packer = JsonRecordPacker()

data = packer.pack(record)

# pack to json string and check if the 3rd and 4th byte are properly surrogate escaped
assert data.startswith('{"name": "R\\u00e9\\udceamy",')

# pack the json string back to a record and make sure it is the same as before
assert packer.unpack(data) == record
29 changes: 25 additions & 4 deletions tests/test_record.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import inspect
import os
import sys
from unittest.mock import patch
Expand Down Expand Up @@ -27,8 +28,6 @@
from flow.record.exceptions import RecordDescriptorError
from flow.record.stream import RecordFieldRewriter

from . import utils_inspect as inspect


def test_record_creation():
TestRecord = RecordDescriptor(
Expand Down Expand Up @@ -288,8 +287,30 @@ def isatty():
writer.write(record)

out, err = capsys.readouterr()
modifier = "" if isinstance("", str) else "u"
expected = "<test/a a_string={u}'hello' common={u}'world' a_count=10>\n".format(u=modifier)
expected = "<test/a a_string='hello' common='world' a_count=10>\n"
assert out == expected


def test_record_printer_stdout_surrogateescape(capsys):
Record = RecordDescriptor(
"test/a",
[
("string", "name"),
],
)
record = Record(b"R\xc3\xa9\xeamy")

# fake capsys to be a tty.
def isatty():
return True

capsys._capture.out.tmpfile.isatty = isatty

writer = RecordPrinter(getattr(sys.stdout, "buffer", sys.stdout))
writer.write(record)

out, err = capsys.readouterr()
expected = "<test/a name='Ré\\udceamy'>\n"
assert out == expected


Expand Down
Loading
Loading