Skip to content

Commit

Permalink
Oauth2 (#9)
Browse files Browse the repository at this point in the history
* Oauth2

* update readme

* endpoint fixes + me endpoint

* add doc string
  • Loading branch information
leon1995 authored Dec 23, 2023
1 parent 99c5e63 commit fe7bfc8
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 7 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0)

## Unreleased

### Added

- oauth2 support
- me endpoint

### Changed

- Change repository structure to a src/package style
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async with endpoints.NetworkHandler('<api_key>') as api:
## TODO

- [ ] tests
- [ ] oauth2 support
- [x] oauth2 support
- [ ] [Family situation endpoint](https://apidoc.factorialhr.com/reference/get_v1-payroll-family-situation)
- [ ] [Contract versions endpoint](https://apidoc.factorialhr.com/reference/get_v1-payroll-contract-versions)
- [ ] [Supplements endpoint](https://apidoc.factorialhr.com/reference/get_v1-payroll-supplements)
Expand Down
165 changes: 165 additions & 0 deletions src/factorialhr/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""Authorization for api calls."""

import queue
import threading
import time
import typing
import urllib.parse
import uuid
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer

import httpx

HTTP_UNAUTHORIZED = 401
HTTP_SERVER = "localhost"
HTTP_SERVER_PORT = 50101


def _get_authorization_code(target_url: str,
client_id: str,
redirect_uri: str,
scope: str,
timeout: int) -> str | None:
state_secret = str(uuid.uuid4())
code_queue = queue.Queue(1)

class GetRequestHandler(BaseHTTPRequestHandler):

def do_GET(self): # noqa: N802
path = urllib.parse.urlparse(self.path)
if path.path != '/authorize/authorization_code':
return
queries = urllib.parse.parse_qs(path.query)
states = queries.get('state', [])
if len(states) != 1 or states[0] != state_secret:
self.send_response(400)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(
bytes('<html><b>ERROR: invalid state parameter. Repeat login process</b></html>', 'utf-8'))
code_queue.put(None)
return
codes = queries.get('code', [])
if len(codes) != 1 or not (code := codes[0]):
self.send_response(400)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(bytes('<html><b>ERROR: authorization code is missing. '
'Repeat login process</b></html>', 'utf-8'))
code_queue.put(None)
return
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(bytes('<html><b>SUCCESS. You can close this window now</b></html>', 'utf-8'))
code_queue.put(code)

web_server = HTTPServer((HTTP_SERVER, HTTP_SERVER_PORT), GetRequestHandler)
t = threading.Thread(target=web_server.serve_forever, daemon=True)
t.start()
webbrowser.open(f'{target_url}/oauth/authorize?'
f'client_id={client_id}&'
f'redirect_uri={urllib.parse.quote(redirect_uri)}&'
'response_type=code&'
f'{scope}&'
f'state={state_secret}')
try:
return code_queue.get(timeout=timeout)
except queue.Empty: # code has not been delivered in time
return None
finally:
web_server.shutdown()
web_server.server_close()
t.join()


class ApiKeyAuth(httpx.Auth):
"""Authorization using an api key."""

def __init__(self, api_key: str):
self.api_key = api_key

def auth_flow(self, request: httpx.Request):
"""Implement the authentication flow."""
request.headers['x-api-key'] = self.api_key
yield request


class OAuth2Auth(httpx.Auth):
"""Authorization using oauth2 flow."""

requires_response_body = True

def __init__(self, client_id: str, client_secret: str, redirect_uri: str, scope: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.scope = scope

self.authorization_code: str | None = None
self.access_token: str | None = None
self.token_type: str | None = None
self.expires_in: int | None = None
self.refresh_token: str | None = None
self.created_at: int | None = None

def auth_flow(self, request: httpx.Request) -> typing.Generator[httpx.Request, httpx.Response, None]:
"""Implement the authentication flow."""
if not self.authorization_code:
self.authorization_code = _get_authorization_code(
target_url=f'{request.url.scheme}://{request.url.host}',
client_id=self.client_id,
redirect_uri=self.redirect_uri,
scope=self.scope,
timeout=60)
if not self.authorization_code:
raise httpx.RequestError('Authorization code could not be obtained')
self.access_token = None # requires to be regenerated
if not self.access_token:
response = yield self.build_access_token_request(f'{request.url.scheme}://{request.url.host}')
self.update_access_token(response)

if self.created_at + self.expires_in <= time.time():
response = yield self.build_refresh_request(f'{request.url.scheme}://{request.url.host}')
self.update_access_token(response)

request.headers['Authorization'] = f'{self.token_type} {self.access_token}'
yield request

def build_access_token_request(self, target_url: str) -> httpx.Request:
"""Build a request to obtain a new access token."""
return httpx.Request(
"POST",
f"{target_url}/oauth/token",
data={
"client_id": self.client_id,
"client_secret": self.client_secret,
'code': self.authorization_code,
"grant_type": "authorization_code",
'redirect_uri': self.redirect_uri,
},
)

def build_refresh_request(self, target_url: str) -> httpx.Request:
"""Build a request to obtain a new access token."""
return httpx.Request(
"POST",
f"{target_url}/oauth/token",
data={
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "refresh_token",
"refresh_token": self.refresh_token,
},
)

def update_access_token(self, response: httpx.Response):
"""Update the member variables."""
response.raise_for_status()
response_data = dict(response.json())
self.access_token = response_data.get('access_token')
self.token_type = response_data.get('token_type')
self.expires_in = response_data.get('expires_in')
self.refresh_token = response_data.get('refresh_token')
self.created_at = response_data.get('created_at')
17 changes: 15 additions & 2 deletions src/factorialhr/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ async def delete(self, *, webhook_id: int, **kwargs) -> models.Webhook:
return models.Webhook(**await self.api.delete(f"{self._endpoint}/{webhook_id}", **kwargs))


class MeEndpoint:
def __init__(self, api: NetworkHandler):
self.api = api

@property
def _endpoint(self) -> str:
return "v1/me"

async def get(self, **kwargs) -> models.Me:
"""Implement https://apidoc.factorialhr.com/reference/get_v1-me."""
return models.Me(**await self.api.get(self._endpoint, **kwargs))


class LocationsEndpoint:
def __init__(self, api: NetworkHandler):
self.api = api
Expand Down Expand Up @@ -257,11 +270,11 @@ def _endpoint(self) -> str:

async def all(self, **kwargs) -> list[models.LegalEntity]:
"""Implement https://apidoc.factorialhr.com/reference/get_v1-core-legal-entities."""
return [models.LegalEntity(**le) for le in await self.api.put(self._endpoint, **kwargs)]
return [models.LegalEntity(**le) for le in await self.api.get(self._endpoint, **kwargs)]

async def get(self, *, entity_id: int, **kwargs) -> models.LegalEntity:
"""Implement https://apidoc.factorialhr.com/reference/get_v1-core-legal-entities-id."""
return models.LegalEntity(**await self.api.put(f"{self._endpoint}/{entity_id}", **kwargs))
return models.LegalEntity(**await self.api.get(f"{self._endpoint}/{entity_id}", **kwargs))


class KeysEndpoint:
Expand Down
17 changes: 13 additions & 4 deletions src/factorialhr/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,27 @@ class Webhook(pydantic.BaseModel):
company_id: int | None


class Me(pydantic.BaseModel):
email: str
full_name: str
first_name: str
last_name: str
employee_id: int
role: str


class Location(pydantic.BaseModel):
id: int
name: str | None # TODO: check which ones are required
country: str | None
name: str
country: str
phone_number: str | None
state: str | None
city: str | None
address_line_1: str | None
address_line_2: str | None
postal_code: str | None
timezone: str | None
company_holidays_ids: list[int]
timezone: str
company_holiday_ids: list[int]


class CompanyHoliday(pydantic.BaseModel):
Expand Down

0 comments on commit fe7bfc8

Please sign in to comment.