Skip to content

Commit 0930245

Browse files
committed
Add type hints and a test
1 parent 5818fb0 commit 0930245

File tree

2 files changed

+122
-4
lines changed

2 files changed

+122
-4
lines changed

scrapy_zyte_api/_session.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -982,23 +982,27 @@ class LocationSessionConfig(SessionConfig):
982982
as a parameter.
983983
"""
984984

985-
def params(self, request):
985+
def params(self, request: Request) -> Dict[str, Any]:
986986
if not (location := self.location(request)):
987987
return super().params(request)
988988
return self.location_params(request, location)
989989

990-
def check(self, response, request):
990+
def check(self, response: Response, request: Request) -> bool:
991991
if not (location := self.location(request)):
992992
return super().check(response, request)
993993
return self.location_check(response, request, location)
994994

995-
def location_params(self, request, location):
995+
def location_params(
996+
self, request: Request, location: Dict[str, Any]
997+
) -> Dict[str, Any]:
996998
"""Like :class:`SessionConfig.params
997999
<scrapy_zyte_api.SessionConfig.params>`, but it is only called when a
9981000
location it set, and gets that *location* as a parameter."""
9991001
return super().params(request)
10001002

1001-
def location_check(self, response, request, location):
1003+
def location_check(
1004+
self, response: Response, request: Request, location: Dict[str, Any]
1005+
) -> bool:
10021006
"""Like :class:`SessionConfig.check
10031007
<scrapy_zyte_api.SessionConfig.check>`, but it is only called when a
10041008
location it set, and gets that *location* as a parameter."""

tests/test_sessions.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from scrapy_zyte_api import (
1818
SESSION_AGGRESSIVE_RETRY_POLICY,
1919
SESSION_DEFAULT_RETRY_POLICY,
20+
LocationSessionConfig,
2021
SessionConfig,
2122
is_session_init_request,
2223
session_config,
@@ -2080,6 +2081,119 @@ class CustomSessionConfig(SessionConfig):
20802081
pass
20812082

20822083

2084+
@ensureDeferred
2085+
async def test_location_session_config(mockserver):
2086+
pytest.importorskip("web_poet")
2087+
2088+
@session_config(
2089+
[
2090+
"postal-code-10001.example",
2091+
"postal-code-10001-fail.example",
2092+
"postal-code-10001-alternative.example",
2093+
]
2094+
)
2095+
class CustomSessionConfig(LocationSessionConfig):
2096+
2097+
def location_params(
2098+
self, request: Request, location: Dict[str, Any]
2099+
) -> Dict[str, Any]:
2100+
assert location == {"postalCode": "10002"}
2101+
return {
2102+
"actions": [
2103+
{
2104+
"action": "setLocation",
2105+
"address": {"postalCode": "10001"},
2106+
}
2107+
]
2108+
}
2109+
2110+
def location_check(
2111+
self, response: Response, request: Request, location: Dict[str, Any]
2112+
) -> bool:
2113+
assert location == {"postalCode": "10002"}
2114+
domain = urlparse_cached(request).netloc
2115+
return "fail" not in domain
2116+
2117+
def pool(self, request: Request) -> str:
2118+
domain = urlparse_cached(request).netloc
2119+
if domain == "postal-code-10001-alternative.example":
2120+
return "postal-code-10001.example"
2121+
return domain
2122+
2123+
settings = {
2124+
"RETRY_TIMES": 0,
2125+
"ZYTE_API_URL": mockserver.urljoin("/"),
2126+
"ZYTE_API_SESSION_ENABLED": True,
2127+
# We set a location to force the location-specific methods of the
2128+
# session config class to be called, but we set the wrong location so
2129+
# that the test would not pass were it not for our custom
2130+
# implementation which ignores the input location and instead sets the
2131+
# right one.
2132+
"ZYTE_API_SESSION_LOCATION": {"postalCode": "10002"},
2133+
"ZYTE_API_SESSION_MAX_BAD_INITS": 1,
2134+
}
2135+
2136+
class TestSpider(Spider):
2137+
name = "test"
2138+
start_urls = [
2139+
"https://postal-code-10001.example",
2140+
"https://postal-code-10001-alternative.example",
2141+
"https://postal-code-10001-fail.example",
2142+
]
2143+
2144+
def start_requests(self):
2145+
for url in self.start_urls:
2146+
yield Request(
2147+
url,
2148+
meta={
2149+
"zyte_api_automap": {
2150+
"actions": [
2151+
{
2152+
"action": "setLocation",
2153+
"address": {"postalCode": "10001"},
2154+
}
2155+
]
2156+
},
2157+
},
2158+
)
2159+
2160+
def parse(self, response):
2161+
pass
2162+
2163+
crawler = await get_crawler(settings, spider_cls=TestSpider, setup_engine=False)
2164+
await crawler.crawl()
2165+
2166+
session_stats = {
2167+
k: v
2168+
for k, v in crawler.stats.get_stats().items()
2169+
if k.startswith("scrapy-zyte-api/sessions")
2170+
}
2171+
assert session_stats == {
2172+
"scrapy-zyte-api/sessions/pools/postal-code-10001.example/init/check-passed": 2,
2173+
"scrapy-zyte-api/sessions/pools/postal-code-10001.example/use/check-passed": 2,
2174+
"scrapy-zyte-api/sessions/pools/postal-code-10001-fail.example/init/check-failed": 1,
2175+
}
2176+
2177+
# Clean up the session config registry, and check it, otherwise we could
2178+
# affect other tests.
2179+
2180+
session_config_registry.__init__() # type: ignore[misc]
2181+
2182+
crawler = await get_crawler(settings, spider_cls=TestSpider, setup_engine=False)
2183+
await crawler.crawl()
2184+
2185+
session_stats = {
2186+
k: v
2187+
for k, v in crawler.stats.get_stats().items()
2188+
if k.startswith("scrapy-zyte-api/sessions")
2189+
}
2190+
assert session_stats == {
2191+
"scrapy-zyte-api/sessions/pools/postal-code-10001.example/init/failed": 1,
2192+
"scrapy-zyte-api/sessions/pools/postal-code-10001-alternative.example/init/failed": 1,
2193+
"scrapy-zyte-api/sessions/pools/postal-code-10001-fail.example/init/failed": 1,
2194+
}
2195+
2196+
20832197
@ensureDeferred
20842198
async def test_session_refresh(mockserver):
20852199
"""If a response does not pass a session validity check, the session is

0 commit comments

Comments
 (0)