-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathyax.py
1316 lines (1077 loc) · 45.5 KB
/
yax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2024 Daniel Bershatsky
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module `yax` implements module tracing routines, module expression (MoX)
building, querying, and mutation.
"""
import re
from copy import copy
from dataclasses import dataclass, field, fields
from functools import partial, wraps
from io import StringIO
from json import dumps
from typing import (
IO, Any, ClassVar, Generic, ParamSpec, Self, Sequence, Type, TypeAlias,
TypeVar)
from xml.etree.ElementTree import (
Element, ElementTree, SubElement, indent as indent_etree)
import flax
import flax.linen as nn
import jax
import jax.extend as jex
import jax.extend.linear_util as lu
import jax.numpy as jnp
from flax.core.scope import LazyRng
from flax.linen.module import Callable, InterceptorContext, intercept_methods
from flax.typing import RNGSequences
from jax.core import (
AbstractValue, ClosedJaxpr as Jaxpr, ConcreteArray, MainTrace, ShapedArray,
Sublevel, Trace, Tracer, find_top_trace, new_main)
# TODO(@daskol): Make PR on reexporting PyTreeDef.
try:
from jax.tree import PyTreeDef # type: ignore[attr-defined]
except ImportError:
from jax.tree_util import PyTreeDef
__all__ = ('Equation', 'Expr', 'Literal', 'Mox', 'Symbol', 'Var', 'dump_yson',
'dump_xml', 'eval_mox', 'make_mox', 'map_mox', 'query', 'sub')
# TODO(@daskol): Python 3.12 introduced new type parameter syntax (PEP-0695)
# but some code quality tools (e.g. yapf) do not support this syntax.
Args = ParamSpec('Args')
class ModuleTracer(Tracer):
__slots__ = 'value'
def __init__(self, trace: Trace, value):
super().__init__(trace)
self.value = value
def __hash__(self) -> int:
return id(self)
def __eq__(self, other) -> bool:
return self is other
@property
def aval(self) -> AbstractValue:
# TODO(@daskol): Some abstract values should be less abstract. There is
# cases when we should preserve original concrete values. For example,
# arguments of binary operation between tracer and numerical literal
# like `x > 0`. In this case `x` should be evaluated as `ShapedArray`
# and zero as `ConcreteArray`. Otherwise, this constant values get
# missing in evaluation time.
if isinstance(self.value, AbstractValue):
return self.value
else:
return jax.api_util.shaped_abstractify(self.value)
def full_lower(self):
# TODO(@daskol): How to properly implement lowering?
#
# The idea is that we need only tracers for abstract avaluation. Thus,
# we should not lower tracer to anything other than a Tracer (actually,
# a ModuleTracer). It seems that we should `full_lower` everywhere
# instead of `full_raise`.
match self.value:
case ModuleTracer():
return self.value.full_lower()
case Tracer():
raise RuntimeError(
'Unreachable executation path: full lowering expects '
'value to be either ModuleTracer or an Array '
f'implementation but actual type is {type(self.value)}.')
case _:
return self
class ModuleTrace(Trace[ModuleTracer]):
def __init__(self, main: MainTrace, sublevel: Sublevel, *,
builder: 'MoxBuilder', **kwargs) -> None:
super().__init__(main, sublevel)
self.builder: MoxBuilder = builder
def pure(self, val) -> ModuleTracer:
"""Wrap value to monadic/functorial context.
Our `pure` differ from original implementation in a way to work not
only with Arrays but with AbstractValues. The idea is to interpret
concrete arrays as literals while shaped arrays are just variables.
Thus, we abstractify input arrays (arguments) manually.
"""
if isinstance(val, AbstractValue):
aval = val
else:
aval = jax.core.get_aval(val)
return ModuleTracer(self, aval)
def lift(self, tracer: Tracer) -> ModuleTracer:
return ModuleTracer(self, tracer)
def sublift(self, tracer: Tracer) -> ModuleTracer:
return ModuleTracer(self, tracer)
def process_primitive(self, prim: jex.core.Primitive, tracers, params):
result, _ = prim.abstract_eval(*[x.aval for x in tracers], **params)
outs, out_tree = jax.tree.flatten(result)
assert all(isinstance(x, ShapedArray) for x in outs if x), \
'Assumption on type of result is violated: {outs}.'
out_flat_tracers = [ModuleTracer(self, x) for x in outs]
out_tracers = jax.tree.unflatten(out_tree, out_flat_tracers)
self.builder.append(prim, params, tracers, out_flat_tracers)
return out_tracers
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, **kwargs):
del primitive, jvp, kwargs
# TODO(@daskol): Why reference implementations in JAX create a
# subtracer? Is it just purely debugging think? It seems that sublevels
# helps to find leaking tracers. This can be highly important for
# partial evauation but here we just forward tracers to our
# `process_primitive` method.
#
# We would enforce subtracers here. Then one should uncomment the
# following. The only issue is that its unclear how to process
# `out_tracers`. Should we `full_lower` them and wrap up them again
# with outer tracer?
#
# with jax.core.new_sublevel():
# trace: ModuleTrace = self.main.with_cur_sublevel()
# in_tracers = [trace.sublift(t) for t in tracers]
# out_tracers = fun.call_wrapped(*in_tracers)
# return [trace.full_raise(t) for t in out_tracers]
return fun.call_wrapped(*tracers)
@lu.transformation
def _lower(builder: 'MoxBuilder', main: MainTrace, *ins):
trace: ModuleTrace = main.with_cur_sublevel()
# NOTE We manually abstractify up to ShapedArray all input arguments.
in_tracers = [trace.pure(jax.api_util.shaped_abstractify(x)) for x in ins]
builder.set_inputs(in_tracers)
outs = yield in_tracers, {}
# TODO(@daskol): We do not use output tracers. Should we remove them? This
# issue is related to the proper implementation of (sub)lifting and
# unboxing that still have quite vague semantics.
out_tracers = [trace.full_raise(t) for t in outs]
builder.set_outputs(out_tracers)
yield out_tracers
@lu.transformation
def _raise(builder, *ins):
with new_main(ModuleTrace, True, builder=builder) as main:
outs = yield (main, *ins), {}
del main
yield outs
def trace_modules(wf: lu.WrappedFun, builder: 'MoxBuilder') -> lu.WrappedFun:
return _raise(_lower(wf, builder), builder)
SymbolValueType = TypeVar('SymbolValueType', bound='ShapedArray')
@dataclass(slots=True)
class Symbol(Generic[SymbolValueType]):
"""A value placeholder in evaluation time.
Derived dataclasses should not generate `__eq__` to preserve `__hash__`.
"""
value: SymbolValueType
def __eq__(self, other) -> bool:
return self is other
def __hash__(self) -> int:
return id(self)
@dataclass(slots=True, eq=False)
class Var(Symbol[ShapedArray]):
pass
@dataclass(slots=True, eq=False)
class Literal(Symbol[ConcreteArray]):
@property
def const(self):
return self.value.val
# Check hashability of Symbol hierarchy.
_ = {
Symbol(ShapedArray((), jnp.float32)),
Var(ShapedArray((), jnp.float32)),
Literal(ConcreteArray(jnp.float32, jnp.empty(()))),
}
@dataclass(slots=True)
class Expr:
"""A base type that represents a module tree structure."""
inputs: list[Symbol]
outputs: list[Symbol]
params: dict[str, Any]
@dataclass(slots=True)
class Equation(Expr):
"""A leaf of module tree that refers to a JAX
:class:`jax.extend.core.Primitive`.
"""
prim: jex.core.Primitive
def to_dict(self, recursively=True) -> dict[str, Any]:
return {**self.params, 'primitive': self.prim.name}
def default_treedef(default=()) -> PyTreeDef:
_, treedef = jax.tree.flatten(default)
return treedef
@dataclass(slots=True)
class Mox(Expr):
children: list[Expr] = field(default_factory=list)
module_ty: Type[nn.Module] | None = None
entrypoint: str | None = None
in_tree: PyTreeDef = field(default_factory=default_treedef)
out_tree: PyTreeDef = field(default_factory=default_treedef)
var_tree: PyTreeDef = field(default_factory=default_treedef)
rngs: RNGSequences = field(default_factory=dict)
@property
def is_ephemeral(self) -> bool:
"""Ephemeral module expression does not reflect any real
:class:`flax.linen.Module`.
"""
return self.module_ty is None or self.entrypoint is None
def __repr__(self) -> str:
buf = StringIO()
dump(self, buf)
return buf.getvalue()
def to_dict(self, recursively=True) -> dict[str, Any]:
res = {
**self.params,
'primitive': 'module_call',
'entrypoint': self.entrypoint,
'ephemeral': self.is_ephemeral,
}
if self.module_ty:
res['type'] = self.module_ty.__name__
if recursively:
children = []
for child in self.children:
if isinstance(child, Mox):
children.append(child.to_dict())
elif isinstance(child, Equation):
children.append({
**child.params, 'primitive':
child.prim.name
})
res['children'] = children
return res
def to_json(self, indent: int | None = None) -> str:
def default(obj):
if isinstance(obj, Jaxpr):
return obj.pretty_print(use_color=False)
return getattr(obj, '__name__', str(obj))
return dumps(self.to_dict(), ensure_ascii=False, indent=indent,
default=default)
def make_mox(fn: Callable[Args, Any]) -> Callable[Args, Mox]:
"""Make a tracing routine for `fn` to obtaine its Module eXpression
(MoX).
>>> m = nn.Dense(10)
>>> batch = jnp.empty((1, 10))
>>> params = jax.jit(m.init)(jax.random.PRNGKey(42), batch)
>>> mox = make_mox(m.apply)(params, batch)
"""
@wraps(fn)
def wrapper(*args: Args.args, **kwargs: Args.kwargs) -> Mox:
wf = lu.wrap_init(fn)
in_args, in_tree = jax.tree.flatten((args, kwargs))
wf, out_tree_thunk = jax.api_util.flatten_fun(wf, in_tree)
# Root module expression is incomplete in traversing due to missing
# outputs. May be, it would be better to construct builder from both
# `inputs` and `in_tree at once.
builder = MoxBuilder()
builder.set_input_tree(in_tree)
wf = trace_modules(wf, builder)
with intercept_methods(builder.intercept):
_ = wf.call_wrapped(*in_args)
builder.set_output_tree(out_tree_thunk())
return builder.build()
return wrapper
class MoxBuilder:
def __init__(self):
self.root = Mox([], [], {})
self.block_stack: list[Mox] = [self.root]
self.symbols: dict[ModuleTracer, Symbol] = {}
self.module_stack: list[InterceptorContext] = [] # get_module_path
def build(self) -> Mox:
return self.root
def get_module_path(self) -> str:
return ''.join(f'/{type(c.module).__qualname__}'
for c in self.module_stack)
def intercept(self, fn: Callable[..., Any], args, kwargs,
context: InterceptorContext) -> Any:
"""A hook to intercept call of `flax.linen.Module` method."""
# TODO(@daskol): Should we run nested tracer?
# TODO(@daskol): How to flatten in abstract way?
# TODO(@daskol): Do not ignore `setup` method.
if context.method_name == 'setup':
return fn(*args, **kwargs)
# It is important to access `__dict__` directly since `__getattr__` is
# overriden.
params = {f.name: context.module.__dict__.get(f.name)
for f in fields(context.module) if f.name != 'parent'}
module_info = type(context.module), context.method_name
module_rngs = {}
for name, rng in context.module.scope.rngs.items():
[rng_sym] = self.to_symbols([rng.rng])
module_rngs[name] = LazyRng(rng_sym, rng.suffix)
# Push on stack partially built module expression object. It will be
# finalized by the end of this routine.
child = Mox([], [], params, [], *module_info, rngs=module_rngs)
parent = self.block_stack[-1]
parent.children += [child]
self.block_stack += [child]
# Flax passes a module function.
unbound_fn: Callable[..., Any]
if isinstance(fn, partial):
unbound_fn = fn.func
else:
method_fn = getattr(context.module, context.method_name)
unbound_fn = (lambda x: x.__func__)(method_fn)
# Assume that weights are already wrapped up.
flat_vars, child.var_tree = jax.tree.flatten(context.module.variables)
child.inputs.extend(self.to_symbols(flat_vars))
# Flax assumes that the first argument is a weight dictionary (or
# Module or Scope). Thus, we need to flatten this dictionary for
# binding and unflatten it in evaluation time.
args = (context.module, ) + args
(scope, *in_args), child.in_tree = jax.tree.flatten((args, kwargs))
trace = find_top_trace(flat_vars + in_args) # TODO(@daskol): Dynamic!
in_tracers = [trace.full_raise(a) for a in in_args]
child.inputs.extend(self.to_symbols(in_tracers)) # XXX
# Flatten function (inputs and outputs) for building intermediate
# representation.
wrap_fn = lu.wrap_init(unbound_fn)
flat_fn, out_tree_fn = jax.api_util.flatten_fun(wrap_fn, child.in_tree)
outs = flat_fn.call_wrapped(scope, *in_tracers)
# TODO(@daskol): Remove code duplication with `_lower`.
flat_outs, _ = jax.tree.flatten(outs)
out_tracers = [trace.full_raise(x) for x in flat_outs]
# TODO(@daskol): Should we introduce `module_call` primitive here? Or
# just flatten outputs?
child.out_tree = out_tree_fn()
child.outputs.extend(self.to_symbols(out_tracers))
self.block_stack.pop()
return jax.tree.unflatten(child.out_tree, outs)
def append(self, prim: jex.core.Primitive, params: dict[str, Any],
in_tracers: list[ModuleTracer],
out_tracers: list[ModuleTracer]):
in_symbols = self.to_symbols(in_tracers)
out_symbols = self.to_symbols(out_tracers)
eq = Equation(in_symbols, out_symbols, params, prim)
block = self.block_stack[-1]
block.children.append(eq)
def set_input_tree(self, tree: PyTreeDef):
self.root.in_tree = tree
def set_inputs(self, tracers: Sequence[ModuleTracer]):
self.root.inputs.clear()
self.root.inputs.extend(self.to_symbols(tracers))
def set_output_tree(self, tree: PyTreeDef):
self.root.out_tree = tree
def set_outputs(self, tracers: Sequence[ModuleTracer] | ModuleTracer):
if not isinstance(tracers, list | tuple | Sequence):
tracers = [tracers]
self.root.outputs.clear()
self.root.outputs.extend(self.to_symbols(tracers))
def to_symbols(self, tracers: Sequence[ModuleTracer]) -> Sequence[Symbol]:
symbols = []
for tracer in tracers:
if (symbol := self.symbols.get(tracer)) is None:
match tracer.aval:
case ConcreteArray():
symbol = Literal(tracer.aval)
case ShapedArray():
symbol = Var(tracer.aval)
case _:
raise RuntimeError('Unexpected abstract value of type '
f'{type(tracer.aval)}.')
self.symbols[tracer] = symbol
symbols += [symbol]
return symbols
def fully_qualified_name(ty: Type[nn.Module] | None) -> str:
if ty is None:
return '<none>'
elif ty.__module__ == 'builtins':
return ty.__qualname__
else:
return f'{ty.__module__}.{ty.__qualname__}'
def dump(node: Expr, fileobj: IO[str], *, depth=0):
indent = ' ' * depth
match node:
case Mox():
if node.is_ephemeral:
name_ty = 'Ephemeral'
name = '<none>'
else:
name_ty = fully_qualified_name(node.module_ty)
name = node.params['name']
print(f'{indent}inputs ={node.inputs}', file=fileobj)
print(f'{indent}outputs={node.outputs}', file=fileobj)
keys = (k for k in node.params.keys()
if k not in ('name', 'parent'))
try:
key = next(keys)
val = node.params[key]
attrs = f'{key}={val}'
except StopIteration:
attrs = ''
print(f'{indent}mod {name} : {name_ty}({attrs}) {{ # {depth}',
file=fileobj)
case Expr():
raise RuntimeError('Unexpected node of type {type(node)}.')
for child in node.children:
match child:
case Mox():
dump(child, fileobj, depth=depth + 1)
case Equation():
print(f'{indent} eq {child.prim.name}', file=fileobj)
print(f'{indent} inputs ={child.inputs}', file=fileobj)
print(f'{indent} outputs={child.outputs}', file=fileobj)
print(f'{indent} {child}', file=fileobj)
print(f'{indent}}}', file=fileobj)
def dump_yson(expr: Expr, fileobj: IO[bytes], indent: int = 2):
"""Serialize module expression `expr` as a YSON-formatted object to
`fileobj`.
"""
try:
from yt.yson import YsonEntity, YsonList, YsonType, dump
except ImportError as e:
msg = (
'Missing YSON packages. Try to install `ytsaurus-client` package '
'for basic YSON support and `ytsaurus-yson` for fast serialization'
'/deserialization.')
raise RuntimeError(msg) from e
def fmt(val):
if isinstance(val, bool | int | float):
return val
elif isinstance(val, Jaxpr):
return val.pretty_print(use_color=False)
elif callable(val):
return f'{val.__module__}.{val.__name__}'
elif hasattr(val, 'dtype'):
return str(val.dtype)
else:
return str(val)
def to_yson(expr: Expr) -> YsonType:
if isinstance(expr, Mox):
root = YsonList([])
root.attributes['primitive'] = 'module_call'
root.attributes['ephemeral'] = fmt(expr.is_ephemeral)
if expr.module_ty:
root.attributes['type'] = expr.module_ty.__name__
root.extend(to_yson(child) for child in expr.children)
elif isinstance(expr, Equation):
root = YsonEntity()
root.attributes['primitive'] = expr.prim.name
else:
raise RuntimeError('Unexpected expression type {type(expr)}.')
root.attributes.update({
k: fmt(v)
for k, v in expr.params.items() if v is not None
})
return root
root = to_yson(expr)
kwargs = {}
if indent > 0:
kwargs = dict(yson_format='pretty', indent=indent)
dump(root, fileobj, **kwargs)
def dump_xml(expr: Expr, fileobj: IO[str], indent: int = 2):
"""Serialize module expression `expr` to XML representation."""
def fmt(val):
if isinstance(val, bool):
return str(val).lower()
elif isinstance(val, Jaxpr):
return val.pretty_print(use_color=False)
elif callable(val):
return f'{val.__module__}.{val.__name__}'
elif hasattr(val, 'dtype'):
return str(val.dtype)
else:
return str(val)
def build_subtree(parent: Element, node: Expr):
attrs = {}
if isinstance(node, Mox):
tag = 'module_call'
attrs['type'] = node.module_ty.__name__
attrs['ephemeral'] = fmt(node.is_ephemeral)
elif isinstance(node, Equation):
tag = node.prim.name
attrs.update({
k: fmt(v)
for k, v in node.params.items() if v is not None
})
elem = SubElement(parent, tag, attrs)
if isinstance(node, Mox):
for child in node.children:
build_subtree(elem, child)
if isinstance(expr, Mox):
root = Element('module_call', attrib={'ephemeral': 'true'})
for child in expr.children:
build_subtree(root, child)
elif isinstance(expr, Equation):
root = Element(expr.prim.name, attrib={})
else:
raise RuntimeError('Unexpected expression type {type(expr)}.')
et = ElementTree(root)
if indent:
indent_etree(et, space=' ' * indent)
et.write(fileobj, encoding='unicode', xml_declaration=True)
@dataclass(slots=True, frozen=True)
class Token:
kind: str
value: str = ''
@property
def is_empty(self) -> bool:
return self.kind == '' and self.value == ''
@classmethod
def empty(cls) -> Self:
return cls('', '')
def __repr__(self) -> str:
if self.is_empty:
return 'ε'
return f'<{self.kind}>{self.value}'
def __str__(self) -> str:
return self.value or self.kind
_NODE_TYPE: ClassVar = re.compile(
r'(comment|text|processing-instruction|node)')
_NAME_CHAR: ClassVar = re.compile(r'[A-z_][0-9A-z_]*')
_DOUBLE_QUOTED_STR: ClassVar = re.compile(r'"[^"]*"')
_SINGLE_QUOTED_STR: ClassVar = re.compile(r"'[^']*'")
_FULL_NUM: ClassVar = re.compile(r'\d+(\.(\d+)?)?')
_FRAC_NUM: ClassVar = re.compile(r'\.(\d+)?')
_AXIS_NAME: ClassVar = re.compile(
r'(ancestor|ancestor-or-self|attribute|child|descendant'
r'|descendant-or-self|following|following-sibling|namespace|parent'
r'|preceding|preceding-sibling|self)')
def tokenize_xpath(val: str):
token = Token.empty()
while val:
while val:
if val[0] != ' ':
break
val = val[:1]
val, token = tokenize_expr_token(val, token)
yield token
def tokenize_expr_token(val: str, last: Token):
if val[:2] in ('..', '::'):
return val[2:], Token(val[:2])
elif val[:1] in '()[]@,.':
return val[1:], Token(val[:1])
_PARSERS = (tokenize_name_test, tokenize_node_type, tokenize_node_operator,
tokenize_function_name, tokenize_axis_name, tokenize_literal,
tokenize_number)
for parser in _PARSERS:
try:
val, token = parser(val, last)
except TypeError:
pass # Parser failed to parse: None is parse result.
else:
return val, token
raise NotImplementedError(
f'Perhaps, this production rule is not implemented at {val[:16]}.')
def tokenize_name_test(val: str, last: Token):
# If there is a preceding token and the preceding token is not one of @,
# ::, (, [, , or an Operator, then a * must be recognized as a
# MultiplyOperator and an NCName must be recognized as an OperatorName.
if val[:1] == '*':
if not last.is_empty and last.kind not in ('@', '::', '(', 'Operator'):
return val[1:], Token('Operator', val[:1])
else:
return val[1:], Token('NameTest', val[:1])
def tokenize_ncname(val: str) -> str | None:
if (m := Token._NAME_CHAR.match(val)) is not None:
name = m.group(0)
return name
# NCName ':' '*'
if not (name := tokenize_ncname(val)):
tail = val[len(name):]
if tail[:2] == ':*':
return tail[2:], Token('NameTest', f'{name}:*')
# QName := NCName : NCName
if (prefix := tokenize_ncname(val)):
tail = val[len(prefix):]
if tail[:1] != ':':
return tail, Token('NameTest', prefix)
if (suffix := tokenize_ncname(tail[1:])):
offset = len(prefix) + 1 + len(suffix)
return val[offset:], Token('NameTest', f'{prefix}:{suffix}')
def tokenize_node_type(val: str, last: Token):
if (m := Token._NODE_TYPE.match(val)) is not None:
value = m.group(0)
return val[len(value):], Token('NodeType', value)
def tokenize_node_operator(val: str, last: Token):
if val[:3] in ('and', 'or', 'mod', 'div'):
return val[3:], Token('Operator', val[:3])
elif val[:2] in ('//', '!=', '<=', '>='):
return val[2:], Token('Operator', val[:2])
elif val[:1] in '*/|+-=':
return val[1:], Token('Operator', val[:1])
def tokenize_function_name(val: str, last: Token):
pass
def tokenize_axis_name(val: str, last: Token):
if (m := Token._AXIS_NAME.match(val)) is not None:
value = m.group(0)
return val[len(value):], Token('AxisName', value)
def tokenize_literal(val: str, last: Token):
for pattern in (Token._DOUBLE_QUOTED_STR, Token._SINGLE_QUOTED_STR):
if (m := pattern.match(val)):
value = m.group(0)
return val[len(value):], Token('Literal', value)
def tokenize_number(val: str, last: Token):
for pattern in (Token._FULL_NUM, Token._FRAC_NUM):
if (m := pattern.match(val)):
value = m.group(0)
return val[len(value):], Token('Number', value)
class XPath:
"""XML Path expression language expression.
[1]: https://www.w3.org/TR/xpath-10/
[1]: https://www.w3.org/TR/xpath-20/
[1]: https://www.w3.org/TR/xpath-30/
[1]: https://www.w3.org/TR/xpath-31/
"""
def __init__(self, xpath: str | Self):
if isinstance(xpath, XPath):
self.locs = (*xpath.locs, )
return
try:
tokens = (*tokenize_xpath(xpath), )
except Exception as e:
raise RuntimeError(f'Failed to tokenize XPath: {xpath}.') from e
try:
self.locs: tuple[LocationStep, ...] = (*parse_xpath(tokens), )
except Exception as e:
raise RuntimeError(f'Failed to parse XPath: {xpath}.') from e
def __repr__(self) -> str:
return ' '.join(repr(x) for x in self.locs)
def __str__(self) -> str:
return '/'.join(str(x) for x in self.locs)
PredicateFn: TypeAlias = Callable[[dict[str, Any]], bool]
@dataclass(slots=True, frozen=True)
class LocationPredicate:
func: PredicateFn
desc: str
def __call__(self, attrs: dict[str, Any]) -> bool:
return self.func(attrs)
def __str__(self) -> str:
return self.desc
@dataclass(slots=True, frozen=True)
class LocationStep:
axis: str
node: str
predicate: tuple[LocationPredicate, ...] = ()
def __str__(self) -> str:
pred = ''.join(str(p) for p in self.predicate)
return f'{self.axis}::{self.node}{pred}'
def parse_xpath(tokens: list[Token]):
if len(tokens) == 0:
yield LocationStep('self', 'node()')
if len(tokens) == 1:
if tokens[0].kind == 'Operator' and tokens[0].value == '/':
yield LocationStep('self', 'node()')
return
while tokens:
if tokens[0].kind == 'Operator':
if tokens[0].value == '/':
tokens = tokens[1:]
elif tokens[0].value == '//':
yield LocationStep('descendant-or-self', 'node()')
tokens = tokens[1:]
else:
raise RuntimeError(f'Unexpected operator: {tokens[0]!r}.')
else:
loc, tokens = parse_location_step(tokens)
yield loc
def parse_location_step(tokens: list[Token]):
# Abbreviated step first.
if tokens[0].kind == '.':
return LocationStep('self', 'node()'), tokens[1:]
elif tokens[0].kind == '..':
return LocationStep('parent', 'node()'), tokens[1:]
if tokens[0].kind == '@':
axis = 'attribute'
tokens = tokens[1:]
elif (len(tokens) > 1 and
tokens[0].kind == 'AxisName' and tokens[1].kind == '::'):
axis = tokens[1].value
tokens = tokens[2:]
else:
axis = 'child' # Default axis specifier.
if len(tokens) > 0 and tokens[0].kind == 'NameTest':
node = tokens[0].value
tokens = tokens[1:]
elif (len(tokens) > 2 and tokens[0].kind == 'NodeType' and
tokens[1].kind == '(' and tokens[2].kind == ')'):
node = f'{tokens[0].value}()'
tokens = tokens[3:]
else:
node = 'node()'
preds = ()
while len(tokens) > 5:
if tokens[0].kind != '[':
break
prefix = ''.join(t.kind for t in tokens[:4])
if prefix != '[@NameTestOperator':
raise RuntimeError('Failed to parse predicate of a step.')
if tokens[4].kind not in ('Literal', 'Number'):
raise RuntimeError('Failed to parse predicate of a step.')
if tokens[5].kind != ']':
raise RuntimeError('Failed to parse predicate of a step.')
key = tokens[2].value
try:
val = int(tokens[4].value)
except ValueError:
try:
val = float(tokens[4].value)
except ValueError:
val = tokens[4].value[1:-1]
pred = LocationPredicate(lambda x: x[key] == val,
''.join(str(t) for t in tokens[:6]))
preds += (pred, )
tokens = tokens[6:]
return LocationStep(axis, node, preds), tokens
def eval_module(read, write, mox: Mox):
if mox.is_ephemeral:
raise NotImplementedError('Only concrete modules allowed.')
def read_safe(var: Symbol | flax.core.scope.Scope) -> Any:
if isinstance(var, Symbol):
return read(var)
def write_safe(var: Symbol, val: Any):
if var not in mox.inputs:
write(var, val)
# Weight and input symbol are all flattened and stored in single list. In
# order to apply module func to weights and inputs we need to separate
# them, restore weights, and restore inputs.
num_vars = mox.var_tree.num_leaves
var_syms = jax.tree.unflatten(mox.var_tree, mox.inputs[:num_vars])
var_vals = jax.tree.map(read, var_syms)
# Materialize RNGs states.
rngs = {}
rng_counters = {}
for name, rng in mox.rngs.items():
rngs[name] = LazyRng(rng=read_safe(rng.rng), suffix=rng.suffix)
rng_counters[name] = 0
with flax.core.bind(var_vals).temporary() as scope:
# Ad hoc patching of RNG states in scope.
scope.rngs = rngs
scope.rng_counters = rng_counters
mod = mox.module_ty(**mox.params)
scoped_mod = mod.clone(parent=scope)
# See func:`bind` in flax/core/scope.py:1105.
in_vals = jax.tree.map(read_safe, [scope] + mox.inputs[num_vars:])
(_, *in_args), in_kwargs = jax.tree.unflatten(mox.in_tree, in_vals)
unbound_fn = getattr(mox.module_ty, mox.entrypoint)
out_vals = unbound_fn(scoped_mod, *in_args, **in_kwargs)
# Output symbols are stored as flattened. So, we need to flatten outputs
# and write result back.
outputs, out_tree = jax.tree.flatten(out_vals)
assert out_tree == mox.out_tree, \
f'Output tree mismatched in eval time: {mox.out_tree} -> {out_tree}.'
jax.tree.map(write_safe, mox.outputs, outputs)
def eval_equation(read, write, eq: Equation):
in_vals = [read(x) for x in eq.inputs] # TODO(@daskol): Trees?
subfuns, params = eq.prim.get_bind_params(eq.params)
out_vals = eq.prim.bind(*subfuns, *in_vals, **params)
if not eq.prim.multiple_results:
out_vals = [out_vals]
for sym, val in zip(eq.outputs, out_vals):
write(sym, val)
def eval_mox(tree: Mox, *args, **kwargs):
"""Evaluate a module expression `tree` with `args` and `kwargs`."""
env: dict[Symbol, Any] = {}
def read(var: Symbol) -> Any:
"""Read a symbol from execution context or take literal value."""
if (val := env.get(var)) is not None:
return val
if isinstance(var, Literal):
return var.const
raise KeyError(f'Variable {var} is undefined.')
def write(var: Symbol, val: Any):
assert var not in env, f'Variable {var} has been already defined.'
env[var] = val
def fn(node: Expr) -> Mox | None:
if isinstance(node, Mox):
if node.is_ephemeral:
return node
else:
return eval_module(read, write, node)
elif isinstance(node, Equation):
return eval_equation(read, write, node)
# Initialize execution context and execute.
flat_args, in_tree = jax.tree.flatten((args, kwargs))
assert in_tree == tree.in_tree, \
f'Arguments and input tree mismatched: {in_tree} vs {tree.in_tree}.'
jax.tree.map(write, tree.inputs, flat_args)
map_mox(fn, tree)
flatten_res = [read(x) for x in tree.outputs]
if len(flatten_res) == 1:
return flatten_res[0]
return flatten_res
def map_mox(fn: Callable[[Expr], Any], tree: Mox):
"""Apply map transformation `fn` to a module `tree`."""
nodes: list[Expr] = [tree]
while nodes:
node: Expr = nodes.pop()
if isinstance(res := fn(node), Mox):
nodes += reversed(res.children)
ModulePath: TypeAlias = tuple[str, ...]
# SubFn takes a path `path` to a node `node` which has been selected with
# search query and returns a new node which substitutes original node `node`.
SubFn: TypeAlias = Callable[[ModulePath, Expr], Expr]
def sub(expr: str | XPath | Sequence[Equation | Mox],
repl: Mox | Equation | SubFn, mox: Mox) -> Mox:
"""Substitute a module expression `mox` with `repl` according to matching
pattern `expr`.
Args:
expr: XPath expression for nodes to substitute.
repl: replacement for selected nodes.
mox: a module expression to mutate.
Return: