Skip to content

Commit

Permalink
Replace unnecessary parameter assignments produced by directed wires …
Browse files Browse the repository at this point in the history
…with symbol identifications
  • Loading branch information
GJHSimmons committed Jul 25, 2024
1 parent c43ea1e commit f1ecf78
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 73 deletions.
183 changes: 114 additions & 69 deletions psymple/ported_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ class DefaultParameterAssignment(ParameterAssignment):
"""
pass

class FunctionalAssignment(ParameterAssignment):
"""
A convenience class to identify parameters which have been constructed from the OutputPort
of a FunctionalPortedObject. These represent the core functional building blocks of a
System."""


class SymbolIdentification:
"""
Expand Down Expand Up @@ -217,11 +223,18 @@ def __init__(self, port, assignment):
self.assignment = deepcopy(assignment)

def substitute_symbol(self, old_symbol, new_symbol):
self.assignment.substitute_symbol(old_symbol, new_symbol)
# In case the symbol of this port itself was substituted,
# this will be reflected in the assignment, and we can pull
# the updated name from there.
self.name = self.assignment.name
if self.assignment:
self.assignment.substitute_symbol(old_symbol, new_symbol)
# In case the symbol of this port itself was substituted,
# this will be reflected in the assignment, and we can pull
# the updated name from there.
self.name = self.assignment.name
else:
assert isinstance(old_symbol, sym.Symbol)
assert isinstance(new_symbol, sym.Symbol)
assert isinstance(self.symbol, sym.Symbol)
if self.symbol == old_symbol:
self.name = str(new_symbol)

def __repr__(self):
return f"{type(self).__name__} {self.name} with {self.assignment}"
Expand All @@ -235,19 +248,13 @@ class CompiledOutputPort(CompiledPort):
pass



class CompiledInputPort(CompiledPort):
def __init__(self, port):
assert isinstance(port, InputPort)
super().__init__(port, None)
self.default_value = port.default_value

def substitute_symbol(self, old_symbol, new_symbol):
assert isinstance(old_symbol, sym.Symbol)
assert isinstance(new_symbol, sym.Symbol)
assert isinstance(self.symbol, sym.Symbol)
if self.symbol == old_symbol:
self.name = str(new_symbol)


class PortedObject(ABC):
def __init__(
Expand Down Expand Up @@ -277,7 +284,7 @@ def check_existing_port_names(self, port: Port):
f"Port with name '{port.name}' doubly defined in PortedObject '{self.name}'."
)

