diff --git a/README.rst b/README.rst index dcd0a8b..102bca8 100644 --- a/README.rst +++ b/README.rst @@ -28,13 +28,14 @@ If you wish to download the source and install from GitHub: .. code-block:: bash git clone git@github.com:mailgun/expiringdict.git + cd expiringdict python setup.py install or to install with test dependencies (`Nose `_, `Mock `_, `coverage `_) run from the directory above: .. code-block:: bash - pip install -e expiringdict[test] + pip install -e .[test] To run tests with coverage: diff --git a/expiringdict/__init__.py b/expiringdict/__init__.py index 56ba71d..220f0eb 100755 --- a/expiringdict/__init__.py +++ b/expiringdict/__init__.py @@ -116,7 +116,7 @@ def ttl(self, key): Returns None for non-existent or expired keys. """ key_value, key_age = self.get(key, with_age=True) # type: Any, Union[None, float] - if key_age: + if key_age is not None: key_ttl = self.max_age - key_age if key_ttl > 0: return key_ttl @@ -244,3 +244,20 @@ def __copy_dict(self, items): def __copy_reduced_result(self, items): [self.__setitem__(key, value, set_time) for key, (value, set_time) in items[1]] + + +def memoize(max_len, max_age_seconds): + cache = ExpiringDict(max_len, max_age_seconds) + + def wrap(fn): + def wrapped_fn(*args, **kwargs): + key = (args, frozenset(kwargs.items())) + result = cache.get(key) + if result is None: + result = fn(*args) + cache[key] = result + return result + + return wrapped_fn + + return wrap diff --git a/setup.py b/setup.py index 17d166e..763df4a 100644 --- a/setup.py +++ b/setup.py @@ -40,5 +40,6 @@ "typing", ], extras_require={ + "test": tests_require, "tests": tests_require, }) diff --git a/tests/expiringdict_test.py b/tests/expiringdict_test.py old mode 100644 new mode 100755 index 78929f1..dcd6092 --- a/tests/expiringdict_test.py +++ b/tests/expiringdict_test.py @@ -3,7 +3,7 @@ from mock import Mock, patch from nose.tools import assert_raises, eq_, ok_ -from expiringdict import ExpiringDict +from expiringdict import ExpiringDict, memoize def test_create(): @@ -94,14 +94,15 @@ def test_ttl(): d['a'] = 'x' # existent non-expired key - ok_(0 < d.ttl('a') < 10) + # TTL can be 10 if the machine is very fast. + ok_(0 < d.ttl('a') <= 10) # non-existent key eq_(None, d.ttl('b')) # expired key with patch.object(ExpiringDict, '__getitem__', - Mock(return_value=('x', 10**9))): + Mock(return_value=('x', 10 ** 9))): eq_(None, d.ttl('a')) @@ -122,3 +123,48 @@ def test_not_implemented(): assert_raises(NotImplementedError, d.viewitems) assert_raises(NotImplementedError, d.viewkeys) assert_raises(NotImplementedError, d.viewvalues) + + +def test_memoize(): + @memoize(max_len=10, max_age_seconds=10) + def noargs(): + return 0 + + eq_(0, noargs()) + + @memoize(max_len=250, max_age_seconds=10) + def fib(n): + if n == 0: + return 0 + elif n == 1: + return 1 + return fib(n - 1) + fib(n - 2) + + eq_(280571172992510140037611932413038677189525, fib(200)) + + +def test_memoize_class(): + class A(object): + def __init__(self, value): + self.value = value + + @memoize(max_len=10, max_age_seconds=5) + def get_value(self, arg, kwarg=1): + return self.value + + # with no kwargs + original_value = 'val' + a = A(original_value) + ok_(original_value is a.get_value(0)) + a.value = 'new val' + ok_(original_value is a.get_value(0)) + + original_value = 'new A val' + new_a = A(original_value) + ok_(a.get_value(0) != new_a.get_value(0)) + eq_(new_a.value, new_a.get_value(0)) + + # with kwargs + ok_(original_value is new_a.get_value(0, kwarg=2)) + new_a.value = 'new A new val' + ok_(original_value is new_a.get_value(0, kwarg=2))