3
3
import threading
4
4
from functools import partial , update_wrapper
5
5
from inspect import getfullargspec , isfunction , ismethod
6
+ from typing import Any , Callable , List , Optional , Tuple , Union
6
7
7
8
logger = logging .getLogger ("rules" )
8
9
9
10
10
- def assert_has_kwonlydefaults (fn , msg ) :
11
+ def assert_has_kwonlydefaults (fn : Callable [..., Any ], msg : str ) -> None :
11
12
argspec = getfullargspec (fn )
12
13
if hasattr (argspec , "kwonlyargs" ):
13
14
if not argspec .kwonlyargs :
@@ -19,21 +20,21 @@ def assert_has_kwonlydefaults(fn, msg):
19
20
20
21
21
22
class Context (dict ):
22
- def __init__ (self , args ) :
23
+ def __init__ (self , args : Tuple [ Any , ...]) -> None :
23
24
super (Context , self ).__init__ ()
24
25
self .args = args
25
26
26
27
27
28
class localcontext (threading .local ):
28
- def __init__ (self ):
29
- self .stack = []
29
+ def __init__ (self ) -> None :
30
+ self .stack : List [ Context ] = []
30
31
31
32
32
33
_context = localcontext ()
33
34
34
35
35
36
class NoValueSentinel (object ):
36
- def __bool__ (self ):
37
+ def __bool__ (self ) -> bool :
37
38
return False
38
39
39
40
__nonzero__ = __bool__ # python 2
@@ -45,7 +46,17 @@ def __bool__(self):
45
46
46
47
47
48
class Predicate (object ):
48
- def __init__ (self , fn , name = None , bind = False ):
49
+ fn : Callable [..., Any ]
50
+ num_args : int
51
+ var_args : bool
52
+ name : str
53
+
54
+ def __init__ (
55
+ self ,
56
+ fn : Union ["Predicate" , Callable [..., Any ]],
57
+ name : Optional [str ] = None ,
58
+ bind : bool = False ,
59
+ ) -> None :
49
60
# fn can be a callable with any of the following signatures:
50
61
# - fn(obj=None, target=None)
51
62
# - fn(obj=None)
@@ -98,13 +109,13 @@ def __init__(self, fn, name=None, bind=False):
98
109
self .name = name or fn .__name__
99
110
self .bind = bind
100
111
101
- def __repr__ (self ):
112
+ def __repr__ (self ) -> str :
102
113
return "<%s:%s object at %s>" % (type (self ).__name__ , str (self ), hex (id (self )))
103
114
104
- def __str__ (self ):
115
+ def __str__ (self ) -> str :
105
116
return self .name
106
117
107
- def __call__ (self , * args , ** kwargs ):
118
+ def __call__ (self , * args , ** kwargs ) -> Any :
108
119
# this method is defined as variadic in order to not mask the
109
120
# underlying callable's signature that was most likely decorated
110
121
# as a predicate. internally we consistently call ``_apply`` that
@@ -114,7 +125,7 @@ def __call__(self, *args, **kwargs):
114
125
return self .fn (* args , ** kwargs )
115
126
116
127
@property
117
- def context (self ):
128
+ def context (self ) -> Optional [ Context ] :
118
129
"""
119
130
The currently active invocation context. A new context is created as a
120
131
result of invoking ``test()`` and is only valid for the duration of
@@ -150,7 +161,7 @@ def context(self):
150
161
except IndexError :
151
162
return None
152
163
153
- def test (self , obj = NO_VALUE , target = NO_VALUE ):
164
+ def test (self , obj : Any = NO_VALUE , target : Any = NO_VALUE ) -> bool :
154
165
"""
155
166
The canonical method to invoke predicates.
156
167
"""
@@ -162,25 +173,25 @@ def test(self, obj=NO_VALUE, target=NO_VALUE):
162
173
finally :
163
174
_context .stack .pop ()
164
175
165
- def __and__ (self , other ):
176
+ def __and__ (self , other ) -> "Predicate" :
166
177
def AND (* args ):
167
178
return self ._combine (other , operator .and_ , args )
168
179
169
180
return type (self )(AND , "(%s & %s)" % (self .name , other .name ))
170
181
171
- def __or__ (self , other ):
182
+ def __or__ (self , other ) -> "Predicate" :
172
183
def OR (* args ):
173
184
return self ._combine (other , operator .or_ , args )
174
185
175
186
return type (self )(OR , "(%s | %s)" % (self .name , other .name ))
176
187
177
- def __xor__ (self , other ):
188
+ def __xor__ (self , other ) -> "Predicate" :
178
189
def XOR (* args ):
179
190
return self ._combine (other , operator .xor , args )
180
191
181
192
return type (self )(XOR , "(%s ^ %s)" % (self .name , other .name ))
182
193
183
- def __invert__ (self ):
194
+ def __invert__ (self ) -> "Predicate" :
184
195
def INVERT (* args ):
185
196
result = self ._apply (* args )
186
197
return None if result is None else not result
@@ -208,7 +219,7 @@ def _combine(self, other, op, args):
208
219
209
220
return op (self_result , other_result )
210
221
211
- def _apply (self , * args ):
222
+ def _apply (self , * args ) -> Optional [ bool ] :
212
223
# Internal method that is used to invoke the predicate with the
213
224
# proper number of positional arguments, inside the current
214
225
# invocation context.
@@ -268,12 +279,12 @@ def inner(fn):
268
279
always_deny = predicate (lambda : False , name = "always_deny" )
269
280
270
281
271
- def is_bool_like (obj ):
282
+ def is_bool_like (obj ) -> bool :
272
283
return hasattr (obj , "__bool__" ) or hasattr (obj , "__nonzero__" )
273
284
274
285
275
286
@predicate
276
- def is_authenticated (user ):
287
+ def is_authenticated (user ) -> bool :
277
288
if not hasattr (user , "is_authenticated" ):
278
289
return False # not a user model
279
290
if not is_bool_like (user .is_authenticated ): # pragma: no cover
@@ -283,27 +294,27 @@ def is_authenticated(user):
283
294
284
295
285
296
@predicate
286
- def is_superuser (user ):
297
+ def is_superuser (user ) -> bool :
287
298
if not hasattr (user , "is_superuser" ):
288
299
return False # swapped user model, doesn't support is_superuser
289
300
return user .is_superuser
290
301
291
302
292
303
@predicate
293
- def is_staff (user ):
304
+ def is_staff (user ) -> bool :
294
305
if not hasattr (user , "is_staff" ):
295
306
return False # swapped user model, doesn't support is_staff
296
307
return user .is_staff
297
308
298
309
299
310
@predicate
300
- def is_active (user ):
311
+ def is_active (user ) -> bool :
301
312
if not hasattr (user , "is_active" ):
302
313
return False # swapped user model, doesn't support is_active
303
314
return user .is_active
304
315
305
316
306
- def is_group_member (* groups ):
317
+ def is_group_member (* groups ) -> Callable [..., Any ] :
307
318
assert len (groups ) > 0 , "You must provide at least one group name"
308
319
309
320
if len (groups ) > 3 :
@@ -314,7 +325,7 @@ def is_group_member(*groups):
314
325
name = "is_group_member:%s" % "," .join (g )
315
326
316
327
@predicate (name )
317
- def fn (user ):
328
+ def fn (user ) -> bool :
318
329
if not hasattr (user , "groups" ):
319
330
return False # swapped user model, doesn't support groups
320
331
if not hasattr (user , "_group_names_cache" ): # pragma: no cover
0 commit comments