Skip to content

Commit f37439f

Browse files
committed
feat: integrate aiofiles for asynchronous file handling and update dependencies
1 parent 8ffa25a commit f37439f

File tree

6 files changed

+2505
-48
lines changed

6 files changed

+2505
-48
lines changed

decart_sdk/client.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, Optional
2+
import aiohttp
23
from pydantic import BaseModel, Field, field_validator, ValidationError
34
from .errors import InvalidAPIKeyError, InvalidBaseURLError, InvalidInputError
45
from .models import ModelDefinition
@@ -39,6 +40,27 @@ def __init__(self, api_key: str, base_url: str = "https://api.decart.ai") -> Non
3940

4041
self.api_key = api_key
4142
self.base_url = base_url
43+
self._session: Optional[aiohttp.ClientSession] = None
44+
45+
async def _get_session(self) -> aiohttp.ClientSession:
46+
"""Get or create the aiohttp session."""
47+
if self._session is None or self._session.closed:
48+
timeout = aiohttp.ClientTimeout(total=300)
49+
self._session = aiohttp.ClientSession(timeout=timeout)
50+
return self._session
51+
52+
async def close(self) -> None:
53+
"""Close the HTTP session and cleanup resources."""
54+
if self._session and not self._session.closed:
55+
await self._session.close()
56+
57+
async def __aenter__(self):
58+
"""Async context manager entry."""
59+
return self
60+
61+
async def __aexit__(self, exc_type, exc_val, exc_tb):
62+
"""Async context manager exit."""
63+
await self.close()
4264

