diff --git a/irclib/parser.py b/irclib/parser.py index c06aa03..c431436 100644 --- a/irclib/parser.py +++ b/irclib/parser.py @@ -130,10 +130,10 @@ def __str__(self) -> str: return self.name @classmethod - def parse(cls, text: str) -> "Cap": + def parse(cls, text: str) -> Self: """Parse a CAP entity from a string""" name, _, value = text.partition(CAP_VALUE_SEP) - return Cap(name, value or None) + return cls(name, value or None) class CapList(Parseable, List[Cap]): @@ -161,7 +161,7 @@ def __str__(self) -> str: return CAP_SEP.join(map(str, self)) @classmethod - def parse(cls, text: str) -> "CapList": + def parse(cls, text: str) -> Self: """Parse a list of CAPs from a string""" if text.startswith(":"): text = text[1:] # Remove leading colon @@ -175,7 +175,7 @@ def parse(cls, text: str) -> "CapList": [] if not text else (Cap.parse(s) for s in stripped.split(CAP_SEP)) ) - return CapList(caps) + return cls(caps) class MessageTag(Parseable): @@ -259,7 +259,7 @@ def __ne__(self, other: object) -> bool: return NotImplemented @classmethod - def parse(cls, text: str) -> "MessageTag": + def parse(cls, text: str) -> Self: """ Parse a tag from a string @@ -270,7 +270,7 @@ def parse(cls, text: str) -> "MessageTag": if value: value = MessageTag.unescape(value) - return MessageTag(name, value, has_value=bool(sep)) + return cls(name, value, has_value=bool(sep)) class TagList(Parseable, Dict[str, MessageTag]): @@ -316,21 +316,19 @@ def __ne__(self, other: object) -> bool: return dict(self) != dict(obj) @classmethod - def parse(cls, text: str) -> "TagList": + def parse(cls, text: str) -> Self: """ Parse the list of tags from a string :param text: The string to parse :return: The parsed object """ - return TagList( - map(MessageTag.parse, filter(None, text.split(TAGS_SEP))) - ) + return cls(map(MessageTag.parse, filter(None, text.split(TAGS_SEP)))) - @staticmethod - def from_dict(tags: Dict[str, str]) -> "TagList": + @classmethod + def from_dict(cls, tags: Dict[str, str]) -> Self: """Create a TagList from a dict of tags""" - return TagList(MessageTag(k, v) for k, v in tags.items()) + return cls(MessageTag(k, v) for k, v in tags.items()) class Prefix(Parseable): @@ -414,7 +412,7 @@ def __ne__(self, other: object) -> bool: return NotImplemented @classmethod - def parse(cls, text: str) -> "Prefix": + def parse(cls, text: str) -> Self: """ Parse the prefix from a string @@ -422,7 +420,7 @@ def parse(cls, text: str) -> "Prefix": :return: Parsed Object """ if not text: - return Prefix() + return cls() match = PREFIX_RE.match(text) if not match: # pragma: no cover @@ -431,7 +429,7 @@ def parse(cls, text: str) -> "Prefix": raise ParseError(msg) nick, user, host = match.groups() - return Prefix(nick, user, host) + return cls(nick, user, host) class ParamList(Parseable, List[str]): @@ -481,11 +479,11 @@ def __ne__(self, other: object) -> bool: return NotImplemented - @staticmethod - def from_list(data: Sequence[str]) -> "ParamList": + @classmethod + def from_list(cls, data: Sequence[str]) -> Self: """Create a ParamList from a Sequence of strings.""" if not data: - return ParamList() + return cls() args = list(data[:-1]) if data[-1].startswith(TRAIL_SENTINEL) or not data[-1]: @@ -495,10 +493,10 @@ def from_list(data: Sequence[str]) -> "ParamList": has_trail = False args.append(data[-1]) - return ParamList(*args, has_trail=has_trail) + return cls(*args, has_trail=has_trail) @classmethod - def parse(cls, text: str) -> "ParamList": + def parse(cls, text: str) -> Self: """ Parse a list of parameters @@ -517,7 +515,7 @@ def parse(cls, text: str) -> "ParamList": if arg: args.append(arg) - return ParamList(*args, has_trail=has_trail) + return cls(*args, has_trail=has_trail) def _parse_tags( @@ -639,9 +637,12 @@ def __ne__(self, other: object) -> bool: return NotImplemented @classmethod - def parse(cls, text: Union[str, bytes]) -> "Message": + def parse(cls, text: Union[str, bytes]) -> Self: """Parse an IRC message in to objects""" - if isinstance(text, bytes): + if isinstance(text, memoryview): + text = text.tobytes().decode(errors="ignore") + + if isinstance(text, (bytes, bytearray)): text = text.decode(errors="ignore") tags = "" @@ -659,4 +660,4 @@ def parse(cls, text: Union[str, bytes]) -> "Message": prefix_obj = Prefix.parse(prefix[1:]) if prefix else None command = command.upper() param_obj = ParamList.parse(params) - return Message(tags_obj, prefix_obj, command, param_obj) + return cls(tags_obj, prefix_obj, command, param_obj) diff --git a/tests/parser_test.py b/tests/parser_test.py index a4e26f3..73542d6 100644 --- a/tests/parser_test.py +++ b/tests/parser_test.py @@ -707,6 +707,18 @@ def test_parse_bytes(self) -> None: assert line.command == "COMMAND" assert line.parameters == ["some", "params", "and stuff"] + def test_parse_bytearray(self) -> None: + """Test parsing bytearray""" + line = Message.parse(bytearray(b"COMMAND some params :and stuff")) + assert line.command == "COMMAND" + assert line.parameters == ["some", "params", "and stuff"] + + def test_parse_memoryview(self) -> None: + """Test parsing memoryview""" + line = Message.parse(memoryview(b"COMMAND some params :and stuff")) + assert line.command == "COMMAND" + assert line.parameters == ["some", "params", "and stuff"] + @pytest.mark.parametrize( ("obj", "text"), [