diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index fc52fe59..86c9693f 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -2288,7 +2288,7 @@ def _fill_mock_geff_graph(graph_backend: BaseGraph) -> None: ) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, pl.Int64) - graph_backend.add_node_attr_key("ndfeature", pl.Object) + graph_backend.add_node_attr_key("ndfeature", pl.Array(pl.Float64, (3, 1)), np.ones((3, 1))) graph_backend.add_edge_attr_key("weight", pl.Float64) diff --git a/src/tracksdata/metrics/_ctc_metrics.py b/src/tracksdata/metrics/_ctc_metrics.py index 39640508..876cb54a 100644 --- a/src/tracksdata/metrics/_ctc_metrics.py +++ b/src/tracksdata/metrics/_ctc_metrics.py @@ -104,7 +104,8 @@ def _match_single_frame( # loading original group ids and filtering by the matches _mapped_ref = ref_group[reference_graph_key][rows_id].to_list() _mapped_comp = comp_group[input_graph_key][cols_id].to_list() - _ious = weights[rows_id, cols_id].tolist() + _ious = weights[rows_id, cols_id] + _ious = _ious.tolist() if _ious.size > 0 else [] LOG.info("Done!")