diff --git a/CHANGELOG.md b/CHANGELOG.md index 858449d..c95a0ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0) ## Unreleased +### Added + +- oauth2 support +- me endpoint + ### Changed - Change repository structure to a src/package style diff --git a/README.md b/README.md index 82dd9e6..e5dcf7c 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ async with endpoints.NetworkHandler('') as api: ## TODO - [ ] tests -- [ ] oauth2 support +- [x] oauth2 support - [ ] [Family situation endpoint](https://apidoc.factorialhr.com/reference/get_v1-payroll-family-situation) - [ ] [Contract versions endpoint](https://apidoc.factorialhr.com/reference/get_v1-payroll-contract-versions) - [ ] [Supplements endpoint](https://apidoc.factorialhr.com/reference/get_v1-payroll-supplements) diff --git a/src/factorialhr/auth.py b/src/factorialhr/auth.py new file mode 100644 index 0000000..d244061 --- /dev/null +++ b/src/factorialhr/auth.py @@ -0,0 +1,165 @@ +"""Authorization for api calls.""" + +import queue +import threading +import time +import typing +import urllib.parse +import uuid +import webbrowser +from http.server import BaseHTTPRequestHandler, HTTPServer + +import httpx + +HTTP_UNAUTHORIZED = 401 +HTTP_SERVER = "localhost" +HTTP_SERVER_PORT = 50101 + + +def _get_authorization_code(target_url: str, + client_id: str, + redirect_uri: str, + scope: str, + timeout: int) -> str | None: + state_secret = str(uuid.uuid4()) + code_queue = queue.Queue(1) + + class GetRequestHandler(BaseHTTPRequestHandler): + + def do_GET(self): # noqa: N802 + path = urllib.parse.urlparse(self.path) + if path.path != '/authorize/authorization_code': + return + queries = urllib.parse.parse_qs(path.query) + states = queries.get('state', []) + if len(states) != 1 or states[0] != state_secret: + self.send_response(400) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write( + bytes('ERROR: invalid state parameter. Repeat login process', 'utf-8')) + code_queue.put(None) + return + codes = queries.get('code', []) + if len(codes) != 1 or not (code := codes[0]): + self.send_response(400) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(bytes('ERROR: authorization code is missing. ' + 'Repeat login process', 'utf-8')) + code_queue.put(None) + return + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(bytes('SUCCESS. You can close this window now', 'utf-8')) + code_queue.put(code) + + web_server = HTTPServer((HTTP_SERVER, HTTP_SERVER_PORT), GetRequestHandler) + t = threading.Thread(target=web_server.serve_forever, daemon=True) + t.start() + webbrowser.open(f'{target_url}/oauth/authorize?' + f'client_id={client_id}&' + f'redirect_uri={urllib.parse.quote(redirect_uri)}&' + 'response_type=code&' + f'{scope}&' + f'state={state_secret}') + try: + return code_queue.get(timeout=timeout) + except queue.Empty: # code has not been delivered in time + return None + finally: + web_server.shutdown() + web_server.server_close() + t.join() + + +class ApiKeyAuth(httpx.Auth): + """Authorization using an api key.""" + + def __init__(self, api_key: str): + self.api_key = api_key + + def auth_flow(self, request: httpx.Request): + """Implement the authentication flow.""" + request.headers['x-api-key'] = self.api_key + yield request + + +class OAuth2Auth(httpx.Auth): + """Authorization using oauth2 flow.""" + + requires_response_body = True + + def __init__(self, client_id: str, client_secret: str, redirect_uri: str, scope: str): + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + self.scope = scope + + self.authorization_code: str | None = None + self.access_token: str | None = None + self.token_type: str | None = None + self.expires_in: int | None = None + self.refresh_token: str | None = None + self.created_at: int | None = None + + def auth_flow(self, request: httpx.Request) -> typing.Generator[httpx.Request, httpx.Response, None]: + """Implement the authentication flow.""" + if not self.authorization_code: + self.authorization_code = _get_authorization_code( + target_url=f'{request.url.scheme}://{request.url.host}', + client_id=self.client_id, + redirect_uri=self.redirect_uri, + scope=self.scope, + timeout=60) + if not self.authorization_code: + raise httpx.RequestError('Authorization code could not be obtained') + self.access_token = None # requires to be regenerated + if not self.access_token: + response = yield self.build_access_token_request(f'{request.url.scheme}://{request.url.host}') + self.update_access_token(response) + + if self.created_at + self.expires_in <= time.time(): + response = yield self.build_refresh_request(f'{request.url.scheme}://{request.url.host}') + self.update_access_token(response) + + request.headers['Authorization'] = f'{self.token_type} {self.access_token}' + yield request + + def build_access_token_request(self, target_url: str) -> httpx.Request: + """Build a request to obtain a new access token.""" + return httpx.Request( + "POST", + f"{target_url}/oauth/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + 'code': self.authorization_code, + "grant_type": "authorization_code", + 'redirect_uri': self.redirect_uri, + }, + ) + + def build_refresh_request(self, target_url: str) -> httpx.Request: + """Build a request to obtain a new access token.""" + return httpx.Request( + "POST", + f"{target_url}/oauth/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "grant_type": "refresh_token", + "refresh_token": self.refresh_token, + }, + ) + + def update_access_token(self, response: httpx.Response): + """Update the member variables.""" + response.raise_for_status() + response_data = dict(response.json()) + self.access_token = response_data.get('access_token') + self.token_type = response_data.get('token_type') + self.expires_in = response_data.get('expires_in') + self.refresh_token = response_data.get('refresh_token') + self.created_at = response_data.get('created_at') diff --git a/src/factorialhr/endpoints.py b/src/factorialhr/endpoints.py index 598843c..375cab5 100644 --- a/src/factorialhr/endpoints.py +++ b/src/factorialhr/endpoints.py @@ -113,6 +113,19 @@ async def delete(self, *, webhook_id: int, **kwargs) -> models.Webhook: return models.Webhook(**await self.api.delete(f"{self._endpoint}/{webhook_id}", **kwargs)) +class MeEndpoint: + def __init__(self, api: NetworkHandler): + self.api = api + + @property + def _endpoint(self) -> str: + return "v1/me" + + async def get(self, **kwargs) -> models.Me: + """Implement https://apidoc.factorialhr.com/reference/get_v1-me.""" + return models.Me(**await self.api.get(self._endpoint, **kwargs)) + + class LocationsEndpoint: def __init__(self, api: NetworkHandler): self.api = api @@ -257,11 +270,11 @@ def _endpoint(self) -> str: async def all(self, **kwargs) -> list[models.LegalEntity]: """Implement https://apidoc.factorialhr.com/reference/get_v1-core-legal-entities.""" - return [models.LegalEntity(**le) for le in await self.api.put(self._endpoint, **kwargs)] + return [models.LegalEntity(**le) for le in await self.api.get(self._endpoint, **kwargs)] async def get(self, *, entity_id: int, **kwargs) -> models.LegalEntity: """Implement https://apidoc.factorialhr.com/reference/get_v1-core-legal-entities-id.""" - return models.LegalEntity(**await self.api.put(f"{self._endpoint}/{entity_id}", **kwargs)) + return models.LegalEntity(**await self.api.get(f"{self._endpoint}/{entity_id}", **kwargs)) class KeysEndpoint: diff --git a/src/factorialhr/models.py b/src/factorialhr/models.py index c7b5ccf..479c9b6 100644 --- a/src/factorialhr/models.py +++ b/src/factorialhr/models.py @@ -56,18 +56,27 @@ class Webhook(pydantic.BaseModel): company_id: int | None +class Me(pydantic.BaseModel): + email: str + full_name: str + first_name: str + last_name: str + employee_id: int + role: str + + class Location(pydantic.BaseModel): id: int - name: str | None # TODO: check which ones are required - country: str | None + name: str + country: str phone_number: str | None state: str | None city: str | None address_line_1: str | None address_line_2: str | None postal_code: str | None - timezone: str | None - company_holidays_ids: list[int] + timezone: str + company_holiday_ids: list[int] class CompanyHoliday(pydantic.BaseModel):