Skip to content

Commit

Permalink
feat: add better type hints for impersonate and session (#359)
Browse files Browse the repository at this point in the history
* feat: add better type hints for impersonate and session

* style: ruff

* style: mypy

* fix: requests __all__

* fix: circular import

* fix: circular import

* fix: test case

* feat: preserve browser enum

* refactor(requests): standardize default browser variables to uppercase
  • Loading branch information
vvanglro authored Jul 30, 2024
1 parent 0cf03fe commit 8691907
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 106 deletions.
10 changes: 9 additions & 1 deletion curl_cffi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,13 @@
# This line includes _wrapper.so into the wheel
from ._wrapper import ffi, lib
from .aio import AsyncCurl
from .const import CurlECode, CurlHttpVersion, CurlInfo, CurlMOpt, CurlOpt, CurlWsFlag, CurlSslVersion
from .const import (
CurlECode,
CurlHttpVersion,
CurlInfo,
CurlMOpt,
CurlOpt,
CurlSslVersion,
CurlWsFlag,
)
from .curl import Curl, CurlError, CurlMime
6 changes: 3 additions & 3 deletions curl_cffi/_asyncio_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,15 @@ def close(self) -> None:
self._selector.close()
self._real_loop.close()

def add_reader(
def add_reader( # type: ignore
self,
fd: "_FileDescriptorLike",
callback: Callable[..., None],
*args: Any, # type: ignore
*args: Any,
) -> None:
return self._selector.add_reader(fd, callback, *args)

def add_writer(
def add_writer( # type: ignore
self,
fd: "_FileDescriptorLike",
callback: Callable[..., None],
Expand Down
4 changes: 2 additions & 2 deletions curl_cffi/curl.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def setopt(self, option: CurlOpt, value: Any) -> int:
10000: "char*",
20000: "void*",
30000: "int64_t*", # offset type
40000: "void*", # blob type
40000: "void*", # blob type
}
# print("option", option, "value", value)

Expand Down Expand Up @@ -195,7 +195,7 @@ def setopt(self, option: CurlOpt, value: Any) -> int:
if option == CurlOpt.POSTFIELDS:
self._body_handle = c_value
else:
raise NotImplementedError("Option unsupported: %s" % option)
raise NotImplementedError(f"Option unsupported: {option}")

if option == CurlOpt.HTTPHEADER:
for header in value:
Expand Down
11 changes: 7 additions & 4 deletions curl_cffi/requests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"Session",
"AsyncSession",
"BrowserType",
"BrowserTypeLiteral",
"CurlWsFlag",
"request",
"head",
Expand All @@ -20,21 +21,23 @@
"WebSocketError",
"WsCloseCode",
"ExtraFingerprints",
"CookieTypes",
"HeaderTypes",
"ProxySpec",
]

from functools import partial
from io import BytesIO
from typing import Callable, Dict, List, Optional, Tuple, Union


from ..const import CurlHttpVersion, CurlWsFlag
from ..curl import CurlMime
from .cookies import Cookies, CookieTypes
from .errors import RequestsError
from .headers import Headers, HeaderTypes
from .impersonate import ExtraFingerprints, ExtraFpDict
from .impersonate import BrowserType, BrowserTypeLiteral, ExtraFingerprints, ExtraFpDict
from .models import Request, Response
from .session import AsyncSession, BrowserType, ProxySpec, Session, ThreadType
from .session import AsyncSession, ProxySpec, Session, ThreadType
from .websockets import WebSocket, WebSocketError, WsCloseCode


Expand All @@ -58,7 +61,7 @@ def request(
referer: Optional[str] = None,
accept_encoding: Optional[str] = "gzip, deflate, br, zstd",
content_callback: Optional[Callable] = None,
impersonate: Optional[Union[str, BrowserType]] = None,
impersonate: Optional[BrowserTypeLiteral] = None,
ja3: Optional[str] = None,
akamai: Optional[str] = None,
extra_fp: Optional[Union[ExtraFingerprints, ExtraFpDict]] = None,
Expand Down
17 changes: 4 additions & 13 deletions curl_cffi/requests/cookies.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,7 @@ def update_cookies_from_curl(self, morsels: List[CurlMorsel]):
self.jar.set_cookie(cookie)
self.jar.clear_expired_cookies()

def set(
self, name: str, value: str, domain: str = "", path: str = "/", secure=False
) -> None:
def set(self, name: str, value: str, domain: str = "", path: str = "/", secure=False) -> None:
"""
Set a cookie value by name. May optionally include domain and path.
"""
Expand Down Expand Up @@ -268,18 +266,14 @@ def get( # type: ignore
return default
return value

def get_dict(
self, domain: Optional[str] = None, path: Optional[str] = None
) -> dict:
def get_dict(self, domain: Optional[str] = None, path: Optional[str] = None) -> dict:
"""
Cookies with the same name on different domains may overwrite each other,
do NOT use this function as a method of serialization.
"""
ret = {}
for cookie in self.jar:
if (domain is None or cookie.name == domain) and (
path is None or cookie.path == path
):
if (domain is None or cookie.name == domain) and (path is None or cookie.path == path):
ret[cookie.name] = cookie.value
return ret

Expand Down Expand Up @@ -350,10 +344,7 @@ def __bool__(self) -> bool:

def __repr__(self) -> str:
cookies_repr = ", ".join(
[
f"<Cookie {cookie.name}={cookie.value} for {cookie.domain} />"
for cookie in self.jar
]
[f"<Cookie {cookie.name}={cookie.value} for {cookie.domain} />" for cookie in self.jar]
)

return f"<Cookies[{cookies_repr}]>"
92 changes: 65 additions & 27 deletions curl_cffi/requests/impersonate.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,63 @@
from dataclasses import dataclass
from typing import List, Literal, Optional, TypedDict
import warnings
from dataclasses import dataclass
from enum import Enum
from typing import List, Literal, Optional, TypedDict

from ..const import CurlSslVersion, CurlOpt
from ..const import CurlOpt, CurlSslVersion

BrowserTypeLiteral = Literal[
# Edge
"edge99",
"edge101",
# Chrome
"chrome99",
"chrome100",
"chrome101",
"chrome104",
"chrome107",
"chrome110",
"chrome116",
"chrome119",
"chrome120",
"chrome123",
"chrome124",
"chrome99_android",
# Safari
"safari15_3",
"safari15_5",
"safari17_0",
"safari17_2_ios",
# alias
"chrome",
"edge",
"safari",
"safari_ios",
"chrome_android",
]

class BrowserType(str, Enum):
DEFAULT_CHROME = "chrome124"
DEFAULT_EDGE = "edge101"
DEFAULT_SAFARI = "safari17_0"
DEFAULT_SAFARI_IOS = "safari17_2_ios"
DEFAULT_CHROME_ANDROID = "chrome99_android"


def normalize_browser_type(item):
if item == "chrome": # noqa: SIM116
return DEFAULT_CHROME
elif item == "edge":
return DEFAULT_EDGE
elif item == "safari":
return DEFAULT_SAFARI
elif item == "safari_ios":
return DEFAULT_SAFARI_IOS
elif item == "chrome_android":
return DEFAULT_CHROME_ANDROID
else:
return item


class BrowserType(str, Enum): # todo: remove in version 1.x
edge99 = "edge99"
edge101 = "edge101"
chrome99 = "chrome99"
Expand All @@ -26,25 +77,6 @@ class BrowserType(str, Enum):
safari17_0 = "safari17_0"
safari17_2_ios = "safari17_2_ios"

chrome = "chrome124"
safari = "safari17_0"
safari_ios = "safari17_2_ios"

@classmethod
def has(cls, item):
return item in cls.__members__

@classmethod
def normalize(cls, item):
if item == "chrome": # noqa: SIM116
return cls.chrome
elif item == "safari":
return cls.safari
elif item == "safari_ios":
return cls.safari_ios
else:
return item


@dataclass
class ExtraFingerprints:
Expand Down Expand Up @@ -214,10 +246,10 @@ class ExtraFpDict(TypedDict, total=False):
# 64251-64767:"Unassigned
64768: "ech_outer_extensions",
# 64769-65036:"Unassigned
65037:"encrypted_client_hello",
65037: "encrypted_client_hello",
# 65038-65279:"Unassigned
# 65280:"Reserved for Private Use
65281:"renegotiation_info",
65281: "renegotiation_info",
# 65282-65535:"Reserved for Private Use
}

Expand All @@ -243,7 +275,11 @@ def toggle_extension(curl, extension_id: int, enable: bool):
# compress certificate
elif extension_id == 27:
if enable:
warnings.warn("Cert compression setting to brotli, you had better specify which to use: zlib/brotli")
warnings.warn(
"Cert compression setting to brotli, "
"you had better specify which to use: zlib/brotli",
stacklevel=1,
)
curl.setopt(CurlOpt.SSL_CERT_COMPRESSION, "brotli")
else:
curl.setopt(CurlOpt.SSL_CERT_COMPRESSION, "")
Expand Down Expand Up @@ -277,4 +313,6 @@ def toggle_extension(curl, extension_id: int, enable: bool):
elif extension_id == 21:
pass
else:
raise NotImplementedError(f"This extension({extension_id}) can not be toggled for now, it may be updated later.")
raise NotImplementedError(
f"This extension({extension_id}) can not be toggled for now, it may be updated later."
)
18 changes: 10 additions & 8 deletions curl_cffi/requests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,11 @@ def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None):
if pending is not None:
chunk = pending + chunk
lines = chunk.split(delimiter) if delimiter else chunk.splitlines()
if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]:
pending = lines.pop()
else:
pending = None
pending = (
lines.pop()
if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]
else None
)

yield from lines

Expand Down Expand Up @@ -214,10 +215,11 @@ async def aiter_lines(self, chunk_size=None, decode_unicode=False, delimiter=Non
if pending is not None:
chunk = pending + chunk
lines = chunk.split(delimiter) if delimiter else chunk.splitlines()
if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]:
pending = lines.pop()
else:
pending = None
pending = (
lines.pop()
if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]
else None
)

for line in lines:
yield line
Expand Down
Loading

0 comments on commit 8691907

Please sign in to comment.