33
33
from dataclasses import dataclass
34
34
35
35
from discopy .cat import Composable , assert_iscomposable
36
- from discopy .monoidal import Whiskerable
36
+ from discopy .monoidal import PRO , Whiskerable
37
37
38
38
Ty = tuple [type , ...]
39
39
@@ -82,6 +82,58 @@ def is_tuple(typ: type) -> bool:
82
82
"""
83
83
return getattr (typ , "__origin__" , typ ) is tuple
84
84
85
+ class product :
86
+ def __init__ (self , * functions ):
87
+ if not functions :
88
+ raise TypeError (repr (type (self ).__name__ ) +
89
+ ' needs at least one argument' )
90
+
91
+ _functions = []
92
+ doms = []
93
+ for function , dom in functions :
94
+ if not callable (function ):
95
+ raise TypeError (repr (type (self ).__name__ ) +
96
+ ' arguments must be callable' )
97
+ if isinstance (function , product ):
98
+ _functions = _functions + function .__wrapped__
99
+ doms = doms + function ._doms
100
+ else :
101
+ _functions .append (function )
102
+ doms .append (dom )
103
+ self .__wrapped__ = _functions
104
+ self ._doms = doms
105
+
106
+ def __call__ (self , * args ):
107
+ i = 0
108
+ result = ()
109
+ for func , dom in zip (self .__wrapped__ , self ._doms ):
110
+ val = tuplify (func (* args [i :i + len (dom )]))
111
+ result = result + val
112
+ i += len (dom )
113
+ return untuplify (result )
114
+
115
+ class compose :
116
+ def __init__ (self , * functions ):
117
+ if not functions :
118
+ raise TypeError (repr (type (self ).__name__ ) +
119
+ ' needs at least one argument' )
120
+
121
+ _functions = []
122
+ for function in reversed (functions ):
123
+ if not callable (function ):
124
+ raise TypeError (repr (type (self ).__name__ ) +
125
+ ' arguments must be callable' )
126
+
127
+ if isinstance (function , compose ):
128
+ _functions = _functions + function .__wrapped__
129
+ else :
130
+ _functions .append (function )
131
+ self .__wrapped__ = _functions
132
+
133
+ def __call__ (self , * values ):
134
+ for func in self .__wrapped__ :
135
+ values = func (* tuplify (values ))
136
+ return values
85
137
86
138
@dataclass
87
139
class Function (Composable [Ty ], Whiskerable ):
@@ -119,7 +171,7 @@ def id(dom: Ty) -> Function:
119
171
The identity function on a given tuple of types :code:`dom`.
120
172
121
173
Parameters:
122
- dom (python.Ty) : The typle of types on which to take the identity.
174
+ dom (python.Ty) : The tuple of types on which to take the identity.
123
175
"""
124
176
return Function (lambda * xs : untuplify (xs ), dom , dom )
125
177
@@ -131,8 +183,13 @@ def then(self, other: Function) -> Function:
131
183
other : The other function to compose in sequence.
132
184
"""
133
185
assert_iscomposable (self , other )
134
- return Function (
135
- lambda * args : other (* tuplify (self (* args ))), self .dom , other .cod )
186
+
187
+ if self .inside == untuplify :
188
+ return other
189
+ if other .inside == untuplify :
190
+ return self
191
+ function = compose (other .inside , self .inside )
192
+ return Function (function , self .dom , other .cod )
136
193
137
194
def __call__ (self , * xs ):
138
195
return self .inside (* xs )
@@ -144,10 +201,12 @@ def tensor(self, other: Function) -> Function:
144
201
Parameters:
145
202
other : The other function to compose in sequence.
146
203
"""
147
- def inside (* xs ):
148
- left , right = xs [:len (self .dom )], xs [len (self .dom ):]
149
- return untuplify (tuplify (self (* left )) + tuplify (other (* right )))
150
- return Function (inside , self .dom + other .dom , self .cod + other .cod )
204
+ if self .dom == PRO (0 ) and self .inside == untuplify :
205
+ return other
206
+ if other .dom == PRO (0 ) and other .inside == untuplify :
207
+ return self
208
+ prod = product ((self .inside , self .dom ), (other .inside , other .dom ))
209
+ return Function (prod , self .dom + other .dom , self .cod + other .cod )
151
210
152
211
@staticmethod
153
212
def swap (x : Ty , y : Ty ) -> Function :
0 commit comments