Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
fab-dlock committed Jul 23, 2023
1 parent 0b09bc5 commit acd8989
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 95 deletions.
235 changes: 141 additions & 94 deletions distributed_lock/sync.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import datetime
import functools
import logging
import os
import time
Expand All @@ -18,7 +19,6 @@
NotAcquiredError,
NotAcquiredException,
NotReleasedError,
NotReleasedException,
)

logger = logging.getLogger("distributed-lock.sync")
Expand Down Expand Up @@ -49,20 +49,77 @@ def make_httpx_client() -> httpx.Client:
return httpx.Client(timeout=timeout)


def with_retry(service_wait: bool = False):
def decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
wait = kwargs.get("wait", DEFAULT_WAIT)
automatic_retry = kwargs.get("automatic_retry", True)
sleep_after_failure = kwargs.get("sleep_after_failure", 1.0)
_forced_service_wait: float | None = None
before = datetime.datetime.utcnow()
while True:
catched_exception: Exception | None = None
try:
if _forced_service_wait is not None:
kwargs["_forced_service_wait"] = _forced_service_wait
return func(self, *args, **kwargs)
except DistributedLockError as e:
if not automatic_retry:
raise
catched_exception = e
except DistributedLockException as e:
catched_exception = e
elapsed = (datetime.datetime.utcnow() - before).total_seconds()
if elapsed > wait - sleep_after_failure:
raise catched_exception
logger.debug(f"wait {sleep_after_failure}s...")
time.sleep(sleep_after_failure)
if service_wait:
if elapsed + sleep_after_failure + self.service_wait > wait:
_forced_service_wait = max(
int(wait - elapsed - sleep_after_failure), 1
)

return wrapper

return decorator


@dataclass
class AcquiredRessource:
resource: str
lock_id: str
tenant_id: str
created: datetime.datetime = field(default_factory=datetime.datetime.utcnow)
expires: datetime.datetime = field(default_factory=datetime.datetime.utcnow)
user_agent: str = ""
user_data: Any = ""

@classmethod
def from_dict(cls, d: dict) -> AcquiredRessource:
for f in ("lock_id", "resource"):
for f in (
"lock_id",
"resource",
"tenant_id",
"created",
"expires",
"user_agent",
"user_data",
):
if f not in d:
raise DistributedLockError(f"bad reply from service, missing {f}")
return cls(resource=d["resource"], lock_id=d["lock_id"])
d2 = dict(d)
for f in ("created", "expires"):
if isinstance(d2[f], str):
d2[f] = datetime.datetime.fromisoformat(d2[f])
return cls(**d2)

def to_dict(self) -> dict:
return asdict(self)
d = asdict(self)
for f in ("created", "expires"):
d[f] = d[f].isoformat()[0:19] + "Z"
return d


@dataclass
Expand All @@ -83,130 +140,120 @@ def get_headers(self) -> dict[str, str]:
def __del__(self):
self.client.close()

def _request(
self,
method: str,
url: str,
headers: dict,
body: dict | None,
error_class,
exception_class,
):
try:
r = self.client.request(method, url, json=body, headers=headers)
except httpx.ConnectTimeout as e:
raise error_class("timeout during connect") from e
except httpx.ReadTimeout as e:
raise error_class("timeout during read") from e
except httpx.WriteTimeout as e:
raise error_class("timeout during write") from e
except httpx.PoolTimeout as e:
raise error_class("timeout in connection pool") from e
except httpx.HTTPError as e:
raise error_class("generic http error") from e
if r.status_code == 409:
raise exception_class("got a HTTP/409 Conflict")
elif r.status_code == 403:
try:
raise error_class(
f"got a HTTP/403 Forbidden error with message: {r.json()['message']}"
)
except Exception:
raise error_class("got an HTTP/403 Forbidden with no detail") from None
elif r.status_code == 429:
try:
logger.warning(
f"got a HTTP/429 Rate limited error with message: {r.json()['message']}"
)
except Exception:
raise error_class("got an HTTP/429 Forbidden with no detail") from None
elif r.status_code < 200 or r.status_code >= 300:
raise error_class(f"unexpected status code: {r.status_code}")
return r

