diff --git a/fmt/fmt.py b/fmt/fmt.py index ce33f5e..019eab3 100644 --- a/fmt/fmt.py +++ b/fmt/fmt.py @@ -9,10 +9,18 @@ from functools import partial +PY3 = sys.version_info[0] == 3 +if PY3: + fmt_types = str, bytes +else: + fmt_types = basestring, unicode # noqa: F821 + + class Fmt(object): def __init__(self): self._g_ns = {} + self._nodes_cache = {} def register(self, name, value, update=False): if not update and name in self._g_ns: @@ -24,13 +32,21 @@ def mregister(self, ns, update=False): self.register(k, v, update) def __call__(self, f_str, *_args): + if not isinstance(f_str, fmt_types): + raise ValueError('Unsupported type as format ' + 'string: {}({})'.format(type(f_str), f_str)) frame = sys._getframe(1) # locals will cover globals ns = deepcopy(self._g_ns) ns.update(frame.f_globals) ns.update(frame.f_locals) - nodes = Parser(f_str).parse() + # cache nodes, if already parsed + nodes = self._nodes_cache.get(f_str, None) + if nodes is None: + nodes = Parser(f_str).parse() + self._nodes_cache[f_str] = nodes + try: return generate(nodes, ns) finally: @@ -45,6 +61,18 @@ def generate(nodes, namespace): class Node(object): + # flyweight pattern: cache instances + _instances = {} + + def __new__(cls, *args, **kwargs): + key = (cls, args, tuple(kwargs.items())) + instance = cls._instances.get(key, None) + if instance is None: + instance = super(Node, cls).__new__(cls) + instance.__init__(*args, **kwargs) + cls._instances[key] = instance + return instance + def generate(self, ns): raise NotImplementedError() diff --git a/tests/test_fmt.py b/tests/test_fmt.py index 5d49ee9..19033b8 100644 --- a/tests/test_fmt.py +++ b/tests/test_fmt.py @@ -1,8 +1,35 @@ # -*- coding: utf-8 -*- import sys -from datetime import datetime -import fmt as f +from datetime import datetime # noqa +from fmt.fmt import Fmt, Parser, Text +import fmt as f + + +def test_parsed_nodes_been_cached(monkeypatch): + call_count = [] + + def _parse(self): + call_count.append(1) + return [Text('baz')] + + monkeypatch.setattr(Parser, 'parse', _parse) + fmt = Fmt() + + count = 0 + for f_str in ('foo', 'bar'): + count += 1 + fmt(f_str) + assert f_str in fmt._nodes_cache + assert len(call_count) == count + + value = fmt._nodes_cache[f_str][0] + + for _ in range(5): + fmt(f_str) + assert f_str in fmt._nodes_cache + assert len(call_count) == count, f_str + assert id(value) == id(fmt._nodes_cache[f_str][0]) g_foo = 'global-foo' @@ -34,6 +61,7 @@ def test_namespace(): def test_closure_namespace(): def outer(x): y = 'yy' + def inner(): z = 'zz' globals_, locals_ = get_namesapce(x, y) @@ -54,7 +82,9 @@ def test_fmt(): l_bar = 'local-bar' ls_num = range(5) ls_ch = ['a', 'b', 'c', 'd', 'e'] + class Baz(object): + def __str__(self): return 'BAZ' @@ -77,19 +107,21 @@ def __str__(self): assert '1314' == f('{func(x, y)}') def outer(arg): + def inner(): assert '{ outer-arg }' == f('{{ {arg} }}', arg) return inner + outer('outer-arg')() - assert '[0, 1, 2, 3, 4]' == f('{list(range(5))}') # Py3 range return iterator + # Py3 range return iterator + assert '[0, 1, 2, 3, 4]' == f('{list(range(5))}') assert '[0, 1, 2, 3, 4]' == f('{[i for i in ls_num]}') if sys.version_info[0] == 2: assert 'set([0, 1, 2, 3, 4])' == f('{{i for i in ls_num}}') else: assert '{0, 1, 2, 3, 4}' == f('{{i for i in ls_num}}') - assert ("{0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e'}" == f('{{k:v for k,v in zip(ls_num, ls_ch)}}')) assert '[1, 2, 3, 4, 5]' == f('{list(map(lambda x: x+1, ls_num))}') diff --git a/tests/test_node.py b/tests/test_node.py index a590084..8abe23f 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -4,6 +4,20 @@ from fmt.fmt import Text, Constant, Expression +def test_flyweight(): + assert id(Text('foo', 'bar')) == id(Text('foo', 'bar')) + assert id(Constant('foo', 'bar')) == id(Constant('foo', 'bar')) + assert id(Expression('foo', 'bar')) == id(Expression('foo', 'bar')) + + assert id(Text('foo', 'bar')) != id(Text('bar', 'foo')) + assert id(Constant('foo', 'bar')) != id(Constant('bar', 'foo')) + assert id(Expression('foo', 'bar')) != id(Expression('bar', 'foo')) + + assert id(Text('foo', 'bar')) != id(Constant('foo', 'bar')) + assert id(Text('foo', 'bar')) != id(Expression('foo', 'bar')) + assert id(Constant('foo', 'bar')) != id(Expression('foo', 'bar')) + + def test_Text(): text = 'I am a good man, /(愒o愒)/~~' assert text == Text(text).generate(None)