diff --git a/xai_components/base.py b/xai_components/base.py index ad14a2f3..7738a55e 100644 --- a/xai_components/base.py +++ b/xai_components/base.py @@ -72,6 +72,31 @@ def __deepcopy__(self, memo): memo[id_self] = _copy return _copy + def _put_at_index(self, idx: int, obj): + v = self._value + if v is None: + v = [] + if isinstance(v, tuple): + v = list(v) + if idx >= len(v): + v.extend([None] * (idx + 1 - len(v))) + v[idx] = obj + self._value = v + + class _IndexProxy: + def __init__(self, parent, idx): + self._p = parent + self._i = idx + def connect(self, ref): + self._p._put_at_index(self._i, ref) + + def __getitem__(self, idx): + # Enables: port[i].connect(other_port) + return InArg._IndexProxy(self, idx) + + def __setitem__(self, idx, value): + # Enables: port[i] = + self._put_at_index(idx, value) class InCompArg(Generic[T]): def __init__(self, value: T = None, getter: Callable[[T], any] = lambda x: x) -> None: @@ -265,6 +290,10 @@ def getter(x): return [] return [item.value if isinstance(item, (InArg, OutArg)) else item for item in x] + @staticmethod + def initial_value(): + return [] + class dynatuple(tuple): def __init__(self, *args): @@ -282,6 +311,11 @@ def resolve(item): return item return tuple(resolve(item) for item in x) + @staticmethod + def initial_value(): + return tuple() + + def parse_bool(value): if value is None: return None diff --git a/xircuits/compiler/generator.py b/xircuits/compiler/generator.py index 79f20999..6ceb6649 100644 --- a/xircuits/compiler/generator.py +++ b/xircuits/compiler/generator.py @@ -199,42 +199,54 @@ def connect_args(target, source): # Handle dynamic connections dynaports = [p for p in node.ports if - p.direction == 'in' and p.type != 'triangle-link' and p.dataType in DYNAMIC_PORTS] - ports_by_varName = {} - - RefOrValue = namedtuple('RefOrValue', ['value', 'is_ref']) # Renamed to RefOrValue + p.direction == 'in' and p.type != 'triangle-link' and p.dataType in DYNAMIC_PORTS] - # Group ports by varName - for port in dynaports: - if port.varName not in ports_by_varName: - ports_by_varName[port.varName] = [] - ports_by_varName[port.varName].append(port) + # Capture index N from port names like 'parameter-dynalist-dlist-2' + _name_idx_re = re.compile(r'-(\d+)\s*$') - for varName, ports in ports_by_varName.items(): - dynaport_values = [] + # Map: varName -> {index: port} + ports_by_varName = {} - for port in ports: - if port.source.id not in named_nodes: - value = _get_value_from_literal_port(port) - dynaport_values.append(RefOrValue(value, False)) - else: - # Handle named node references - value = "%s.%s" % (named_nodes[port.source.id], port.sourceLabel) # Variable reference - dynaport_values.append(RefOrValue(value, True)) - - if ports[0].dataType == 'dynatuple': - tuple_elements = [item.value if item.is_ref else repr(item.value) for item in dynaport_values] - if len(tuple_elements) == 1: - assignment_value = '(' + tuple_elements[0] + ',)' + for port in dynaports: + var_name = port.varName + name = port.name + # Extract index from name; default to 0 if no '-N' suffix + m = _name_idx_re.search(name) + idx = int(m.group(1)) if m else 0 + + ports_by_varName.setdefault(var_name, {}) + ports_by_varName[var_name].setdefault(idx, port) + + # Emit code per element in numeric order + for var_name, mapping in ports_by_varName.items(): + for i in sorted(mapping.keys()): + port = mapping[i] + target_indexed = f"{named_nodes[port.target.id]}.{var_name}[{i}]" + + if port.source.id in named_nodes: + # Component reference -> connect + source_ref = f"{named_nodes[port.source.id]}.{port.sourceLabel}" + init_code.append(ast.parse(f"{target_indexed}.connect({source_ref})")) else: - assignment_value = '(' + ', '.join(tuple_elements) + ')' - else: - list_elements = [item.value if item.is_ref else repr(item.value) for item in dynaport_values] - assignment_value = '[' + ', '.join(list_elements) + ']' + # Regex: matches e.g. 'Argument (string): argsName' + pattern = re.compile(r'^Argument \((.+?)\): (.+)$') + if port.source.file is None and port.source.name.startswith("Argument "): + match = pattern.match(port.source.name) + arg_type = type_mapping.get(match.group(1), 'any') + arg_name = match.group(2) + + if arg_name not in existing_args: + args_code.append(ast.parse(f"{arg_name}: InArg[{arg_type}]").body[0]) + existing_args.add(arg_name) + + init_code.append(ast.parse(f"{target_indexed}.connect(self.{arg_name})")) + else: + # Literal -> `[i] = ` + lit_value = _get_value_from_literal_port(port) + assign = ast.parse(f"{target_indexed} = 0") + assign.body[0].value = ast.parse(repr(lit_value)).body[0].value + init_code.append(assign) - assignment_target = "%s.%s" % (named_nodes[ports[0].target.id], ports[0].varName) - tpl = set_value(assignment_target, assignment_value) - init_code.append(tpl) # Handle output connections for i, port in enumerate(p for p in finish_node.ports if p.dataType == 'dynalist'):