def _acquire(
self,
resource: str,
lifetime: int = DEFAULT_LIFETIME,
user_data: str | None = None,
forced_service_wait: float | None = None,
) -> AcquiredRessource:
body: dict[str, Any] = {"wait": self.service_wait, "lifetime": lifetime}
body: dict[str, Any] = {
"wait": forced_service_wait
if forced_service_wait is not None
else self.service_wait,
"lifetime": lifetime,
}
if self.user_agent:
body["user_agent"] = self.user_agent
if user_data:
body["user_data"] = user_data
url = self.get_resource_url(resource)
logger.debug(f"Try to lock {resource} with url: {url}...")
try:
r = self.client.post(url, json=body, headers=self.get_headers())
except httpx.ConnectTimeout as e:
logger.warning(f"connect timeout error during POST on {url}")
raise NotAcquiredError("timeout during connect") from e
except httpx.ReadTimeout as e:
logger.warning(f"read timeout error during POST on {url}")
raise NotAcquiredError("timeout during read") from e
except httpx.WriteTimeout as e:
logger.warning(f"write timeout error during POST on {url}")
raise NotAcquiredError("timeout during write") from e
except httpx.PoolTimeout as e:
logger.warning("timeout in connection pool")
raise NotAcquiredError("timeout in connection pool") from e
except httpx.HTTPError as e:
logger.warning("generic http error")
raise NotAcquiredError("generic http error") from e
if r.status_code == 409:
logger.info(f"Lock on {resource} NOT acquired")
raise NotAcquiredException()
# FIXME other codes
r = self._request(
"POST",
url,
headers=self.get_headers(),
body=body,
error_class=NotAcquiredError,
exception_class=NotAcquiredException,
)
d = r.json()
logger.info(f"Lock on {resource} acquired")
return AcquiredRessource.from_dict(d)

@with_retry(service_wait=True)
def acquire_exclusive_lock(
self,
resource: str,
*,
lifetime: int = DEFAULT_LIFETIME,
wait: int = DEFAULT_WAIT,
user_data: str | None = None,
automatic_retry: bool = True,
sleep_after_unsuccessful: float = 1.0,
sleep_after_failure: float = 1.0,
_forced_service_wait: float | None = None,
) -> AcquiredRessource:
before = datetime.datetime.utcnow()
while True:
catched_exception: Exception | None = None
try:
return self._acquire(
resource=resource, lifetime=lifetime, user_data=user_data
)
except DistributedLockError as e:
if not automatic_retry:
raise
catched_exception = e
except DistributedLockException as e:
catched_exception = e
elapsed = (datetime.datetime.utcnow() - before).total_seconds()
if elapsed > wait - sleep_after_unsuccessful:
raise catched_exception
logger.debug(f"wait {sleep_after_unsuccessful}s...")
time.sleep(sleep_after_unsuccessful)
if elapsed + sleep_after_unsuccessful + self.service_wait > wait:
self.service_wait = max(
int(wait - elapsed - sleep_after_unsuccessful), 1
)
return self._acquire(
resource=resource,
lifetime=lifetime,
user_data=user_data,
forced_service_wait=_forced_service_wait,
)

