diff --git a/agent_environment/__init__.py b/agent_environment/__init__.py index 46332a1..8b6ada3 100644 --- a/agent_environment/__init__.py +++ b/agent_environment/__init__.py @@ -23,6 +23,7 @@ LocalTmpFileOperator, ) from agent_environment.protocols import ( + DEFAULT_CHUNK_SIZE, InstructableResource, Resource, ResumableResource, @@ -40,6 +41,7 @@ from agent_environment.utils import generate_filetree __all__ = [ + "DEFAULT_CHUNK_SIZE", "DEFAULT_INSTRUCTIONS_MAX_DEPTH", "DEFAULT_INSTRUCTIONS_SKIP_DIRS", "BaseResource", diff --git a/agent_environment/file_operator.py b/agent_environment/file_operator.py index b6a30d8..da6dd45 100644 --- a/agent_environment/file_operator.py +++ b/agent_environment/file_operator.py @@ -5,14 +5,14 @@ """ import shutil -import tempfile from abc import ABC, abstractmethod +from collections.abc import AsyncIterator from pathlib import Path from xml.etree import ElementTree as ET import anyio -from agent_environment.protocols import TmpFileOperator +from agent_environment.protocols import DEFAULT_CHUNK_SIZE, TmpFileOperator from agent_environment.types import FileStat, TruncatedResult # Default directories to skip but mark in file tree @@ -70,11 +70,18 @@ async def read_bytes( length: int | None = None, ) -> bytes: resolved = self._resolve(path) - content = await anyio.Path(resolved).read_bytes() - if offset > 0 or length is not None: - end = None if length is None else offset + length - content = content[offset:end] - return content + # Optimize: use seek instead of reading entire file then slicing + if offset == 0 and length is None: + # Fast path: read entire file + return await anyio.Path(resolved).read_bytes() + + # Use seek for partial reads + async with await anyio.open_file(resolved, "rb") as f: + if offset > 0: + await f.seek(offset) + if length is None: + return await f.read() + return await f.read(length) async def write_file( self, @@ -119,6 +126,24 @@ async def list_dir(self, path: str) -> list[str]: entries = [entry.name async for entry in anyio.Path(resolved).iterdir()] return sorted(entries) + async def list_dir_with_types(self, path: str) -> list[tuple[str, bool]]: + """List directory contents with type information. + + More efficient than calling list_dir + is_dir for each entry. + + Args: + path: Directory path. + + Returns: + List of (name, is_dir) tuples, sorted alphabetically. + """ + resolved = self._resolve(path) + result: list[tuple[str, bool]] = [] + async for entry in anyio.Path(resolved).iterdir(): + is_dir = await entry.is_dir() + result.append((entry.name, is_dir)) + return sorted(result, key=lambda x: x[0]) + async def exists(self, path: str) -> bool: return await anyio.Path(self._resolve(path)).exists() @@ -163,6 +188,49 @@ async def glob(self, pattern: str) -> list[str]: matches.append(str(p)) return sorted(matches) + async def read_bytes_stream( + self, + path: str, + *, + chunk_size: int = DEFAULT_CHUNK_SIZE, + ) -> AsyncIterator[bytes]: + """Read file content as an async stream of bytes. + + Memory-efficient way to read large files. + + Args: + path: Path to file. + chunk_size: Size of each chunk in bytes (default: 64KB). + + Yields: + Chunks of bytes from the file. + """ + resolved = self._resolve(path) + async with await anyio.open_file(resolved, "rb") as f: + while True: + chunk = await f.read(chunk_size) + if not chunk: + break + yield chunk + + async def write_bytes_stream( + self, + path: str, + stream: AsyncIterator[bytes], + ) -> None: + """Write bytes stream to file. + + Memory-efficient way to write large files. + + Args: + path: Path to file. + stream: Async iterator yielding bytes chunks. + """ + resolved = self._resolve(path) + async with await anyio.open_file(resolved, "wb") as f: + async for chunk in stream: + await f.write(chunk) + async def truncate_to_tmp( self, content: str, @@ -198,12 +266,19 @@ class FileOperator(ABC): Provides common initialization logic for path validation, instructions configuration, and transparent tmp file handling. + This class has no local system dependencies - it's designed to work + with both local and remote backends. Tmp file handling is optional + and must be explicitly configured. + Tmp File Handling: - When tmp_dir and tmp_file_operator are provided, operations on + When tmp_dir or tmp_file_operator is provided, operations on paths under tmp_dir are automatically delegated to tmp_file_operator. Subclasses only need to implement _xxx_impl methods and don't need to be aware of tmp handling. + If neither tmp_dir nor tmp_file_operator is provided, tmp handling + is disabled and cross-boundary operations will not be available. + Example: ```python # Environment assembles the operators @@ -233,6 +308,7 @@ def __init__( tmp_dir: Path | None = None, tmp_file_operator: TmpFileOperator | None = None, skip_instructions: bool = False, + default_chunk_size: int = DEFAULT_CHUNK_SIZE, ): """Initialize FileOperator. @@ -243,9 +319,16 @@ def __init__( default_path is always included in allowed_paths. instructions_skip_dirs: Directories to skip in file tree generation. instructions_max_depth: Maximum depth for file tree generation. - tmp_dir: Directory for temporary files. - tmp_file_operator: Operator for tmp file operations. + tmp_dir: Directory for temporary files. If provided without + tmp_file_operator, a LocalTmpFileOperator will be created. + tmp_file_operator: Operator for tmp file operations. Takes + precedence over tmp_dir if both are provided. skip_instructions: If True, get_context_instructions returns None. + default_chunk_size: Default chunk size for streaming operations (default: 64KB). + + Note: + If neither tmp_dir nor tmp_file_operator is provided, tmp handling + is disabled. Cross-boundary operations will not be available. """ self._default_path = default_path.resolve() @@ -263,15 +346,20 @@ def __init__( self._instructions_max_depth = instructions_max_depth self._skip_instructions = skip_instructions - # Auto-create LocalTmpFileOperator with tmp_dir or a random temp directory + # Default chunk size for streaming operations + self._default_chunk_size = default_chunk_size + + # Tmp file operator setup - no auto-creation to avoid local system dependency + # Environment or subclass is responsible for providing tmp_file_operator if needed self._owned_tmp_dir: Path | None = None # Track tmp_dir we created (for cleanup) if tmp_file_operator is not None: self._tmp_file_operator: TmpFileOperator | None = tmp_file_operator - else: - if tmp_dir is None: - tmp_dir = Path(tempfile.mkdtemp(prefix="pai_agent_")) - self._owned_tmp_dir = tmp_dir # We created it, we must clean it up + elif tmp_dir is not None: + # Only create LocalTmpFileOperator if tmp_dir is explicitly provided self._tmp_file_operator = LocalTmpFileOperator(tmp_dir) + else: + # No tmp handling - cross-boundary operations will not be available + self._tmp_file_operator = None def _is_tmp_path(self, path: str) -> tuple[bool, str]: """Delegate to tmp_file_operator to check if path is managed.""" @@ -347,6 +435,23 @@ async def _list_dir_impl(self, path: str) -> list[str]: """List directory contents. Implement in subclass.""" ... + async def _list_dir_with_types_impl(self, path: str) -> list[tuple[str, bool]]: + """List directory with type info. Override for efficiency. + + Default implementation calls list_dir + is_dir for each entry. + Subclasses can override for more efficient implementation. + + Returns: + List of (name, is_dir) tuples, sorted alphabetically. + """ + entries = await self._list_dir_impl(path) + result: list[tuple[str, bool]] = [] + for name in entries: + entry_path = f"{path}/{name}" if path != "." else name + is_dir = await self._is_dir_impl(entry_path) + result.append((name, is_dir)) + return sorted(result, key=lambda x: x[0]) + @abstractmethod async def _exists_impl(self, path: str) -> bool: """Check if path exists. Implement in subclass.""" @@ -387,6 +492,50 @@ async def _glob_impl(self, pattern: str) -> list[str]: """Find files matching glob pattern. Implement in subclass.""" ... + # Streaming methods - optional to override (default uses read_bytes/write_file) + + async def _read_bytes_stream_impl( + self, + path: str, + *, + chunk_size: int = DEFAULT_CHUNK_SIZE, + ) -> AsyncIterator[bytes]: + """Read file content as an async stream. Override for efficiency. + + Default implementation loads entire file into memory and yields as single chunk. + Subclasses should override this for true streaming with large files. + + Args: + path: Path to file. + chunk_size: Size of each chunk in bytes (default: 64KB). + + Yields: + Chunks of bytes from the file. + """ + # Default: read entire file and yield as single chunk + content = await self._read_bytes_impl(path) + yield content + + async def _write_bytes_stream_impl( + self, + path: str, + stream: AsyncIterator[bytes], + ) -> None: + """Write bytes stream to file. Override for efficiency. + + Default implementation collects all chunks and writes at once. + Subclasses should override this for true streaming with large files. + + Args: + path: Path to file. + stream: Async iterator yielding bytes chunks. + """ + # Default: collect all chunks and write at once + chunks = [] + async for chunk in stream: + chunks.append(chunk) + await self._write_file_impl(path, b"".join(chunks)) + # --- Public methods with tmp routing --- async def read_file( @@ -467,6 +616,22 @@ async def list_dir(self, path: str) -> list[str]: return await self._tmp_file_operator.list_dir(routed_path) # type: ignore[union-attr] return await self._list_dir_impl(path) + async def list_dir_with_types(self, path: str) -> list[tuple[str, bool]]: + """List directory contents with type information. + + More efficient than calling list_dir + is_dir for each entry. + + Args: + path: Directory path. + + Returns: + List of (name, is_dir) tuples, sorted alphabetically. + """ + is_tmp, routed_path = self._is_tmp_path(path) + if is_tmp: # pragma: no cover + return await self._tmp_file_operator.list_dir_with_types(routed_path) # type: ignore[union-attr] + return await self._list_dir_with_types_impl(path) + async def exists(self, path: str) -> bool: """Check if path exists.""" is_tmp, routed_path = self._is_tmp_path(path) @@ -506,14 +671,16 @@ async def move(self, src: str, dst: str) -> None: # pragma: no cover # Neither in tmp: delegate to subclass await self._move_impl(src, dst) else: - # Cross-boundary move: read from source, write to dest, delete source + # Cross-boundary move: use streaming to avoid loading entire file into memory if src_is_tmp: - content = await self._tmp_file_operator.read_bytes(src_path) # type: ignore[union-attr] - await self._write_file_impl(dst, content) + stream = self._tmp_file_operator.read_bytes_stream( # type: ignore[union-attr] + src_path, chunk_size=self._default_chunk_size + ) + await self._write_bytes_stream_impl(dst, stream) await self._tmp_file_operator.delete(src_path) # type: ignore[union-attr] else: - content = await self._read_bytes_impl(src) - await self._tmp_file_operator.write_file(dst_path, content) # type: ignore[union-attr] + stream = self._read_bytes_stream_impl(src, chunk_size=self._default_chunk_size) + await self._tmp_file_operator.write_bytes_stream(dst_path, stream) # type: ignore[union-attr] await self._delete_impl(src) async def copy(self, src: str, dst: str) -> None: # pragma: no cover @@ -526,13 +693,15 @@ async def copy(self, src: str, dst: str) -> None: # pragma: no cover # Neither in tmp: delegate to subclass await self._copy_impl(src, dst) else: - # Cross-boundary copy: read from source, write to dest + # Cross-boundary copy: use streaming to avoid loading entire file into memory if src_is_tmp: - content = await self._tmp_file_operator.read_bytes(src_path) # type: ignore[union-attr] - await self._write_file_impl(dst, content) + stream = self._tmp_file_operator.read_bytes_stream( # type: ignore[union-attr] + src_path, chunk_size=self._default_chunk_size + ) + await self._write_bytes_stream_impl(dst, stream) else: - content = await self._read_bytes_impl(src) - await self._tmp_file_operator.write_file(dst_path, content) # type: ignore[union-attr] + stream = self._read_bytes_stream_impl(src, chunk_size=self._default_chunk_size) + await self._tmp_file_operator.write_bytes_stream(dst_path, stream) # type: ignore[union-attr] async def stat(self, path: str) -> FileStat: """Get file/directory status information.""" @@ -546,6 +715,53 @@ async def glob(self, pattern: str) -> list[str]: # Note: glob doesn't support tmp routing as patterns are relative to default_path return await self._glob_impl(pattern) + async def read_bytes_stream( + self, + path: str, + *, + chunk_size: int = DEFAULT_CHUNK_SIZE, + ) -> AsyncIterator[bytes]: + """Read file content as an async stream of bytes. + + Memory-efficient way to read large files. This is used internally + for cross-boundary copy/move operations. + + Args: + path: Path to file. + chunk_size: Size of each chunk in bytes (default: 64KB). + + Yields: + Chunks of bytes from the file. + """ + is_tmp, routed_path = self._is_tmp_path(path) + if is_tmp: # pragma: no cover + return await self._tmp_file_operator.read_bytes_stream( # type: ignore[union-attr] + routed_path, chunk_size=chunk_size + ) + return self._read_bytes_stream_impl(path, chunk_size=chunk_size) + + async def write_bytes_stream( + self, + path: str, + stream: AsyncIterator[bytes], + ) -> None: + """Write bytes stream to file. + + Memory-efficient way to write large files. This is used internally + for cross-boundary copy/move operations. + + Args: + path: Path to file. + stream: Async iterator yielding bytes chunks. + """ + is_tmp, routed_path = self._is_tmp_path(path) + if is_tmp: # pragma: no cover + await self._tmp_file_operator.write_bytes_stream( # type: ignore[union-attr] + routed_path, stream + ) + return + await self._write_bytes_stream_impl(path, stream) + async def truncate_to_tmp( self, content: str, diff --git a/agent_environment/protocols.py b/agent_environment/protocols.py index 7223c76..da7bcae 100644 --- a/agent_environment/protocols.py +++ b/agent_environment/protocols.py @@ -3,11 +3,15 @@ This module defines runtime-checkable protocols for resources and operators. """ +from collections.abc import AsyncIterator from pathlib import Path from typing import Any, Protocol, runtime_checkable from agent_environment.types import FileStat, TruncatedResult +# Default chunk size for streaming operations (64KB) +DEFAULT_CHUNK_SIZE = 65536 + @runtime_checkable class Resource(Protocol): @@ -176,6 +180,19 @@ async def delete(self, path: str) -> None: ... async def list_dir(self, path: str) -> list[str]: ... + async def list_dir_with_types(self, path: str) -> list[tuple[str, bool]]: + """List directory contents with type information. + + More efficient than calling list_dir + is_dir for each entry. + + Args: + path: Directory path. + + Returns: + List of (name, is_dir) tuples, sorted alphabetically. + """ + ... + async def exists(self, path: str) -> bool: ... async def is_file(self, path: str) -> bool: ... @@ -196,6 +213,43 @@ async def glob(self, pattern: str) -> list[str]: """Find files matching glob pattern.""" ... + def read_bytes_stream( + self, + path: str, + *, + chunk_size: int = DEFAULT_CHUNK_SIZE, + ) -> AsyncIterator[bytes]: + """Read file content as an async stream of bytes. + + Memory-efficient way to read large files. + + Note: This method returns an async iterator directly (not a coroutine). + Call it without await: `stream = op.read_bytes_stream(path)` + + Args: + path: Path to file. + chunk_size: Size of each chunk in bytes (default: 64KB). + + Yields: + Chunks of bytes from the file. + """ + ... + + async def write_bytes_stream( + self, + path: str, + stream: AsyncIterator[bytes], + ) -> None: + """Write bytes stream to file. + + Memory-efficient way to write large files. + + Args: + path: Path to file. + stream: Async iterator yielding bytes chunks. + """ + ... + async def truncate_to_tmp( self, content: str, diff --git a/agent_environment/resources.py b/agent_environment/resources.py index 7592b30..e1ab1cd 100644 --- a/agent_environment/resources.py +++ b/agent_environment/resources.py @@ -411,20 +411,41 @@ def keys(self) -> list[str]: """Return list of resource keys.""" return list(self._resources.keys()) - async def close_all(self) -> None: - """Close all resources in reverse registration order. + async def close_all(self, *, parallel: bool = False) -> None: + """Close all resources. + + Args: + parallel: If True, close resources concurrently using asyncio.gather. + If False (default), close in reverse registration order sequentially. Uses best-effort cleanup - continues even if individual resources fail to close. Handles both sync and async close(). Also clears registered factories. """ - for resource in reversed(list(self._resources.values())): - try: - result = resource.close() - if asyncio.iscoroutine(result): - await result - except Exception: # noqa: S110 - pass # Best effort cleanup + if parallel: + # Close all resources concurrently + async def _close_resource(resource: Resource) -> None: + try: + result = resource.close() + if asyncio.iscoroutine(result): + await result + except Exception: # noqa: S110 + pass # Best effort cleanup + + await asyncio.gather( + *[_close_resource(r) for r in self._resources.values()], + return_exceptions=True, + ) + else: + # Close in reverse registration order (sequential) + for resource in reversed(list(self._resources.values())): + try: + result = resource.close() + if asyncio.iscoroutine(result): + await result + except Exception: # noqa: S110 + pass # Best effort cleanup + self._resources.clear() self._factories.clear() diff --git a/agent_environment/utils.py b/agent_environment/utils.py index 8a33895..508d533 100644 --- a/agent_environment/utils.py +++ b/agent_environment/utils.py @@ -30,7 +30,7 @@ def _load_gitignore_spec(gitignore_content: str) -> pathspec.PathSpec | None: """Load .gitignore patterns from content.""" try: patterns = gitignore_content.splitlines() - return pathspec.PathSpec.from_lines("gitwildmatch", patterns) + return pathspec.PathSpec.from_lines("gitignore", patterns) except Exception: return None @@ -78,25 +78,19 @@ def _is_gitignored(rel_path: str, is_dir: bool) -> bool: path = rel_path + "/" if is_dir else rel_path return gitignore_spec.match_file(path) - async def _collect_paths(current_path: str, current_depth: int, path_prefix: str = "") -> list[str]: # noqa: C901 + async def _collect_paths(current_path: str, current_depth: int, path_prefix: str = "") -> list[str]: """Collect all file paths recursively, returning flat paths.""" result: list[str] = [] try: - entries = await file_op.list_dir(current_path) - # Sort: directories first, then files, alphabetically - dir_entries = [] - file_entries = [] - for name in entries: - entry_path = f"{current_path}/{name}" if current_path != "." else name - if await file_op.is_dir(entry_path): - dir_entries.append(name) - else: - file_entries.append(name) - dir_entries.sort() - file_entries.sort() + # Use list_dir_with_types to avoid N+1 is_dir calls + entries_with_types = await file_op.list_dir_with_types(current_path) + + # Separate directories and files + dir_entries = [(name, True) for name, is_dir in entries_with_types if is_dir] + file_entries = [(name, False) for name, is_dir in entries_with_types if not is_dir] # Process directories first - for name in dir_entries: + for name, _ in dir_entries: entry_path = f"{current_path}/{name}" if current_path != "." else name flat_path = f"{path_prefix}{name}" if path_prefix else name @@ -115,7 +109,7 @@ async def _collect_paths(current_path: str, current_depth: int, path_prefix: str result.extend(await _collect_paths(entry_path, current_depth + 1, f"{flat_path}/")) # Then files - for name in file_entries: + for name, _ in file_entries: flat_path = f"{path_prefix}{name}" if path_prefix else name should_skip, _ = _should_skip_hidden_item(name, False, skip_dirs) diff --git a/pytest.ini b/pytest.ini index 9ac67a5..daa752a 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,4 @@ # pytest.ini [pytest] asyncio_mode = auto -testpaths = ["tests"] +testpaths = tests diff --git a/tests/test_file_operator.py b/tests/test_file_operator.py index 4be36f9..3b2ddeb 100644 --- a/tests/test_file_operator.py +++ b/tests/test_file_operator.py @@ -1,8 +1,12 @@ """Tests for LocalTmpFileOperator and FileOperator.""" +from collections.abc import AsyncIterator from pathlib import Path from agent_environment import ( + DEFAULT_CHUNK_SIZE, + FileOperator, + FileStat, LocalTmpFileOperator, ) @@ -105,6 +109,26 @@ async def test_local_tmp_file_operator_list_dir(tmp_path: Path) -> None: assert sorted(entries) == ["a.txt", "b.txt", "subdir"] +async def test_local_tmp_file_operator_list_dir_with_types(tmp_path: Path) -> None: + """LocalTmpFileOperator should list directory with type info.""" + op = LocalTmpFileOperator(tmp_path) + + await op.write_file("file1.txt", "content") + await op.write_file("file2.txt", "content") + await op.mkdir("dir1") + await op.mkdir("dir2") + + entries = await op.list_dir_with_types(".") + + # Should be sorted alphabetically + assert entries == [ + ("dir1", True), + ("dir2", True), + ("file1.txt", False), + ("file2.txt", False), + ] + + async def test_local_tmp_file_operator_move(tmp_path: Path) -> None: """LocalTmpFileOperator should support moving files.""" op = LocalTmpFileOperator(tmp_path) @@ -207,6 +231,24 @@ async def test_local_tmp_operator_read_bytes_with_offset(tmp_path: Path) -> None assert content == b"3456" +async def test_local_tmp_operator_read_bytes_seek_optimization(tmp_path: Path) -> None: + """Partial read should use seek instead of reading entire file.""" + op = LocalTmpFileOperator(tmp_path) + + # Create a larger file + large_content = b"x" * 100000 + b"TARGET" + b"y" * 100000 # ~200KB + await op.write_file("large.bin", large_content) + + # Read just the TARGET portion using offset/length + # This should only read 6 bytes, not the entire 200KB + content = await op.read_bytes("large.bin", offset=100000, length=6) + assert content == b"TARGET" + + # Verify offset-only read + content = await op.read_bytes("large.bin", offset=100000) + assert content == b"TARGET" + b"y" * 100000 + + async def test_local_tmp_operator_write_bytes(tmp_path: Path) -> None: """LocalTmpFileOperator should support writing bytes.""" op = LocalTmpFileOperator(tmp_path) @@ -245,3 +287,405 @@ async def test_local_tmp_operator_delete_dir(tmp_path: Path) -> None: await op.delete("empty_dir") assert not await op.exists("empty_dir") + + +# --- Streaming tests --- + + +async def test_local_tmp_operator_read_bytes_stream(tmp_path: Path) -> None: + """LocalTmpFileOperator should support streaming read.""" + op = LocalTmpFileOperator(tmp_path) + + # Create a file with known content + content = b"0123456789" * 100 # 1000 bytes + await op.write_file("test.bin", content) + + # Read as stream + stream = op.read_bytes_stream("test.bin", chunk_size=256) + chunks = [] + async for chunk in stream: + chunks.append(chunk) + + # Verify content + result = b"".join(chunks) + assert result == content + + # Verify chunking happened (with 256 byte chunks, should have multiple chunks) + assert len(chunks) > 1 + + +async def test_local_tmp_operator_read_bytes_stream_small_file(tmp_path: Path) -> None: + """LocalTmpFileOperator streaming should work for small files.""" + op = LocalTmpFileOperator(tmp_path) + + content = b"small content" + await op.write_file("small.bin", content) + + stream = op.read_bytes_stream("small.bin") + chunks = [] + async for chunk in stream: + chunks.append(chunk) + + assert b"".join(chunks) == content + + +async def test_local_tmp_operator_write_bytes_stream(tmp_path: Path) -> None: + """LocalTmpFileOperator should support streaming write.""" + op = LocalTmpFileOperator(tmp_path) + + # Create an async generator of chunks + async def chunk_generator(): + yield b"chunk1" + yield b"chunk2" + yield b"chunk3" + + await op.write_bytes_stream("test.bin", chunk_generator()) + + # Verify content + content = await op.read_bytes("test.bin") + assert content == b"chunk1chunk2chunk3" + + +async def test_local_tmp_operator_stream_roundtrip(tmp_path: Path) -> None: + """Streaming read/write should preserve data integrity.""" + op = LocalTmpFileOperator(tmp_path) + + # Create a larger file to test proper chunking + original = bytes(range(256)) * 1000 # 256KB + await op.write_file("original.bin", original) + + # Stream read from original + stream = op.read_bytes_stream("original.bin", chunk_size=4096) + + # Stream write to copy + await op.write_bytes_stream("copy.bin", stream) + + # Verify copy matches original + copy = await op.read_bytes("copy.bin") + assert copy == original + + +async def test_local_tmp_operator_stream_default_chunk_size(tmp_path: Path) -> None: + """Default chunk size should be DEFAULT_CHUNK_SIZE.""" + op = LocalTmpFileOperator(tmp_path) + + # Create a file larger than default chunk size + content = b"x" * (DEFAULT_CHUNK_SIZE + 1000) + await op.write_file("large.bin", content) + + stream = op.read_bytes_stream("large.bin") + chunks = [] + async for chunk in stream: + chunks.append(chunk) + + # Should have at least 2 chunks with default chunk size + assert len(chunks) >= 2 + assert b"".join(chunks) == content + + +# --- Cross-boundary streaming tests --- + + +class LocalFileOperator(FileOperator): + """A local filesystem FileOperator for testing cross-boundary operations.""" + + def __init__(self, default_path: Path, tmp_dir: Path, default_chunk_size: int = DEFAULT_CHUNK_SIZE) -> None: + super().__init__( + default_path=default_path, + allowed_paths=[default_path, tmp_dir], + tmp_dir=tmp_dir, + default_chunk_size=default_chunk_size, + ) + + async def _read_file_impl( + self, path: str, *, encoding: str = "utf-8", offset: int = 0, length: int | None = None + ) -> str: + import anyio + + resolved = self._default_path / path + content = await anyio.Path(resolved).read_text(encoding=encoding) + if offset > 0 or length is not None: + end = None if length is None else offset + length + content = content[offset:end] + return content + + async def _read_bytes_impl(self, path: str, *, offset: int = 0, length: int | None = None) -> bytes: + import anyio + + resolved = self._default_path / path + content = await anyio.Path(resolved).read_bytes() + if offset > 0 or length is not None: + end = None if length is None else offset + length + content = content[offset:end] + return content + + async def _write_file_impl(self, path: str, content: str | bytes, *, encoding: str = "utf-8") -> None: + import anyio + + resolved = self._default_path / path + apath = anyio.Path(resolved) + if isinstance(content, bytes): + await apath.write_bytes(content) + else: + await apath.write_text(content, encoding=encoding) + + async def _append_file_impl(self, path: str, content: str | bytes, *, encoding: str = "utf-8") -> None: + import anyio + + resolved = self._default_path / path + existing = await anyio.Path(resolved).read_bytes() if await anyio.Path(resolved).exists() else b"" + new_content = existing + (content if isinstance(content, bytes) else content.encode(encoding)) + await anyio.Path(resolved).write_bytes(new_content) + + async def _delete_impl(self, path: str) -> None: + import anyio + + resolved = self._default_path / path + apath = anyio.Path(resolved) + if await apath.is_dir(): + await apath.rmdir() + else: + await apath.unlink() + + async def _list_dir_impl(self, path: str) -> list[str]: + import anyio + + resolved = self._default_path / path + entries = [entry.name async for entry in anyio.Path(resolved).iterdir()] + return sorted(entries) + + async def _exists_impl(self, path: str) -> bool: + import anyio + + return await anyio.Path(self._default_path / path).exists() + + async def _is_file_impl(self, path: str) -> bool: + import anyio + + return await anyio.Path(self._default_path / path).is_file() + + async def _is_dir_impl(self, path: str) -> bool: + import anyio + + return await anyio.Path(self._default_path / path).is_dir() + + async def _mkdir_impl(self, path: str, *, parents: bool = False) -> None: + import anyio + + await anyio.Path(self._default_path / path).mkdir(parents=parents, exist_ok=True) + + async def _move_impl(self, src: str, dst: str) -> None: + import shutil + + import anyio + + src_resolved = self._default_path / src + dst_resolved = self._default_path / dst + await anyio.to_thread.run_sync(shutil.move, src_resolved, dst_resolved) + + async def _copy_impl(self, src: str, dst: str) -> None: + import shutil + + import anyio + + src_resolved = self._default_path / src + dst_resolved = self._default_path / dst + await anyio.to_thread.run_sync(shutil.copy2, src_resolved, dst_resolved) + + async def _stat_impl(self, path: str) -> FileStat: + import anyio + + resolved = self._default_path / path + st = await anyio.Path(resolved).stat() + return FileStat( + size=st.st_size, + mtime=st.st_mtime, + is_file=await anyio.Path(resolved).is_file(), + is_dir=await anyio.Path(resolved).is_dir(), + ) + + async def _glob_impl(self, pattern: str) -> list[str]: + matches = [] + for p in self._default_path.glob(pattern): + try: + rel = p.relative_to(self._default_path) + matches.append(str(rel)) + except ValueError: + matches.append(str(p)) + return sorted(matches) + + # Override streaming for true streaming behavior + async def _read_bytes_stream_impl( + self, + path: str, + *, + chunk_size: int = DEFAULT_CHUNK_SIZE, + ) -> AsyncIterator[bytes]: + import anyio + + resolved = self._default_path / path + async with await anyio.open_file(resolved, "rb") as f: + while True: + chunk = await f.read(chunk_size) + if not chunk: + break + yield chunk + + async def _write_bytes_stream_impl( + self, + path: str, + stream: AsyncIterator[bytes], + ) -> None: + import anyio + + resolved = self._default_path / path + async with await anyio.open_file(resolved, "wb") as f: + async for chunk in stream: + await f.write(chunk) + + +async def test_cross_boundary_copy_from_main_to_tmp(tmp_path: Path) -> None: + """Copy from main filesystem to tmp should use streaming.""" + main_dir = tmp_path / "main" + tmp_dir = tmp_path / "tmp" + main_dir.mkdir() + tmp_dir.mkdir() + + op = LocalFileOperator(main_dir, tmp_dir) + + # Create a large file in main dir + content = b"x" * (DEFAULT_CHUNK_SIZE * 3) # ~192KB + await op.write_file("source.bin", content) + + # Copy to tmp (cross-boundary) + await op.copy("source.bin", str(tmp_dir / "dest.bin")) + + # Verify content in tmp + result = await op.read_bytes(str(tmp_dir / "dest.bin")) + assert result == content + + # Original should still exist + assert await op.exists("source.bin") + + +async def test_cross_boundary_copy_from_tmp_to_main(tmp_path: Path) -> None: + """Copy from tmp to main filesystem should use streaming.""" + main_dir = tmp_path / "main" + tmp_dir = tmp_path / "tmp" + main_dir.mkdir() + tmp_dir.mkdir() + + op = LocalFileOperator(main_dir, tmp_dir) + + # Create a large file in tmp dir + content = b"y" * (DEFAULT_CHUNK_SIZE * 2) # ~128KB + await op._tmp_file_operator.write_file("source.bin", content) + + # Copy to main (cross-boundary) + await op.copy(str(tmp_dir / "source.bin"), "dest.bin") + + # Verify content in main + result = await op.read_bytes("dest.bin") + assert result == content + + # Original should still exist in tmp + assert await op._tmp_file_operator.exists("source.bin") + + +async def test_cross_boundary_move_from_main_to_tmp(tmp_path: Path) -> None: + """Move from main filesystem to tmp should use streaming.""" + main_dir = tmp_path / "main" + tmp_dir = tmp_path / "tmp" + main_dir.mkdir() + tmp_dir.mkdir() + + op = LocalFileOperator(main_dir, tmp_dir) + + # Create a file in main dir + content = b"move_test" * 10000 # ~90KB + await op.write_file("source.bin", content) + + # Move to tmp (cross-boundary) + await op.move("source.bin", str(tmp_dir / "dest.bin")) + + # Verify content in tmp + result = await op.read_bytes(str(tmp_dir / "dest.bin")) + assert result == content + + # Original should be deleted + assert not await op.exists("source.bin") + + +async def test_cross_boundary_move_from_tmp_to_main(tmp_path: Path) -> None: + """Move from tmp to main filesystem should use streaming.""" + main_dir = tmp_path / "main" + tmp_dir = tmp_path / "tmp" + main_dir.mkdir() + tmp_dir.mkdir() + + op = LocalFileOperator(main_dir, tmp_dir) + + # Create a file in tmp dir + content = b"move_test_2" * 10000 + await op._tmp_file_operator.write_file("source.bin", content) + + # Move to main (cross-boundary) + await op.move(str(tmp_dir / "source.bin"), "dest.bin") + + # Verify content in main + result = await op.read_bytes("dest.bin") + assert result == content + + # Original should be deleted from tmp + assert not await op._tmp_file_operator.exists("source.bin") + + +async def test_file_operator_default_chunk_size(tmp_path: Path) -> None: + """FileOperator should use default_chunk_size in cross-boundary operations.""" + main_dir = tmp_path / "main" + tmp_dir = tmp_path / "tmp" + main_dir.mkdir() + tmp_dir.mkdir() + + # Use a custom small chunk size + custom_chunk_size = 256 + + # Track chunks read from main operator + chunks_read: list[int] = [] + + class TrackedFileOperator(LocalFileOperator): + """FileOperator that tracks chunk sizes used in streaming.""" + + async def _read_bytes_stream_impl( + self, + path: str, + *, + chunk_size: int = DEFAULT_CHUNK_SIZE, + ) -> AsyncIterator[bytes]: + import anyio + + chunks_read.append(chunk_size) + resolved = self._default_path / path + async with await anyio.open_file(resolved, "rb") as f: + while True: + chunk = await f.read(chunk_size) + if not chunk: + break + yield chunk + + op = TrackedFileOperator(main_dir, tmp_dir, default_chunk_size=custom_chunk_size) + + # Create a file in main dir + content = b"x" * 1024 # 1KB file + await op.write_file("source.bin", content) + + # Copy to tmp (cross-boundary) - should use custom chunk size + await op.copy("source.bin", str(tmp_dir / "dest.bin")) + + # Verify chunk size was used + assert len(chunks_read) > 0 + assert chunks_read[-1] == custom_chunk_size + + # Verify content is correct + result = await op.read_bytes(str(tmp_dir / "dest.bin")) + assert result == content diff --git a/tests/test_resources.py b/tests/test_resources.py index 45ef957..25cf630 100644 --- a/tests/test_resources.py +++ b/tests/test_resources.py @@ -584,3 +584,49 @@ async def test_registry_contains() -> None: assert "test" not in env.resources env.resources.set("test", SimpleResource()) assert "test" in env.resources + + +async def test_registry_close_all_parallel() -> None: + """close_all with parallel=True should close resources concurrently.""" + async with MockEnvironment() as env: + resource1 = SimpleResource() + resource2 = SimpleResource() + resource3 = MinimalBaseResource() + + env.resources.set("r1", resource1) + env.resources.set("r2", resource2) + env.resources.set("r3", resource3) + + # Close with parallel=True + await env.resources.close_all(parallel=True) + + assert resource1.closed + assert resource2.closed + assert resource3.closed + assert len(env.resources) == 0 + + +async def test_registry_close_all_parallel_with_exception() -> None: + """Parallel close should continue even if a resource fails.""" + + class FailingResource: + def __init__(self) -> None: + self.closed = False + + def close(self) -> None: + raise RuntimeError("Failed to close") + + async with MockEnvironment() as env: + good1 = SimpleResource() + bad = FailingResource() + good2 = SimpleResource() + + env.resources.set("good1", good1) + env.resources.set("bad", bad) + env.resources.set("good2", good2) + + # Should not raise + await env.resources.close_all(parallel=True) + + assert good1.closed + assert good2.closed