From dac061c07d21f945568bfd3572be7652929f8c09 Mon Sep 17 00:00:00 2001
From: Dmytro Krasun <krasun.net@gmail.com>
Date: Wed, 12 Jun 2024 16:22:58 +0300
Subject: [PATCH] =?UTF-8?q?Improve=20error=20handling=E2=80=94add=20error?=
 =?UTF-8?q?=20codes=20and=20HTTP=20status=20code=20to=20exceptions?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 README.md                |  42 +++-
 pyproject.toml           |  18 +-
 src/screenshotone/sdk.py | 400 +++++++++++++++++++++------------------
 3 files changed, 262 insertions(+), 198 deletions(-)

diff --git a/README.md b/README.md
index a214158..edffde2 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
 
 An official Python SDK for [ScreenshotOne.com API](https://screenshotone.com) to take screenshots of URLs, render HTML as images and PDF.
 
-It takes minutes to start taking screenshots. Just [sign up](https://screenshotone.com/) to get access and secret keys, import the client, and you are ready to go. 
+It takes minutes to start taking screenshots. Just [sign up](https://screenshotone.com/) to get access and secret keys, import the client, and you are ready to go.
 
 The SDK client is synchronized with the latest [screenshot API options](https://screenshotone.com/docs/options/).
 
@@ -14,12 +14,13 @@ pip install screenshotone
 
 ## Usage
 
-Generate a screenshot URL without executing the request. Or download the screenshot. It is up to you: 
+Generate a screenshot URL without executing the request. Or download the screenshot. It is up to you:
+
 ```python
 import shutil
 from screenshotone import Client, TakeOptions
 
-# create API client 
+# create API client
 client = Client('<your access key>', '<your secret key>')
 
 # set up options
@@ -42,10 +43,37 @@ with open('example.png', 'wb') as result_file:
     shutil.copyfileobj(image, result_file)
 ```
 
-## Release 
+### How to handle errors
+
+Read about [how to handle the ScreenshotOne API errors](https://screenshotone.com/docs/guides/how-to-handle-api-errors/), and that's how you can get the HTTP status code and the error code of the request:
+
+```python
+try:
+    # ...
+    # render a screenshot and download the image as stream
+    image = client.take(options)
+    # ...
+except InvalidRequestException as e:
+    print(f"Invalid request: {e}")
+    if e.http_status_code:
+        print(f"HTTP Status Code: {e.http_status_code}")
+    if e.error_code:
+        print(f"Error Code: {e.error_code}")
+except APIErrorException as e:
+    print(f"API Error: {e}")
+    if e.http_status_code:
+        print(f"HTTP Status Code: {e.http_status_code}")
+    if e.error_code:
+        print(f"Error Code: {e.error_code}")
+except Exception as e:
+    # handle any other exceptions
+    print(f"An unexpected error occurred: {e}")
+```
+
+## Release
 
-[Github Actions](https://github.com/screenshotone/pythonsdk/blob/main/.github/workflows/pypi-release.yml) is used to automate the release process and publishing to PyPI. Update the library version in `pyproject.toml` and [create a new release](https://github.com/screenshotone/pythonsdk/releases/new) to launch the `publish` workflow. 
+[Github Actions](https://github.com/screenshotone/pythonsdk/blob/main/.github/workflows/pypi-release.yml) is used to automate the release process and publishing to PyPI. Update the library version in `pyproject.toml` and [create a new release](https://github.com/screenshotone/pythonsdk/releases/new) to launch the `publish` workflow.
 
-## License 
+## License
 
-`screenshotone/pythonsdk` is released under [the MIT license](LICENSE).
\ No newline at end of file
+`screenshotone/pythonsdk` is released under [the MIT license](LICENSE).
diff --git a/pyproject.toml b/pyproject.toml
index 4898732..f4131aa 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,21 +3,17 @@ requires = ["hatchling"]
 build-backend = "hatchling.build"
 [project]
 name = "screenshotone"
-version = "0.0.12"
-authors = [
-  { name="Dmytro Krasun", email="support@screenshotone.com" },
-]
-dependencies = [
-    "requests >= 2.28.1"
-]
+version = "0.0.13"
+authors = [{ name = "Dmytro Krasun", email = "support@screenshotone.com" }]
+dependencies = ["requests >= 2.28.1"]
 description = "A Python SDK for ScreenshotOne.com API to take screenshots of URLs, render HTML as images and PDF"
 readme = "README.md"
 requires-python = ">=3.7"
 classifiers = [
-    "Programming Language :: Python :: 3",
-    "License :: OSI Approved :: MIT License",
-    "Operating System :: OS Independent"
+  "Programming Language :: Python :: 3",
+  "License :: OSI Approved :: MIT License",
+  "Operating System :: OS Independent",
 ]
 [project.urls]
 "Homepage" = "https://screenshotone.com/"
-"Docs" = "https://screenshotone.com/docs/getting-started/"
\ No newline at end of file
+"Docs" = "https://screenshotone.com/docs/getting-started/"
diff --git a/src/screenshotone/sdk.py b/src/screenshotone/sdk.py
index f895f6d..e66bcbc 100644
--- a/src/screenshotone/sdk.py
+++ b/src/screenshotone/sdk.py
@@ -6,368 +6,382 @@
 import requests
 import json
 
-API_BASE_URL = 'https://api.screenshotone.com'
-API_TAKE_PATH = '/take'
+API_BASE_URL = "https://api.screenshotone.com"
+API_TAKE_PATH = "/take"
+
 
 class InvalidRequestException(Exception):
-    pass
+    def __init__(self, message, http_status_code=None, error_code=None):
+        self.http_status_code = http_status_code
+        self.error_code = error_code
+
+        super().__init__(message)
+
+
 class APIErrorException(Exception):
-    pass
 
-class TakeOptions: 
-    options = OrderedDict()    
+    def __init__(self, message, http_status_code=None, error_code=None):
+        self.http_status_code = http_status_code
+        self.error_code = error_code
+
+        super().__init__(message)
+
+
+class TakeOptions:
+    options = OrderedDict()
 
     def __init__(self, defaults):
-        for key, value in defaults.items(): 
-            self.options[key] = value    
+        for key, value in defaults.items():
+            self.options[key] = value
 
-    def url(url): 
-        return TakeOptions({'url': url})
+    def url(url):
+        return TakeOptions({"url": url})
 
-    def html(html): 
-        return TakeOptions({'html': html})
+    def html(html):
+        return TakeOptions({"html": html})
 
-    def markdown(markdown): 
-        return TakeOptions({'markdown': markdown})
+    def markdown(markdown):
+        return TakeOptions({"markdown": markdown})
 
-    def signature(self, value): 	
-        self.options['signature'] = value
+    def signature(self, value):
+        self.options["signature"] = value
 
         return self
 
-    def selector(self, value): 	
-        self.options['selector'] = value
+    def selector(self, value):
+        self.options["selector"] = value
 
         return self
 
-    def error_on_selector_not_found(self, value): 	
-        self.options['error_on_selector_not_found'] = value
+    def error_on_selector_not_found(self, value):
+        self.options["error_on_selector_not_found"] = value
 
         return self
 
-    def response_type(self, value): 	
-        self.options['response_type'] = value
+    def response_type(self, value):
+        self.options["response_type"] = value
 
         return self
 
-    def openai_api_key(self, value): 	
-        self.options['openai_api_key'] = value
+    def openai_api_key(self, value):
+        self.options["openai_api_key"] = value
 
         return self
 
-    def vision_prompt(self, value): 	
-        self.options['vision_prompt'] = value
+    def vision_prompt(self, value):
+        self.options["vision_prompt"] = value
 
         return self
-    
-    def vision_max_tokens(self, value): 	
-        self.options['vision_max_tokens'] = value
+
+    def vision_max_tokens(self, value):
+        self.options["vision_max_tokens"] = value
 
         return self
 
-    def format(self, value): 	
-        self.options['format'] = value
+    def format(self, value):
+        self.options["format"] = value
 
         return self
 
-    def clip_x(self, value): 	
-        self.options['clip_x'] = value
+    def clip_x(self, value):
+        self.options["clip_x"] = value
 
         return self
 
-    def clip_y(self, value): 	
-        self.options['clip_y'] = value
+    def clip_y(self, value):
+        self.options["clip_y"] = value
 
         return self
 
-    def clip_width(self, value): 	
-        self.options['clip_width'] = value
+    def clip_width(self, value):
+        self.options["clip_width"] = value
 
         return self
 
-    def clip_height(self, value): 	
-        self.options['clip_height'] = value
+    def clip_height(self, value):
+        self.options["clip_height"] = value
 
         return self
 
-    def dark_mode(self, value): 	
-        self.options['dark_mode'] = value
+    def dark_mode(self, value):
+        self.options["dark_mode"] = value
 
         return self
 
-    def reduced_motion(self, value): 	
-        self.options['reduced_motion'] = value
+    def reduced_motion(self, value):
+        self.options["reduced_motion"] = value
 
         return self
 
-    def media_type(self, value): 	
-        self.options['media_type'] = value
+    def media_type(self, value):
+        self.options["media_type"] = value
 
         return self
 
-    def scripts(self, value): 	
-        self.options['scripts'] = value
+    def scripts(self, value):
+        self.options["scripts"] = value
 
         return self
 
-    def scripts_wait_until(self, value): 	
-        self.options['scripts_wait_until'] = value
+    def scripts_wait_until(self, value):
+        self.options["scripts_wait_until"] = value
 
         return self
 
-    def styles(self, value): 	
-        self.options['styles'] = value
+    def styles(self, value):
+        self.options["styles"] = value
 
         return self
 
-    def hide_selectors(self, values: List[str]): 	
-        self.options['hide_selectors'] = values
+    def hide_selectors(self, values: List[str]):
+        self.options["hide_selectors"] = values
 
         return self
 
-    def click(self, value): 	
-        self.options['click'] = value
+    def click(self, value):
+        self.options["click"] = value
 
         return self
 
-    def image_quality(self, value): 	
-        self.options['image_quality'] = value
+    def image_quality(self, value):
+        self.options["image_quality"] = value
 
         return self
 
-    def image_width(self, value): 	
-        self.options['image_width'] = value
+    def image_width(self, value):
+        self.options["image_width"] = value
 
         return self
 
-    def image_height(self, value): 	
-        self.options['image_height'] = value
+    def image_height(self, value):
+        self.options["image_height"] = value
 
         return self
 
-    def omit_background(self, value): 	
-        self.options['omit_background'] = value
+    def omit_background(self, value):
+        self.options["omit_background"] = value
 
         return self
 
-    def viewport_device(self, value): 	
-        self.options['viewport_device'] = value
+    def viewport_device(self, value):
+        self.options["viewport_device"] = value
 
         return self
 
-    def viewport_width(self, value): 	
-        self.options['viewport_width'] = value
+    def viewport_width(self, value):
+        self.options["viewport_width"] = value
 
         return self
 
-    def viewport_height(self, value): 	
-        self.options['viewport_height'] = value
+    def viewport_height(self, value):
+        self.options["viewport_height"] = value
 
         return self
 
-    def device_scale_factor(self, value): 	
-        self.options['device_scale_factor'] = value
+    def device_scale_factor(self, value):
+        self.options["device_scale_factor"] = value
 
         return self
 
-    def viewport_mobile(self, value): 	
-        self.options['viewport_mobile'] = value
+    def viewport_mobile(self, value):
+        self.options["viewport_mobile"] = value
 
         return self
 
-    def viewport_has_touch(self, value): 	
-        self.options['viewport_has_touch'] = value
+    def viewport_has_touch(self, value):
+        self.options["viewport_has_touch"] = value
 
         return self
 
-    def viewport_landscape(self, value): 	
-        self.options['viewport_landscape'] = value
+    def viewport_landscape(self, value):
+        self.options["viewport_landscape"] = value
 
         return self
 
-    def full_page(self, value): 	
-        self.options['full_page'] = value
+    def full_page(self, value):
+        self.options["full_page"] = value
 
         return self
 
-    def full_page_scroll(self, value): 	
-        self.options['full_page_scroll'] = value
+    def full_page_scroll(self, value):
+        self.options["full_page_scroll"] = value
 
         return self
-    
-    def fail_if_content_contains(self, value): 	
-        self.options['fail_if_content_contains'] = value
+
+    def fail_if_content_contains(self, value):
+        self.options["fail_if_content_contains"] = value
 
         return self
 
-    def full_page_scroll_delay(self, value): 	
-        self.options['full_page_scroll_delay'] = value
+    def full_page_scroll_delay(self, value):
+        self.options["full_page_scroll_delay"] = value
 
         return self
-    
-    def full_page_max_height(self, value): 	
-        self.options['full_page_max_height'] = value
+
+    def full_page_max_height(self, value):
+        self.options["full_page_max_height"] = value
 
         return self
 
-    def full_page_scroll_by(self, value): 	
-        self.options['full_page_scroll_by'] = value
+    def full_page_scroll_by(self, value):
+        self.options["full_page_scroll_by"] = value
 
         return self
 
-    def geolocation_latitude(self, value): 	
-        self.options['geolocation_latitude'] = value
+    def geolocation_latitude(self, value):
+        self.options["geolocation_latitude"] = value
 
         return self
 
-    def geolocation_longitude(self, value): 	
-        self.options['geolocation_longitude'] = value
+    def geolocation_longitude(self, value):
+        self.options["geolocation_longitude"] = value
 
         return self
 
-    def geolocation_accuracy(self, value): 	
-        self.options['geolocation_accuracy'] = value
+    def geolocation_accuracy(self, value):
+        self.options["geolocation_accuracy"] = value
 
         return self
 
-    def block_cookie_banners(self, value): 	
-        self.options['block_cookie_banners'] = value
+    def block_cookie_banners(self, value):
+        self.options["block_cookie_banners"] = value
 
         return self
 
-    def block_banners_by_heuristics(self, value): 	
-        self.options['block_banners_by_heuristics'] = value
+    def block_banners_by_heuristics(self, value):
+        self.options["block_banners_by_heuristics"] = value
 
         return self
 
-    def block_chats(self, value): 	
-        self.options['block_chats'] = value
+    def block_chats(self, value):
+        self.options["block_chats"] = value
 
         return self
 
-    def block_ads(self, value): 	
-        self.options['block_ads'] = value
+    def block_ads(self, value):
+        self.options["block_ads"] = value
 
         return self
 
-    def block_socials(self, value): 	
-        self.options['block_socials'] = value
+    def block_socials(self, value):
+        self.options["block_socials"] = value
 
         return self
 
-    def block_trackers(self, value): 	
-        self.options['block_trackers'] = value
+    def block_trackers(self, value):
+        self.options["block_trackers"] = value
 
         return self
 
-    def block_requests(self, values: List[str]): 	
-        self.options['block_requests'] = values
+    def block_requests(self, values: List[str]):
+        self.options["block_requests"] = values
 
         return self
 
-    def block_resources(self, values: List[str]): 	
-        self.options['block_resources'] = values
+    def block_resources(self, values: List[str]):
+        self.options["block_resources"] = values
 
         return self
 
-    def cache(self, value): 	
-        self.options['cache'] = value
+    def cache(self, value):
+        self.options["cache"] = value
 
         return self
 
-    def cache_ttl(self, value): 	
-        self.options['cache_ttl'] = value
+    def cache_ttl(self, value):
+        self.options["cache_ttl"] = value
 
         return self
 
-    def cache_key(self, value): 	
-        self.options['cache_key'] = value
+    def cache_key(self, value):
+        self.options["cache_key"] = value
 
         return self
 
-    def user_agent(self, value): 	
-        self.options['user_agent'] = value
+    def user_agent(self, value):
+        self.options["user_agent"] = value
 
         return self
 
-    def authorization(self, value): 	
-        self.options['authorization'] = value
+    def authorization(self, value):
+        self.options["authorization"] = value
 
         return self
 
-    def headers(self, values: List[str]): 	
-        self.options['headers'] = values
+    def headers(self, values: List[str]):
+        self.options["headers"] = values
 
         return self
 
-    def cookies(self, values: List[str]): 	
-        self.options['cookies'] = values
+    def cookies(self, values: List[str]):
+        self.options["cookies"] = values
 
         return self
 
-    def proxy(self, value): 	
-        self.options['proxy'] = value
+    def proxy(self, value):
+        self.options["proxy"] = value
 
         return self
 
-    def time_zone(self, value): 	
-        self.options['time_zone'] = value
+    def time_zone(self, value):
+        self.options["time_zone"] = value
 
         return self
 
-    def delay(self, value): 	
-        self.options['delay'] = value
+    def delay(self, value):
+        self.options["delay"] = value
 
         return self
 
-    def timeout(self, value): 	
-        self.options['timeout'] = value
+    def timeout(self, value):
+        self.options["timeout"] = value
 
         return self
 
-    def wait_until(self, values: List[str]): 	
-        self.options['wait_until'] = values
+    def wait_until(self, values: List[str]):
+        self.options["wait_until"] = values
 
         return self
 
-    def wait_for_selector(self, value): 	
-        self.options['wait_for_selector'] = value
+    def wait_for_selector(self, value):
+        self.options["wait_for_selector"] = value
 
         return self
 
-    def store(self, value): 	
-        self.options['store'] = value
+    def store(self, value):
+        self.options["store"] = value
 
         return self
 
-    def storage_bucket(self, value): 	
-        self.options['storage_bucket'] = value
+    def storage_bucket(self, value):
+        self.options["storage_bucket"] = value
 
         return self
 
-    def storage_path(self, value): 	
-        self.options['storage_path'] = value
+    def storage_path(self, value):
+        self.options["storage_path"] = value
 
         return self
 
-    def storage_class(self, value): 	
-        self.options['storage_class'] = value
+    def storage_class(self, value):
+        self.options["storage_class"] = value
 
         return self
 
-
-    def query(self): 
+    def query(self):
         return self.options
 
 
-class ScreenshotResultVision: 
+class ScreenshotResultVision:
     completion = None
 
-class ScreenshotResult: 
+
+class ScreenshotResult:
     screenshot = None
     vision = None
-    
-class Client: 
+
+
+class Client:
     access_key = None
     secret_key = None
 
@@ -375,33 +389,47 @@ def __init__(self, access_key: str, secret_key: str):
         self.access_key = access_key
         self.secret_key = secret_key
 
-    def with_keys(access_key: str, secret_key: str):         
+    def with_keys(access_key: str, secret_key: str):
         return Client(access_key, secret_key)
 
-    def generate_take_url(self, options: TakeOptions):                
+    def generate_take_url(self, options: TakeOptions):
         query = options.query()
-        query['access_key'] = self.access_key
+        query["access_key"] = self.access_key
+
+        query_string = urllib.parse.urlencode(query, doseq=True)
 
-        query_string = urllib.parse.urlencode(query, doseq=True) 
-    
-        signature = hmac.new(bytes(self.secret_key , 'utf-8'), msg = bytes(query_string , 'utf-8'), digestmod = hashlib.sha256).hexdigest()
+        signature = hmac.new(
+            bytes(self.secret_key, "utf-8"),
+            msg=bytes(query_string, "utf-8"),
+            digestmod=hashlib.sha256,
+        ).hexdigest()
 
-        return '%s%s?%s&signature=%s' % (API_BASE_URL, API_TAKE_PATH, query_string, signature) 
+        return "%s%s?%s&signature=%s" % (
+            API_BASE_URL,
+            API_TAKE_PATH,
+            query_string,
+            signature,
+        )
 
     def take(self, options):
         query = options.query()
-        query['access_key'] = self.access_key
+        query["access_key"] = self.access_key
 
-        url = '%s%s' % (API_BASE_URL, API_TAKE_PATH)
+        url = "%s%s" % (API_BASE_URL, API_TAKE_PATH)
         r = requests.post(url, json=query, stream=True)
-        
-        if r.status_code == 200: 
+
+        if r.status_code == 200:
             return r.raw
         elif r.status_code == 400:
             error_response = json.loads(r.text)
-            if not error_response.get('is_successful'):
-                error_messages = [detail['message'] for detail in error_response.get('error_details', [])]
-                error_message = f"Error: {error_response.get('error_message', 'Unknown error')}\n"
+            if not error_response.get("is_successful"):
+                error_messages = [
+                    detail["message"]
+                    for detail in error_response.get("error_details", [])
+                ]
+                error_message = (
+                    f"Error: {error_response.get('error_message', 'Unknown error')}\n"
+                )
                 error_message += "\n".join(error_messages)
 
                 raise InvalidRequestException(error_message)
@@ -411,32 +439,44 @@ def take(self, options):
             raise APIErrorException(error_message)
 
         return None
-    
+
     def take_with_metadata(self, options):
         query = options.query()
-        query['access_key'] = self.access_key
+        query["access_key"] = self.access_key
 
-        url = '%s%s' % (API_BASE_URL, API_TAKE_PATH)
+        url = "%s%s" % (API_BASE_URL, API_TAKE_PATH)
         r = requests.post(url, json=query, stream=True)
 
         vision = None
-        completion = r.headers.get('x-screenshotone-vision-completion')
+        completion = r.headers.get("x-screenshotone-vision-completion")
         if completion is not None:
             vision = ScreenshotResultVision(completion=completion)
 
-        if r.status_code == 200: 
+        if r.status_code == 200:
             return ScreenshotResult(screenshot=r.raw, vision=vision)
         elif r.status_code == 400:
             error_response = json.loads(r.text)
-            if not error_response.get('is_successful'):
-                error_messages = [detail['message'] for detail in error_response.get('error_details', [])]
-                error_message = f"Error: {error_response.get('error_message', 'Unknown error')}\n"
+            if not error_response.get("is_successful"):
+                error_messages = [
+                    detail["message"]
+                    for detail in error_response.get("error_details", [])
+                ]
+                error_message = (
+                    f"Error: {error_response.get('error_message', 'Unknown error')}\n"
+                )
                 error_message += "\n".join(error_messages)
+                error_code = error_response.get("error_code")
 
-                raise InvalidRequestException(error_message)
+                raise InvalidRequestException(
+                    error_message, http_status_code=r.status_code, error_code=error_code
+                )
         else:
-            # Handle other error status codes with a generic message
-            error_message = f"An error occurred while processing the request. Status code: {r.status_code}"
-            raise APIErrorException(error_message)
+            error_response = json.loads(r.text)
+            error_code = error_response.get("error_code")
+            error_message = f"An error occurred while processing the request. Status code: {r.status_code}, error code: {error_code}"
+
+            raise APIErrorException(
+                error_message, http_status_code=r.status_code, error_code=error_code
+            )
 
-        return None
\ No newline at end of file
+        return None