def _release(self, resource: str, lock_id: str):
url = self.get_resource_url(resource) + "/" + lock_id
logger.debug(f"Try to unlock {resource} with url: {url}...")
try:
r = self.client.delete(url, headers=self.get_headers())
except httpx.ConnectTimeout as e:
logger.warning(f"connect timeout error during DELETE on {url}")
raise NotReleasedError("timeout during connect") from e
except httpx.ReadTimeout as e:
logger.warning(f"read timeout error during DELTE on {url}")
raise NotReleasedError("timeout during read") from e
except httpx.WriteTimeout as e:
logger.warning(f"write timeout error during DELETE on {url}")
raise NotReleasedError("timeout during write") from e
except httpx.PoolTimeout as e:
logger.warning("timeout in connection pool")
raise NotReleasedError("timeout in connection pool") from e
except httpx.HTTPError as e:
logger.warning("generic http error")
raise NotReleasedError("generic http error") from e
if r.status_code == 409:
logger.warning(
f"Lock on {resource} NOT released (because it's acquired by another lock_id!)"
)
raise NotReleasedException()
if r.status_code == 204:
return
logger.warning(
f"Lock on {resource} NOT released (because of an unexpected status code: {r.status_code})"
self._request(
"DELETE",
url,
headers=self.get_headers(),
body=None,
error_class=NotReleasedError,
exception_class=NotReleasedError,
)
raise NotReleasedError(f"unexpected status code: {r.status_code}")

@with_retry()
def release_exclusive_lock(
self,
resource: str,
lock_id: str,
*,
wait: int = 30,
automatic_retry: bool = True,
sleep_after_unsuccessful: float = 1.0,
sleep_after_failure: float = 1.0,
):
before = datetime.datetime.utcnow()
while True:
catched_exception = None
try:
return self._release(resource=resource, lock_id=lock_id)
except DistributedLockError as e:
if not automatic_retry:
raise
catched_exception = e
elapsed = (datetime.datetime.utcnow() - before).total_seconds()
if elapsed > wait - sleep_after_unsuccessful:
raise catched_exception
logger.debug(f"wait {sleep_after_unsuccessful}s...")
time.sleep(sleep_after_unsuccessful)
return self._release(resource=resource, lock_id=lock_id)

@contextmanager
def exclusive_lock(
Expand All @@ -216,7 +263,7 @@ def exclusive_lock(
wait: int = DEFAULT_WAIT,
user_data: str | None = None,
automatic_retry: bool = True,
sleep_after_unsuccessful: float = 1.0,
sleep_after_failure: float = 1.0,
):
ar: AcquiredRessource | None = None
try:
Expand All @@ -226,7 +273,7 @@ def exclusive_lock(
wait=wait,
user_data=user_data,
automatic_retry=automatic_retry,
sleep_after_unsuccessful=sleep_after_unsuccessful,
sleep_after_failure=sleep_after_failure,
)
yield
finally:
Expand All @@ -236,5 +283,5 @@ def exclusive_lock(
lock_id=ar.lock_id,
wait=wait,
automatic_retry=automatic_retry,
sleep_after_unsuccessful=sleep_after_unsuccessful,
sleep_after_failure=sleep_after_failure,
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ ignore = [
"PLR2004",
"PLR0913",
"PLW0603",
"PLR0912",
"N805",
"N818"
]
Expand Down
13 changes: 12 additions & 1 deletion tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
"DLOCK_TENANT_ID": "tenant_id",
"DLOCK_TOKEN": "token",
}
AR = AcquiredRessource(lock_id="1234", resource="bar")
AR = AcquiredRessource(
lock_id="1234",
resource="bar",
tenant_id="tenant_id",
user_agent="foo/1.0",
user_data={"foo": "bar"},
)


@mock.patch.dict(os.environ, MOCKED_ENVIRON, clear=True)
Expand Down Expand Up @@ -52,6 +58,11 @@ def test_acquire(respx_mock):
assert ar is not None
assert ar.lock_id == AR.lock_id
assert ar.resource == AR.resource
assert ar.tenant_id == AR.tenant_id
assert ar.created.isoformat()[0:19] == AR.created.isoformat()[0:19]
assert ar.expires.isoformat()[0:19] == AR.expires.isoformat()[0:19]
assert ar.user_agent == AR.user_agent
assert ar.user_data == AR.user_data


@mock.patch.dict(os.environ, MOCKED_ENVIRON, clear=True)
Expand Down

0 comments on commit acd8989

Please sign in to comment.