diff --git a/README.rst b/README.rst
index dcd0a8b..102bca8 100644
--- a/README.rst
+++ b/README.rst
@@ -28,13 +28,14 @@ If you wish to download the source and install from GitHub:
.. code-block:: bash
git clone git@github.com:mailgun/expiringdict.git
+ cd expiringdict
python setup.py install
or to install with test dependencies (`Nose `_, `Mock `_, `coverage `_) run from the directory above:
.. code-block:: bash
- pip install -e expiringdict[test]
+ pip install -e .[test]
To run tests with coverage:
diff --git a/expiringdict/__init__.py b/expiringdict/__init__.py
index 56ba71d..220f0eb 100755
--- a/expiringdict/__init__.py
+++ b/expiringdict/__init__.py
@@ -116,7 +116,7 @@ def ttl(self, key):
Returns None for non-existent or expired keys.
"""
key_value, key_age = self.get(key, with_age=True) # type: Any, Union[None, float]
- if key_age:
+ if key_age is not None:
key_ttl = self.max_age - key_age
if key_ttl > 0:
return key_ttl
@@ -244,3 +244,20 @@ def __copy_dict(self, items):
def __copy_reduced_result(self, items):
[self.__setitem__(key, value, set_time) for key, (value, set_time) in items[1]]
+
+
+def memoize(max_len, max_age_seconds):
+ cache = ExpiringDict(max_len, max_age_seconds)
+
+ def wrap(fn):
+ def wrapped_fn(*args, **kwargs):
+ key = (args, frozenset(kwargs.items()))
+ result = cache.get(key)
+ if result is None:
+ result = fn(*args)
+ cache[key] = result
+ return result
+
+ return wrapped_fn
+
+ return wrap
diff --git a/setup.py b/setup.py
index 17d166e..763df4a 100644
--- a/setup.py
+++ b/setup.py
@@ -40,5 +40,6 @@
"typing",
],
extras_require={
+ "test": tests_require,
"tests": tests_require,
})
diff --git a/tests/expiringdict_test.py b/tests/expiringdict_test.py
old mode 100644
new mode 100755
index 78929f1..dcd6092
--- a/tests/expiringdict_test.py
+++ b/tests/expiringdict_test.py
@@ -3,7 +3,7 @@
from mock import Mock, patch
from nose.tools import assert_raises, eq_, ok_
-from expiringdict import ExpiringDict
+from expiringdict import ExpiringDict, memoize
def test_create():
@@ -94,14 +94,15 @@ def test_ttl():
d['a'] = 'x'
# existent non-expired key
- ok_(0 < d.ttl('a') < 10)
+ # TTL can be 10 if the machine is very fast.
+ ok_(0 < d.ttl('a') <= 10)
# non-existent key
eq_(None, d.ttl('b'))
# expired key
with patch.object(ExpiringDict, '__getitem__',
- Mock(return_value=('x', 10**9))):
+ Mock(return_value=('x', 10 ** 9))):
eq_(None, d.ttl('a'))
@@ -122,3 +123,48 @@ def test_not_implemented():
assert_raises(NotImplementedError, d.viewitems)
assert_raises(NotImplementedError, d.viewkeys)
assert_raises(NotImplementedError, d.viewvalues)
+
+
+def test_memoize():
+ @memoize(max_len=10, max_age_seconds=10)
+ def noargs():
+ return 0
+
+ eq_(0, noargs())
+
+ @memoize(max_len=250, max_age_seconds=10)
+ def fib(n):
+ if n == 0:
+ return 0
+ elif n == 1:
+ return 1
+ return fib(n - 1) + fib(n - 2)
+
+ eq_(280571172992510140037611932413038677189525, fib(200))
+
+
+def test_memoize_class():
+ class A(object):
+ def __init__(self, value):
+ self.value = value
+
+ @memoize(max_len=10, max_age_seconds=5)
+ def get_value(self, arg, kwarg=1):
+ return self.value
+
+ # with no kwargs
+ original_value = 'val'
+ a = A(original_value)
+ ok_(original_value is a.get_value(0))
+ a.value = 'new val'
+ ok_(original_value is a.get_value(0))
+
+ original_value = 'new A val'
+ new_a = A(original_value)
+ ok_(a.get_value(0) != new_a.get_value(0))
+ eq_(new_a.value, new_a.get_value(0))
+
+ # with kwargs
+ ok_(original_value is new_a.get_value(0, kwarg=2))
+ new_a.value = 'new A new val'
+ ok_(original_value is new_a.get_value(0, kwarg=2))