diff --git a/fastwarc/fastwarc/stream_io.pyi b/fastwarc/fastwarc/stream_io.pyi index ab4dbf48..f763700d 100644 --- a/fastwarc/fastwarc/stream_io.pyi +++ b/fastwarc/fastwarc/stream_io.pyi @@ -1,5 +1,13 @@ from types import TracebackType -from typing import ContextManager, IO, Optional, Type, Union, BinaryIO +from typing import ContextManager, Optional, Type, Union, BinaryIO, Protocol + +class _ReadableStream(Protocol): + def read(self, size: int) -> bytes: ... + def seek(self, offset: int) -> int: ... + +class _WritableStream(Protocol): + def write(self, data: bytes) -> int: ... + def flush(self) -> None: ... class IOStream(ContextManager[IOStream]): @@ -20,7 +28,7 @@ class IOStream(ContextManager[IOStream]): class BufferedReader: def __init__( - self, stream: Union[IOStream, BinaryIO], buf_size: int = 65536, negotiate_stream: bool = True + self, stream: Union[IOStream, BinaryIO, _ReadableStream], buf_size: int = 65536, negotiate_stream: bool = True ) -> None: ... def close(self) -> None: ... def consume(self, size: int = -1) -> int: ... @@ -45,20 +53,20 @@ class CompressingStream(IOStream): class BrotliStream(CompressingStream): def __init__( - self, raw_stream: Union[IOStream, BinaryIO], quality: int = 11, lgwin: int = 22, lgblock: int = 0 + self, raw_stream: Union[IOStream, BinaryIO, _ReadableStream, _WritableStream], quality: int = 11, lgwin: int = 22, lgblock: int = 0 ) -> None: ... class GZipStream(CompressingStream): def __init__( - self, raw_stream: Union[IOStream, BinaryIO], compression_level: int = 9, zlib: bool = False + self, raw_stream: Union[IOStream, BinaryIO, _ReadableStream, _WritableStream], compression_level: int = 9, zlib: bool = False ) -> None: ... class LZ4Stream(CompressingStream): def __init__( self, - raw_stream: Union[IOStream, BinaryIO], + raw_stream: Union[IOStream, BinaryIO, _ReadableStream, _WritableStream], compression_level: int = 12, favor_dec_speed: bool = True, ) -> None: ... @@ -66,7 +74,7 @@ class LZ4Stream(CompressingStream): class PythonIOStreamAdapter(IOStream): - def __init__(self, py_stream: BinaryIO) -> None: ... + def __init__(self, py_stream: Union[_ReadableStream, _WritableStream]) -> None: ... class FastWARCError(Exception): diff --git a/fastwarc/fastwarc/warc.pyi b/fastwarc/fastwarc/warc.pyi index 49f828e5..6f02fdd3 100644 --- a/fastwarc/fastwarc/warc.pyi +++ b/fastwarc/fastwarc/warc.pyi @@ -11,11 +11,20 @@ from typing import ( ValuesView, KeysView, BinaryIO, + Protocol, ) from enum import IntFlag from .stream_io import BufferedReader, IOStream +class _ReadableStream(Protocol): + def read(self, size: int) -> bytes: ... + def seek(self, offset: int) -> int: ... + +class _WritableStream(Protocol): + def write(self, data: bytes) -> int: ... + def flush(self) -> None: ... + class WarcRecordType(IntFlag): warcinfo = 2 @@ -48,7 +57,7 @@ class WarcHeaderMap: def items(self) -> Iterator[Tuple[str, str]]: ... def keys(self) -> KeysView[str]: ... def values(self) -> ValuesView[str]: ... - def write(self, stream: Union[IOStream, BinaryIO]) -> None: ... + def write(self, stream: IOStream) -> None: ... def __getitem__(self, item: str) -> str: ... def __iter__(self) -> Iterator[Tuple[str, str]]: ... def __len__(self) -> int: ... @@ -82,17 +91,18 @@ class WarcRecord: def verify_payload_digest(self, consume: bool = False) -> bool: ... def write( self, - stream: Union[IOStream, BinaryIO], + stream: Union[IOStream, BinaryIO, _WritableStream], checksum_data: bool = False, payload_digest: Optional[bytes] = None, chunk_size: int = 16384 ) -> int: ... + class ArchiveIterator(Iterable[WarcRecord]): def __init__( self, - stream: Union[IOStream, BinaryIO], + stream: Union[IOStream, BinaryIO, _ReadableStream], record_types: WarcRecordType = any_type, parse_http: bool = True, min_content_length: int = -1,