4365
async def process(self, options: dict[str, Any]) -> bytes:
4466
"""
@@ -62,29 +84,33 @@ async def process(self, options: dict[str, Any]) -> bytes:
6284

6385
inputs = {k: v for k, v in options.items() if k not in ("model", "cancel_token")}
6486

65-
# Separate file inputs from other inputs
66-
file_fields = {"data", "start", "end"}
67-
file_inputs = {k: v for k, v in inputs.items() if k in file_fields}
68-
non_file_inputs = {k: v for k, v in inputs.items() if k not in file_fields}
87+
# File fields that need special handling (not validated by Pydantic)
88+
FILE_FIELDS = {"data", "start", "end"}
89+
90+
# Separate file inputs from regular inputs
91+
file_inputs = {k: v for k, v in inputs.items() if k in FILE_FIELDS}
92+
non_file_inputs = {k: v for k, v in inputs.items() if k not in FILE_FIELDS}
6993

70-
# Validate inputs using model's schema
71-
validation_inputs = non_file_inputs.copy()
72-
for field in file_fields:
73-
if field in file_inputs:
74-
validation_inputs[field] = b"" # Placeholder for validation
94+
# Validate non-file inputs and create placeholder for file fields
95+
validation_inputs = {
96+
**non_file_inputs,
97+
**{k: b"" for k in file_inputs.keys()} # Placeholder bytes for validation
98+
}
7599

76100
try:
77101
validated_inputs = model.input_schema(**validation_inputs)
78102
except ValidationError as e:
79103
raise InvalidInputError(f"Invalid inputs for {model.name}: {str(e)}") from e
80104

81-
# Merge validated inputs with file inputs
82-
processed_inputs = validated_inputs.model_dump(exclude_none=True)
83-
for field in file_fields:
84-
if field in file_inputs:
85-
processed_inputs[field] = file_inputs[field]
105+
# Build final inputs: validated non-file inputs + original file inputs
106+
processed_inputs = {
107+
**validated_inputs.model_dump(exclude_none=True),
108+
**file_inputs # Override placeholders with actual file data
109+
}
86110

111+
session = await self._get_session()
87112
response = await send_request(
113+
session=session,
88114
base_url=self.base_url,
89115
api_key=self.api_key,
90116
model=model,

decart_sdk/process/request.py

Lines changed: 73 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,83 @@
11
import aiohttp
2+
import aiofiles
23
import asyncio
4+
from pathlib import Path
35
from typing import Any, Optional
46
from ..types import FileInput
57
from ..models import ModelDefinition
68
from ..errors import InvalidInputError, ProcessingError
79

810

9-
async def file_input_to_bytes(input_data: FileInput) -> tuple[bytes, str]:
11+
async def file_input_to_bytes(
12+
input_data: FileInput, session: aiohttp.ClientSession
13+
) -> tuple[bytes, str]:
14+
"""Convert various file input types to bytes asynchronously.
15+
16+
Args:
17+
input_data: The file input (bytes, Path, str, or file-like object)
18+
session: Reusable aiohttp session for URL fetching
19+
20+
Returns:
21+
Tuple of (content bytes, content type)
22+
23+
Raises:
24+
InvalidInputError: If input is invalid or processing fails
25+
"""
26+
1027
if isinstance(input_data, bytes):
1128
return input_data, "application/octet-stream"
1229

13-
if hasattr(input_data, "read"):
30+
if isinstance(input_data, Path):
31+
# Async file reading with aiofiles
32+
try:
33+
async with aiofiles.open(input_data, mode="rb") as f:
34+
content = await f.read()
35+
return content, "application/octet-stream"
36+
except FileNotFoundError:
37+
raise InvalidInputError(f"File not found: {input_data}")
38+
except Exception as e:
39+
raise InvalidInputError(f"Failed to read file {input_data}: {str(e)}")
40+
41+
if isinstance(input_data, str):
42+
# Check if it's a file path
43+
path = Path(input_data)
44+
if path.exists():
45+
try:
46+
async with aiofiles.open(path, mode="rb") as f:
47+
content = await f.read()
48+
return content, "application/octet-stream"
49+
except Exception as e:
50+
raise InvalidInputError(f"Failed to read file {input_data}: {str(e)}")
51+
52+
# Otherwise treat as URL
53+
if not input_data.startswith(("http://", "https://")):
54+
raise InvalidInputError(
55+
f"Input must be a URL (http:// or https://) or existing file path: {input_data}"
56+
)
57+
58+
# Use the provided session instead of creating a new one
59+
async with session.get(input_data) as response:
60+
if not response.ok:
61+
raise InvalidInputError(
62+
f"Failed to fetch file from URL: {response.status}"
63+
)
64+
content = await response.read()
65+
content_type = response.headers.get("Content-Type", "application/octet-stream")
66+
return content, content_type
67+
68+
from ..types import HasRead
69+
if isinstance(input_data, HasRead):
70+
# Sync file-like objects (for backwards compatibility)
1471
content = await asyncio.to_thread(input_data.read)
1572
if isinstance(content, str):
1673
content = content.encode()
1774
return content, "application/octet-stream"
1875

19-
if isinstance(input_data, str):
20-
if not input_data.startswith(("http://", "https://")):
21-
raise InvalidInputError("URL must start with http:// or https://")
22-
23-
async with aiohttp.ClientSession() as session:
24-
async with session.get(input_data) as response:
25-
if not response.ok:
26-
raise InvalidInputError(
27-
f"Failed to fetch file from URL: {response.status}"
28-
)
29-
content = await response.read()
30-
content_type = response.headers.get("Content-Type", "application/octet-stream")
31-
return content, content_type
32-
33-
raise InvalidInputError("Invalid file input type")
76+
raise InvalidInputError(f"Invalid file input type: {type(input_data)}")
3477

3578

3679
async def send_request(
80+
session: aiohttp.ClientSession,
3781
base_url: str,
3882
api_key: str,
3983
model: ModelDefinition,
@@ -45,28 +89,25 @@ async def send_request(
4589
for key, value in inputs.items():
4690
if value is not None:
4791
if key in ("data", "start", "end"):
48-
content, content_type = await file_input_to_bytes(value)
92+
content, content_type = await file_input_to_bytes(value, session)
4993
form_data.add_field(key, content, content_type=content_type)
5094
else:
5195
form_data.add_field(key, str(value))
5296

5397
endpoint = f"{base_url}{model.url_path}"
5498

55-
timeout = aiohttp.ClientTimeout(total=300)
56-
5799
async def make_request() -> bytes:
58-
async with aiohttp.ClientSession(timeout=timeout) as session:
59-
async with session.post(
60-
endpoint,
61-
headers={"X-API-KEY": api_key},
62-
data=form_data,
63-
) as response:
64-
if not response.ok:
65-
error_text = await response.text()
66-
raise ProcessingError(
67-
f"Processing failed: {response.status} - {error_text}"
68-
)
69-
return await response.read()
100+
async with session.post(
101+
endpoint,
102+
headers={"X-API-KEY": api_key},
103+
data=form_data,
104+
) as response:
105+
if not response.ok:
106+
error_text = await response.text()
107+
raise ProcessingError(
108+
f"Processing failed: {response.status} - {error_text}"
109+
)
110+
return await response.read()
70111

71112
if cancel_token:
72113
request_task = asyncio.create_task(make_request())

decart_sdk/types.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1-
from typing import BinaryIO, Union, Optional
1+
from typing import BinaryIO, Union, Optional, Protocol, runtime_checkable
2+
from pathlib import Path
23
from pydantic import BaseModel, Field
34

45

5-
FileInput = Union[BinaryIO, bytes, str]
6+
@runtime_checkable
7+
class HasRead(Protocol):
8+
"""Protocol for file-like objects with a read method."""
9+
def read(self) -> Union[bytes, str]:
10+
...
11+
12+
13+
FileInput = Union[HasRead, bytes, str, Path]
614

715

816
class Prompt(BaseModel):

0 commit comments

Comments
 (0)