Skip to content
This repository was archived by the owner on Sep 28, 2022. It is now read-only.

Commit a87f379

Browse files
committed
Merge pull request #71 from brandicted/acl-refactor
Acl refactor
2 parents a58092c + 4fe11fb commit a87f379

File tree

4 files changed

+76
-83
lines changed

4 files changed

+76
-83
lines changed

ramses/acl.py

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
Allow, Deny,
66
Everyone, Authenticated,
77
ALL_PERMISSIONS)
8-
from nefertari.acl import SelfParamMixin
8+
from nefertari.acl import CollectionACL
9+
from nefertari.resource import PERMISSIONS
10+
from nefertari.elasticsearch import ES
911

1012
from .views import collection_methods, item_methods
1113
from .utils import resolve_to_callable, is_callable_tag
@@ -30,8 +32,7 @@ def methods_to_perms(perms, methods_map):
3032
the keyword 'all' into a set of valid Pyramid permissions.
3133
3234
:param perms: List or comma-separated string of HTTP methods, or 'all'
33-
:param methods_map: Map of HTTP methods to permission names (nefertari view
34-
methods)
35+
:param methods_map: Map of HTTP methods to nefertari view methods
3536
"""
3637
if isinstance(perms, six.string_types):
3738
perms = perms.split(',')
@@ -40,7 +41,7 @@ def methods_to_perms(perms, methods_map):
4041
return ALL_PERMISSIONS
4142
else:
4243
try:
43-
return [methods_map[p] for p in perms]
44+
return [PERMISSIONS[methods_map[p]] for p in perms]
4445
except KeyError:
4546
raise ValueError(
4647
'Unknown method name in permissions: {}. Valid methods: '
@@ -96,15 +97,12 @@ def parse_acl(acl_string, methods_map):
9697
return result_acl
9798

9899

99-
class BaseACL(SelfParamMixin):
100+
class BaseACL(CollectionACL):
100101
""" ACL Base class. """
101-
__context_class__ = None
102-
collection_acl = None
103-
item_acl = None
104102

105-
def __init__(self, request):
106-
super(BaseACL, self).__init__()
107-
self.request = request
103+
es_based = False
104+
_collection_acl = (ALLOW_ALL, )
105+
_item_acl = (ALLOW_ALL, )
108106

109107
def _apply_callables(self, acl, methods_map, obj=None):
110108
""" Iterate over ACEs from :acl: and apply callable principals if any.
@@ -138,42 +136,37 @@ def _apply_callables(self, acl, methods_map, obj=None):
138136
return new_acl
139137

140138
def __acl__(self):
141-
""" Apply callables to `self.collection_acl` and return result. """
139+
""" Apply callables to `self._collection_acl` and return result. """
142140
return self._apply_callables(
143-
acl=self.collection_acl,
141+
acl=self._collection_acl,
144142
methods_map=collection_methods)
145143

146-
def context_acl(self, obj):
147-
""" Apply callables to `self.item_acl` and return result. """
144+
def item_acl(self, item):
145+
""" Apply callables to `self._item_acl` and return result. """
148146
return self._apply_callables(
149-
acl=self.item_acl,
147+
acl=self._item_acl,
150148
methods_map=item_methods,
151-
obj=obj)
149+
obj=item)
150+
151+
def item_db_id(self, key):
152+
# ``self`` can be used for current authenticated user key
153+
if key != 'self':
154+
return key
155+
user = getattr(self.request, 'user', None)
156+
if user is None or not isinstance(user, self.item_model):
157+
return key
158+
return getattr(user, user.pk_field())
152159

153160
def __getitem__(self, key):
154161
""" Get item using method depending on value of `self.es_based` """
155-
key = self.resolve_self_key(key)
156-
if self.es_based:
157-
return self.getitem_es(key=key)
158-
else:
159-
return self.getitem_db(key=key)
160-
161-
def getitem_db(self, key):
162-
""" Get item with ID of :key: from database """
163-
pk_field = self.__context_class__.pk_field()
164-
obj = self.__context_class__.get_resource(
165-
**{pk_field: key})
166-
obj.__acl__ = self.context_acl(obj)
167-
obj.__parent__ = self
168-
obj.__name__ = key
169-
return obj
162+
if not self.es_based:
163+
return super(BaseACL, self).__getitem__(key)
164+
return self.getitem_es(self.item_db_id(key))
170165

