Created
April 6, 2017 15:27
-
-
Save bananos/101742827f47bfc7291243934bb9c03f to your computer and use it in GitHub Desktop.
Simple TTL-based memoization with support for coroutines
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| import unittest | |
| from h.commons import memoized_ttl | |
| called = 0 | |
| @memoized_ttl(10) | |
| def wrapped(): | |
| global called | |
| called += 1 | |
| return 10 | |
| class MemoizeClass(object): | |
| cls_called = 0 | |
| cls_async_called = 0 | |
| @classmethod | |
| @memoized_ttl(10) | |
| def my_class_fun(cls): | |
| cls.cls_called += 1 | |
| return 20 | |
| @classmethod | |
| @memoized_ttl(10) | |
| async def my_async_classmethod(cls): | |
| cls.cls_async_called += 1 | |
| return 40 | |
| def __init__(self): | |
| self.called = 0 | |
| @memoized_ttl(10) | |
| def my_fun(self): | |
| self.called += 1 | |
| return 30 | |
| @memoized_ttl(10) | |
| async def my_async_fun(self): | |
| self.called += 1 | |
| return 50 | |
| class TestMemoize(unittest.TestCase): | |
| def setUp(self): | |
| self.loop = asyncio.new_event_loop() | |
| def test_memoize_fun(self): | |
| """It should work for a module level method""" | |
| self.assertEqual(called, 0) | |
| val = wrapped() | |
| self.assertEqual(val, 10) | |
| self.assertEqual(called, 1) | |
| val = wrapped() | |
| self.assertEqual(val, 10) | |
| self.assertEqual(called, 1) | |
| def test_memoize_class_method(self): | |
| """It should work for a classmethod""" | |
| self.assertEqual(MemoizeClass.cls_called, 0) | |
| val = MemoizeClass.my_class_fun() | |
| self.assertEqual(val, 20) | |
| self.assertEqual(MemoizeClass.cls_called, 1) | |
| val = MemoizeClass.my_class_fun() | |
| self.assertEqual(val, 20) | |
| self.assertEqual(MemoizeClass.cls_called, 1) | |
| def test_memoize_instance_method(self): | |
| """It should work for an instance method""" | |
| mc = MemoizeClass() | |
| self.assertEqual(mc.called, 0) | |
| val = mc.my_fun() | |
| self.assertEqual(val, 30) | |
| self.assertEqual(mc.called, 1) | |
| val = mc.my_fun() | |
| self.assertEqual(val, 30) | |
| self.assertEqual(mc.called, 1) | |
| def test_memoize_async_classmethod(self): | |
| """It should work with an async coroutine as classmethod.""" | |
| self.assertEqual(MemoizeClass.cls_async_called, 0) | |
| async def go(): | |
| val_fut1 = await MemoizeClass.my_async_classmethod() | |
| val_fut2 = await MemoizeClass.my_async_classmethod() | |
| self.assertEqual(val_fut1, 40) | |
| self.assertEqual(val_fut2, 40) | |
| self.loop.run_until_complete(go()) | |
| self.assertEqual(MemoizeClass.cls_async_called, 1) | |
| def test_memoize_async(self): | |
| """It should work with an async coroutine instance method.""" | |
| mc = MemoizeClass() | |
| self.assertEqual(mc.called, 0) | |
| async def go(): | |
| val_fut1 = await mc.my_async_fun() | |
| val_fut2 = await mc.my_async_fun() | |
| self.assertEqual(val_fut1, 50) | |
| self.assertEqual(val_fut2, 50) | |
| self.loop.run_until_complete(go()) | |
| self.assertEqual(mc.called, 1) | |
| if __name__ == '__main__': | |
| unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import time | |
| import asyncio | |
| class memoized_ttl(object): | |
| """ | |
| Decorator that caches a function's return value each time it is called within a TTL | |
| If called within the TTL and the same arguments, the cached value is returned, | |
| If called outside the TTL or a different value, a fresh value is returned. | |
| http://jonebird.com/2012/02/07/python-memoize-decorator-with-ttl-argument/ | |
| asyncio-friendly version inspired by: | |
| https://gist.github.com/dlebech/c16a34f735c0c4e9b604 | |
| TTL in seconds! | |
| """ | |
| def __init__(self, ttl): | |
| self.cache = {} | |
| self.ttl = ttl | |
| def _wrap_value_in_coroutine(self, val): | |
| async def wrapper(): | |
| return val | |
| return wrapper() | |
| def _wrap_coroutine_storage(self, key, future, last_update): | |
| async def wrapper(): | |
| val = await future | |
| self.cache[key] = (val, last_update) | |
| return val | |
| return wrapper() | |
| def __call__(self, f): | |
| def wrapped_f(*args, **kwargs): | |
| now = time.time() | |
| # Simple key generation. Notice that there are no guarantees that the | |
| # key will be the same when using dict arguments. | |
| key = f.__module__ + '#' + f.__name__ + '#' + repr((args, kwargs)) | |
| try: | |
| value, last_update = self.cache[key] | |
| if self.ttl > 0 and now - last_update > self.ttl: | |
| raise AttributeError | |
| if asyncio.iscoroutinefunction(f): | |
| return self._wrap_value_in_coroutine(value) | |
| return value | |
| except (KeyError, AttributeError): | |
| value = f(*args, **kwargs) | |
| if asyncio.iscoroutine(value): | |
| # If the value returned by the function is a coroutine, wrap | |
| # the future in a new coroutine that stores the actual result | |
| # in the cache. | |
| return self._wrap_coroutine_storage(key, value, now) | |
| self.cache[key] = (value, now) | |
| return value | |
| except TypeError: | |
| # uncachable -- for instance, passing a list as an argument. | |
| # Better to not cache than to blow up entirely. | |
| return f(*args, **kwargs) | |
| return wrapped_f | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment