diff --git a/src/motile_plugin/data_model/tracks.py b/src/motile_plugin/data_model/tracks.py index 8d0a549..a2d2598 100644 --- a/src/motile_plugin/data_model/tracks.py +++ b/src/motile_plugin/data_model/tracks.py @@ -163,6 +163,7 @@ def get_time(self, node: Node) -> int: return int(self.get_times([node])[0]) def set_times(self, nodes: Iterable[Node], times: Iterable[int]): + times = [int(t) for t in times] self._remove_from_seg_time_to_node(nodes) self._set_nodes_attr(nodes, self.time_attr, times) self._add_to_seg_time_to_node(nodes) @@ -176,7 +177,7 @@ def set_time(self, node: Any, time: int): time (int): The time to set """ - self.set_times([node], [time]) + self.set_times([node], [int(time)]) def get_seg_ids( self, nodes: Iterable[Node], required=False @@ -205,12 +206,13 @@ def set_seg_ids(self, nodes: Iterable[Node], seg_ids: Iterable[int]): node (Any): The node id to set the seg id of seg_id (int): The segmentation id to set for the node """ + seg_ids = [int(seg_id) for seg_id in seg_ids] self._remove_from_seg_time_to_node(nodes) self._set_nodes_attr(nodes, NodeAttr.SEG_ID.value, seg_ids) self._add_to_seg_time_to_node(nodes) def set_seg_id(self, node: Node, seg_id: int): - self.set_seg_ids([node], [seg_id]) + self.set_seg_ids([node], [int(seg_id)]) def add_nodes( self, @@ -493,8 +495,29 @@ def _save_graph(self, directory: Path): directory (Path): The directory in which to save the graph file. """ graph_file = directory / self.GRAPH_FILE + graph_data = nx.node_link_data(self.graph) + + def convert_np_types(data): + """Recursively convert numpy types to native Python types.""" + + if isinstance(data, dict): + return {key: convert_np_types(value) for key, value in data.items()} + elif isinstance(data, list): + return [convert_np_types(item) for item in data] + elif isinstance(data, np.ndarray): + return data.tolist() # Convert numpy arrays to Python lists + elif isinstance(data, np.integer): + return int(data) # Convert numpy integers to Python int + elif isinstance(data, np.floating): + return float(data) # Convert numpy floats to Python float + else: + return ( + data # Return the data as-is if it's already a native Python type + ) + + graph_data = convert_np_types(graph_data) with open(graph_file, "w") as f: - json.dump(nx.node_link_data(self.graph), f) + json.dump(graph_data, f) def _save_seg(self, directory: Path): """Save a segmentation as a numpy array using np.save. In the future, @@ -514,9 +537,7 @@ def _save_attrs(self, directory: Path): """ out_path = directory / self.ATTRS_FILE attrs_dict = { - "time_attr": self.time_attr - if not isinstance(self.time_attr, np.ndarray) - else self.time_attr.tolist(), + "time_attr": self.time_attr, "pos_attr": self.pos_attr if not isinstance(self.pos_attr, np.ndarray) else self.pos_attr.tolist(),