171166
def getitem_es(self, key):
172-
""" Get item with ID of :key: from elasticsearch """
173-
from nefertari.elasticsearch import ES
174-
es = ES(self.__context_class__.__name__)
167+
es = ES(self.item_model.__name__)
175168
obj = es.get_resource(id=key)
176-
obj.__acl__ = self.context_acl(obj)
169+
obj.__acl__ = self.item_acl(obj)
177170
obj.__parent__ = self
178171
obj.__name__ = key
179172
return obj
@@ -182,7 +175,7 @@ def getitem_es(self, key):
182175
def generate_acl(model_cls, raml_resource, es_based=True):
183176
""" Generate an ACL.
184177
185-
Generated ACL class has a `__context_class__` attribute set to
178+
Generated ACL class has a `item_model` attribute set to
186179
:model_cls:.
187180
188181
ACLs used for collection and item access control are generated from a
@@ -216,12 +209,12 @@ def generate_acl(model_cls, raml_resource, es_based=True):
216209
methods_map=item_methods)
217210

218211
class GeneratedACL(BaseACL):
219-
__context_class__ = model_cls
212+
item_model = model_cls
220213

221214
def __init__(self, request, es_based=es_based):
222215
super(GeneratedACL, self).__init__(request=request)
223216
self.es_based = es_based
224-
self.collection_acl = collection_acl
225-
self.item_acl = item_acl
217+
self._collection_acl = collection_acl
218+
self._item_acl = item_acl
226219

227220
return GeneratedACL

ramses/views.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def reload_context(self, es_based, **kwargs):
148148
kwargs['es_based'] = es_based
149149

150150
acl = self._factory(**kwargs)
151-
if acl.__context_class__ is None:
152-
acl.__context_class__ = self.Model
151+
if acl.item_model is None:
152+
acl.item_model = self.Model
153153

154154
self.context = acl[key]
155155

tests/test_acl.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def test_methods_to_perms_invalid_perm_name(self):
2424

2525
def test_methods_to_perms(self):
2626
perms = acl.methods_to_perms('get', self.methods_map)
27-
assert perms == ['index']
27+
assert perms == ['view']
2828
perms = acl.methods_to_perms('get,post', self.methods_map)
29-
assert sorted(perms) == ['create', 'index']
29+
assert sorted(perms) == ['create', 'view']
3030

3131
def test_parse_acl_no_string(self):
3232
perms = acl.parse_acl('', self.methods_map)
@@ -85,12 +85,12 @@ def test_no_security(self, mock_parse):
8585
model_cls='Foo',
8686
raml_resource=Mock(security_schemes=[]),
8787
es_based=True)
88-
assert acl_cls.__context_class__ == 'Foo'
88+
assert acl_cls.item_model == 'Foo'
8989
assert issubclass(acl_cls, acl.BaseACL)
9090
instance = acl_cls(request=None)
9191
assert instance.es_based
92-
assert instance.collection_acl == [acl.ALLOW_ALL]
93-
assert instance.item_acl == [acl.ALLOW_ALL]
92+
assert instance._collection_acl == [acl.ALLOW_ALL]
93+
assert instance._item_acl == [acl.ALLOW_ALL]
9494
assert not mock_parse.called
9595

9696
def test_wrong_security_scheme_type(self, mock_parse):
@@ -102,12 +102,12 @@ def test_wrong_security_scheme_type(self, mock_parse):
102102
raml_resource=raml_resource,
103103
es_based=False)
104104
assert not mock_parse.called
105-
assert acl_cls.__context_class__ == 'Foo'
105+
assert acl_cls.item_model == 'Foo'
106106
assert issubclass(acl_cls, acl.BaseACL)
107107
instance = acl_cls(request=None)
108108
assert not instance.es_based
109-
assert instance.collection_acl == [acl.ALLOW_ALL]
110-
assert instance.item_acl == [acl.ALLOW_ALL]
109+
assert instance._collection_acl == [acl.ALLOW_ALL]
110+
assert instance._item_acl == [acl.ALLOW_ALL]
111111

112112
def test_correct_security_scheme(self, mock_parse):
113113
raml_resource = Mock(security_schemes=[
@@ -122,18 +122,18 @@ def test_correct_security_scheme(self, mock_parse):
122122
call(acl_string=7, methods_map=acl.item_methods),
123123
])
124124
instance = acl_cls(request=None)
125-
assert instance.collection_acl == mock_parse()
126-
assert instance.item_acl == mock_parse()
125+
assert instance._collection_acl == mock_parse()
126+
assert instance._item_acl == mock_parse()
127127
assert not instance.es_based
128128

129129

130130
class TestBaseACL(object):
131131

132132
def test_init(self):
133133
obj = acl.BaseACL(request='Foo')
134-
assert obj.__context_class__ is None
135-
assert obj.collection_acl is None
136-
assert obj.item_acl is None
134+
assert obj.item_model is None
135+
assert obj._collection_acl == (acl.ALLOW_ALL,)
136+
assert obj._item_acl == (acl.ALLOW_ALL,)
137137
assert obj.request == 'Foo'
138138

139139
def test_apply_callables_no_callables(self):
@@ -199,11 +199,11 @@ def test_apply_callables_functional(self):
199199
acl=[(Deny, principal, ALL_PERMISSIONS)],
200200
methods_map=acl.item_methods,
201201
)
202-
assert new_acl == [(Allow, Everyone, ['show'])]
202+
assert new_acl == [(Allow, Everyone, ['view'])]
203203

204204
def test_magic_acl(self):
205205
obj = acl.BaseACL('req')
206-
obj.collection_acl = [(1, 2, 3)]
206+
obj._collection_acl = [(1, 2, 3)]
207207
obj._apply_callables = Mock()
208208
result = obj.__acl__()
209209
obj._apply_callables.assert_called_once_with(
@@ -212,11 +212,11 @@ def test_magic_acl(self):
212212
)
213213
assert result == obj._apply_callables()
214214

215-
def test_context_acl(self):
215+
def test_item_acl(self):
216216
obj = acl.BaseACL('req')
217-
obj.item_acl = [(1, 2, 3)]
217+
obj._item_acl = [(1, 2, 3)]
218218
obj._apply_callables = Mock()
219-
result = obj.context_acl(obj='foobar')
219+
result = obj.item_acl('foobar')
220220
obj._apply_callables.assert_called_once_with(
221221
acl=[(1, 2, 3)],
222222
methods_map=acl.item_methods,
@@ -226,50 +226,50 @@ def test_context_acl(self):
226226

227227
def test_magic_getitem_es_based(self):
228228
obj = acl.BaseACL('req')
229-
obj.resolve_self_key = Mock()
229+
obj.item_db_id = Mock(return_value=42)
230230
obj.getitem_es = Mock()
231231
obj.es_based = True
232232
obj.__getitem__(1)
233-
obj.resolve_self_key.assert_called_once_with(1)
234-
obj.getitem_es.assert_called_once_with(key=obj.resolve_self_key())
233+
obj.item_db_id.assert_called_once_with(1)
234+
obj.getitem_es.assert_called_once_with(42)
235235

236236
def test_magic_getitem_db_based(self):
237237
obj = acl.BaseACL('req')
238-
obj.resolve_self_key = Mock()
239-
obj.getitem_db = Mock()
238+
obj.item_db_id = Mock(return_value = 42)
239+
obj.item_model = Mock()
240+
obj.item_model.pk_field.return_value = 'id'
240241
obj.es_based = False
241242
obj.__getitem__(1)
242-
obj.resolve_self_key.assert_called_once_with(1)
243-
obj.getitem_db.assert_called_once_with(key=obj.resolve_self_key())
243+
obj.item_db_id.assert_called_once_with(1)
244244

245245
def test_getitem_db(self):
246246
obj = acl.BaseACL('req')
247-
obj.__context_class__ = Mock()
248-
obj.__context_class__.pk_field.return_value = 'myname'
249-
obj.context_acl = Mock()
250-
value = obj.getitem_db(key='varvar')
251-
obj.__context_class__.get_resource.assert_called_once_with(
252-
myname='varvar')
253-
obj.context_acl.assert_called_once_with(
254-
obj.__context_class__.get_resource())
255-
assert value.__acl__ == obj.context_acl()
247+
obj.item_model = Mock()
248+
obj.item_model.pk_field.return_value = 'myname'
249+
obj.item_acl = Mock()
250+
value = obj['varvar']
251+
obj.item_model.get.assert_called_once_with(
252+
__raise=True, myname='varvar')
253+
obj.item_acl.assert_called_once_with(
254+
obj.item_model.get())
255+
assert value.__acl__ == obj.item_acl()
256256
assert value.__parent__ is obj
257257
assert value.__name__ == 'varvar'
258258

259-
@patch('nefertari.elasticsearch.ES')
259+
@patch('ramses.acl.ES')
260260
def test_getitem_es(self, mock_es):
261261
found_obj = Mock()
262262
es_obj = Mock()
263263
es_obj.get_resource.return_value = found_obj
264264
mock_es.return_value = es_obj
265265
obj = acl.BaseACL('req')
266-
obj.__context_class__ = Mock(__name__='Foo')
267-
obj.__context_class__.pk_field.return_value = 'myname'
268-
obj.context_acl = Mock()
266+
obj.item_model = Mock(__name__='Foo')
267+
obj.item_model.pk_field.return_value = 'myname'
268+
obj.item_acl = Mock()
269269
value = obj.getitem_es(key='varvar')
270270
mock_es.assert_called_with('Foo')
271271
es_obj.get_resource.assert_called_once_with(id='varvar')
272-
obj.context_acl.assert_called_once_with(found_obj)
273-
assert value.__acl__ == obj.context_acl()
272+
obj.item_acl.assert_called_once_with(found_obj)
273+
assert value.__acl__ == obj.item_acl()
274274
assert value.__parent__ is obj
275275
assert value.__name__ == 'varvar'

tests/test_views.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def test_parent_queryset(self):
162162

163163
def test_reload_context(self):
164164
class Factory(dict):
165-
__context_class__ = None
165+
item_model = None
166166

167167
def __getitem__(self, key):
168168
return key

0 commit comments

Comments
 (0)