Skip to content

Commit 7a2b181

Browse files
authored
Merge pull request #3 from neu-pml/develop
Develop
2 parents 065e1eb + 74a8723 commit 7a2b181

File tree

5 files changed

+714
-11
lines changed

5 files changed

+714
-11
lines changed

discopy/cartesian.py

-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
from discopy.cat import factory
5656
from discopy.monoidal import Ty, assert_isatomic
5757

58-
5958
@factory
6059
class Diagram(symmetric.Diagram):
6160
"""
@@ -171,5 +170,4 @@ def __call__(self, other):
171170
return self.cod.ar.copy(self(other.dom), len(other.cod))
172171
return super().__call__(other)
173172

174-
175173
Id = Diagram.id

discopy/cat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def free_symbols(self) -> "set[sympy.Symbol]":
567567
def recursive_free_symbols(data):
568568
if isinstance(data, Mapping):
569569
data = data.values()
570-
if isinstance(data, Iterable):
570+
if not isinstance(data, str) and isinstance(data, Iterable):
571571
# Handles numpy 0-d arrays, which are actually not iterable.
572572
if not hasattr(data, "shape") or data.shape != ():
573573
return set().union(*map(recursive_free_symbols, data))

discopy/python.py

+67-8
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from dataclasses import dataclass
3434

3535
from discopy.cat import Composable, assert_iscomposable
36-
from discopy.monoidal import Whiskerable
36+
from discopy.monoidal import PRO, Whiskerable
3737

3838
Ty = tuple[type, ...]
3939

@@ -82,6 +82,58 @@ def is_tuple(typ: type) -> bool:
8282
"""
8383
return getattr(typ, "__origin__", typ) is tuple
8484

85+
class product:
86+
def __init__(self, *functions):
87+
if not functions:
88+
raise TypeError(repr(type(self).__name__) +
89+
' needs at least one argument')
90+
91+
_functions = []
92+
doms = []
93+
for function, dom in functions:
94+
if not callable(function):
95+
raise TypeError(repr(type(self).__name__) +
96+
' arguments must be callable')
97+
if isinstance(function, product):
98+
_functions = _functions + function.__wrapped__
99+
doms = doms + function._doms
100+
else:
101+
_functions.append(function)
102+
doms.append(dom)
103+
self.__wrapped__ = _functions
104+
self._doms = doms
105+
106+
def __call__(self, *args):
107+
i = 0
108+
result = ()
109+
for func, dom in zip(self.__wrapped__, self._doms):
110+
val = tuplify(func(*args[i:i+len(dom)]))
111+
result = result + val
112+
i += len(dom)
113+
return untuplify(result)
114+
115+
class compose:
116+
def __init__(self, *functions):
117+
if not functions:
118+
raise TypeError(repr(type(self).__name__) +
119+
' needs at least one argument')
120+
121+
_functions = []
122+
for function in reversed(functions):
123+
if not callable(function):
124+
raise TypeError(repr(type(self).__name__) +
125+
' arguments must be callable')
126+
127+
if isinstance(function, compose):
128+
_functions = _functions + function.__wrapped__
129+
else:
130+
_functions.append(function)
131+
self.__wrapped__ = _functions
132+
133+
def __call__(self, *values):
134+
for func in self.__wrapped__:
135+
values = func(*tuplify(values))
136+
return values
85137

86138
@dataclass
87139
class Function(Composable[Ty], Whiskerable):
@@ -119,7 +171,7 @@ def id(dom: Ty) -> Function:
119171
The identity function on a given tuple of types :code:`dom`.
120172
121173
Parameters:
122-
dom (python.Ty) : The typle of types on which to take the identity.
174+
dom (python.Ty) : The tuple of types on which to take the identity.
123175
"""
124176
return Function(lambda *xs: untuplify(xs), dom, dom)
125177

@@ -131,8 +183,13 @@ def then(self, other: Function) -> Function:
131183
other : The other function to compose in sequence.
132184
"""
133185
assert_iscomposable(self, other)
134-
return Function(
135-
lambda *args: other(*tuplify(self(*args))), self.dom, other.cod)
186+
187+
if self.inside == untuplify:
188+
return other
189+
if other.inside == untuplify:
190+
return self
191+
function = compose(other.inside, self.inside)
192+
return Function(function, self.dom, other.cod)
136193

137194
def __call__(self, *xs):
138195
return self.inside(*xs)
@@ -144,10 +201,12 @@ def tensor(self, other: Function) -> Function:
144201
Parameters:
145202
other : The other function to compose in sequence.
146203
"""
147-
def inside(*xs):
148-
left, right = xs[:len(self.dom)], xs[len(self.dom):]
149-
return untuplify(tuplify(self(*left)) + tuplify(other(*right)))
150-
return Function(inside, self.dom + other.dom, self.cod + other.cod)
204+
if self.dom == PRO(0) and self.inside == untuplify:
205+
return other
206+
if other.dom == PRO(0) and other.inside == untuplify:
207+
return self
208+
prod = product((self.inside, self.dom), (other.inside, other.dom))
209+
return Function(prod, self.dom + other.dom, self.cod + other.cod)
151210

152211
@staticmethod
153212
def swap(x: Ty, y: Ty) -> Function:

0 commit comments

Comments
 (0)