def parse_port_entry(self, port_info: Port | dict | str, port_type: Port):
def parse_port_entry(self, port_info: Port | dict | tuple | str, port_type: Port):
if isinstance(port_info, port_type):
port = port_info
elif isinstance(port_info, dict):
Expand All @@ -288,6 +295,19 @@ def parse_port_entry(self, port_info: Port | dict | str, port_type: Port):
raise ValidationError(
f'The dictionary {port_info} must have a "name" entry.'
)
elif isinstance(port_info, tuple):
name = port_info[0]
if issubclass(port_type, InputPort):
port = port_type(
name=name,
default_value=port_info[1] if len(port_info) >= 2 else None,
description=port_info[2] if len(port_info) >=3 else None
)
else:
port = port_type(
name=name,
description=port_info[1] if len(port_info) >=2 else None,
)
elif isinstance(port_info, str):
port = port_type(name=port_info)
else:
Expand Down Expand Up @@ -358,7 +378,7 @@ def parse_assignment_entry(
elif isinstance(assignment_info, dict):
keys = assignment_info.keys()
if "expression" in keys:
if assignment_type is DifferentialAssignment:
if issubclass(assignment_type, DifferentialAssignment):
if "variable" in keys:
return assignment_type(
assignment_info["variable"], assignment_info["expression"]
Expand All @@ -367,7 +387,7 @@ def parse_assignment_entry(
raise ValidationError(
f'The dictionary {assignment_info} must contain a key "variable" to define a differential assignment'
)
if assignment_type is ParameterAssignment:
if issubclass(assignment_type, ParameterAssignment):
if "parameter" in keys:
return assignment_type(
assignment_info["parameter"], assignment_info["expression"]
Expand Down Expand Up @@ -572,6 +592,10 @@ class FunctionalPortedObject(PortedObject):
Note that function assignments whose expression references a parameter defined as
the function value of another expression are not allowed.
Methods:
add_assignments
compile
"""

#TODO: In the future, this should be a composite ported object that
Expand All @@ -595,21 +619,14 @@ def __init__(
Construct a FunctionalPortedObject from a list of assignments specifying functions.
Args:
name: a string which must be unique for each VariablePortedObject inside a common
name: a unique identifier for each VariablePortedObject inside a common
CompositePortedObject.
input_ports: list of input ports to expose. Elements should be of type InputPort,
input_ports: input ports to expose. Elements should be of type InputPort,
dict or str.
assignments: list of functional assignments. Elements should be of type
assignments: functional assignments. Elements should be of type
ParameterAssignment, dict or tuple.
create_input_ports: automatically expose all function arguments which aren't specified
in the list input_ports as input ports.
Example:
FunctionalPortedObject(
name = "A",
input_ports = ["x", "y"],
assignments = [("f", "2*x"), ("g", "3*y")],
)
"""
#TODO: Functional ported objects should take lists of assignments to a list of output port
super().__init__(name, input_ports=input_ports)
Expand All @@ -629,7 +646,7 @@ def add_assignments(self, *assignments: list[ParameterAssignment|dict|tuple], cr
"""
for assignment_info in assignments:
assignment = self.parse_assignment_entry(assignment_info, ParameterAssignment)
assignment = self.parse_assignment_entry(assignment_info, FunctionalAssignment)
parameter_name = str(assignment.parameter.symbol)
if parameter_name in self.assignments:
raise ValueError(
Expand All @@ -651,7 +668,7 @@ def add_assignments(self, *assignments: list[ParameterAssignment|dict|tuple], cr

def add_assignment(self, assignment_info: ParameterAssignment|dict|tuple, create_input_ports=True):
# DEPRECATE?
assignment = self.parse_assignment_entry(assignment_info, ParameterAssignment)
assignment = self.parse_assignment_entry(assignment_info, FunctionalAssignment)
parameter_name = str(assignment.parameter.symbol)
if parameter_name in self.assignments:
raise ValueError(
Expand Down Expand Up @@ -772,33 +789,33 @@ def add_wires(self, variable_wires: list = [], directed_wires: list = []):
elif isinstance(wire_info, tuple):
self.add_variable_aggregation_wiring(
child_ports=wire_info[0],
parent_port=wire_info[1] or None,
parent_port=wire_info[1] if len(wire_info)>=2 else None,
output_name=wire_info[2] if len(wire_info)==3 else None,
)
else:
raise ValidationError(f"The element {wire_info} is not a dictionary or tuple")
raise ValidationError(f"The information {wire_info} is not a dictionary or tuple")

for wire_info in directed_wires:
if isinstance(wire_info, dict):
keys = wire_info.keys()
if keys == {"source", "destinations"}:
self.add_directed_wire(wire_info["source"], wire_info["destinations"])
if keys == {"source", "destinations"} or keys == {"source", "destination"}:
self.add_directed_wire(*wire_info.values())
else:
raise ValidationError(
f'The dictionary {wire_info} must contain keys "source" and "destinations".'
f'The dictionary {wire_info} must contain keys "source" and either "destination" or "destinations".'
)
elif isinstance(wire_info, tuple):
self.add_directed_wire(wire_info[0], wire_info[1])
self.add_directed_wire(*wire_info)
else:
raise ValidationError(f"The element {wire_info} is not a dictionary or tuple")


def add_directed_wire(self, source_name: str, destination_names: list[str]):
def add_directed_wire(self, source_name: str, destination_names: str | list[str]):
source_port = self.get_port_by_name(source_name)
if source_port is None:
raise WiringError(
f"Incorrect wiring in '{self.name}'. "
f"Destination port '{source_name}' does not exist."
f"Source port '{source_name}' does not exist."
)
if (
source_name in self.output_ports
Expand All @@ -807,9 +824,12 @@ def add_directed_wire(self, source_name: str, destination_names: list[str]):
):
# Source must be: own input, or child output, or child variable
raise WiringError(
f"Incorrect wiring in '{self.name}'. Destination port '{source_name}' "
f"Incorrect wiring in '{self.name}'. Source port '{source_name}' "
"must be an input port or a child output/variable port."
)
# If a singular destination is specified, coerce it into a list
if isinstance(destination_names, str):
destination_names = [destination_names]
for destination_name in destination_names:
destination_port = self.get_port_by_name(destination_name)
if destination_port is None:
Expand Down Expand Up @@ -841,9 +861,6 @@ def add_variable_aggregation_wiring(
parent_port: str = None,
output_name: str = None,
):
# TODO: This should become reimplemented by using a composite
# ported object that contains building blocks modeling this behavior.
# These are names of ports
# All ports must be variable ports.
# Parent port (if provided) should be port of the object itself
if parent_port is not None:
Expand Down Expand Up @@ -929,49 +946,75 @@ def compile(self, prefix_names=False):
compiled.internal_parameter_assignments.update(
child.internal_parameter_assignments
)
# Pass forward assignments from output ports
# Pass forward assignments from output ports. Assignments may later be exposed
# at an output port by a directed wire.
# TODO: Unconnected output ports are an indication that something may be wrong
# If an output port is not connected, we could consider discarding it
for name, port in child.output_ports.items():
assg = port.assignment
compiled.internal_parameter_assignments[assg.name] = assg
if assg := port.assignment:
compiled.internal_parameter_assignments[assg.name] = assg

# Process directed wires. In the process, we check which child input
# ports don't have an incoming wire using unconnected_child_input_ports.
# Process directed wires. We first determine the port which produces the wire symbol,
# which depends on if the wire connects to output ports or not.
for wire in self.directed_wires:

# Directed wires connect:
# - an input port to child input ports, or;
# - a child output port to child input ports and at most one output port, or;
# - a child variable port to child input ports.
# We take cases on the number of output ports a directed wire connects to.
outputs = [port for port in self.output_ports if port in wire.destination_ports]
num_outputs = len(outputs)
if num_outputs > 1:
# No wire can point to more than one output port
raise WiringError(
f"Incorrect wiring in '{self.name}'. "
f"Directed wire from port {wire.source_port} "
"is connected to two different output ports. "
)
elif num_outputs == 1:
# A wire ending at an output port can only start at a child output port.
source = compiled.get_port_by_name(wire.source_port)
if type(source) is CompiledOutputPort:
wire_root = self.get_port_by_name(outputs[0])
else:
raise WiringError(
f"Incorrect wiring in '{self.name}'. "
"A DirectedWire pointing to an output port must start at "
f"a child OutputPort, not {wire.source_port} ."
)
else:
# The wire has only internal destinations.
wire_root = compiled.get_port_by_name(wire.source_port)

# Now we perform the identifications. In the process we check which child ports
# don't have an incoming wire using unconnected_child_input_ports.
for destination_port in wire.destination_ports:
if destination_port in unconnected_child_input_ports:
# Goes from own input or child output port to child input port.
# In all of these cases, the ports have been pre-compiled
source = compiled.get_port_by_name(wire.source_port)
destination = compiled.get_port_by_name(destination_port)
assert type(destination) is CompiledInputPort
# We're dropping the destination symbol in favor of the source
assg = SymbolIdentification(source.symbol, destination.symbol)
compiled.symbol_identifications.append(assg)
# Substitute the destination symbol for the wire symbol
symb_id = SymbolIdentification(wire_root.symbol, destination.symbol)
compiled.symbol_identifications.append(symb_id)
unconnected_child_input_ports.pop(destination_port)
elif destination_port in self.output_ports:
# We can only be in this case if the source is a child output port,
# which has already been compiled
source = compiled.get_port_by_name(wire.source_port)
destination = self.get_port_by_name(destination_port)
assert type(destination) is OutputPort
if self.is_own_port(wire.source_port):
# Goes from own input port to own output port.
assg_out = ParameterAssignment(destination.symbol, source.symbol)
compiled.output_ports[destination.name] = CompiledOutputPort(
destination, assg_out
)
# TODO: I don't see a use case for this
# Raising an error because I haven't tested the code above
raise ValueError(
"Why are you connecting an input port directly to an output port?"
)
else:
# Goes from child output/variable port to own output port.
# We create a compiled output port
assg = ParameterAssignment(destination.symbol, source.symbol)
compiled.output_ports[destination.name] = CompiledOutputPort(
destination, assg
)
# Substitute the source symbol for the output port symbol
symb_id = SymbolIdentification(wire_root.symbol, source.symbol)
compiled.symbol_identifications.append(symb_id)
# Pass forward the assignment at source, currently stored as an
# internal parameter assignment, to the output port.
source_assg = compiled.internal_parameter_assignments.pop(source.name)
compiled.output_ports[destination.name] = CompiledOutputPort(
destination,
source_assg,
)
else:
raise WiringError(
f"Incorrect wiring in '{self.name}'. "
Expand Down Expand Up @@ -1014,7 +1057,7 @@ def compile(self, prefix_names=False):
raise ValueError(
f"Inconsistent initial values for variable {wiring.parent_port}: {initial_values}."
)
elif initial_values:
elif initial_values:
initial_value = initial_values.pop()
else:
initial_value = None
Expand Down Expand Up @@ -1152,9 +1195,8 @@ def set_input_parameters(self, parameter_assignments=[]):
# - Process those ports with default values to DefaultParameterAssignments
# - Those input ports with no default should carry to the system, but not simulation
# - The ability to set or change parameters should move to a system property

default_input_ports = []
for name, port in self.input_ports.items():
default_input_ports = []
if port.default_value is not None:
new_assg = DefaultParameterAssignment(name, port.default_value)
self.internal_parameter_assignments[name] = new_assg
Expand All @@ -1178,6 +1220,9 @@ def set_input_parameters(self, parameter_assignments=[]):
self.input_ports = {}
"""

def get_free_inputs(self):
return self.input_ports.values()

def get_assignments(self):
# Should this get done on instantiation?
self.set_input_parameters()
Expand Down
Loading

0 comments on commit f1ecf78

Please sign in to comment.