From 4ad4cc9c2ef1a96b095478b48a253d77796131cb Mon Sep 17 00:00:00 2001 From: richard-rogers <93153899+richard-rogers@users.noreply.github.com> Date: Thu, 26 Sep 2024 00:36:19 +0100 Subject: [PATCH] Handle explicit segmentation keys (#1566) ## Description Allow `segment_key_values` in addition to segmented `DatasetSchema` - [ ] I have reviewed the [Guidelines for Contributing](CONTRIBUTING.md) and the [Code of Conduct](CODE_OF_CONDUCT.md). --- python/tests/api/logger/test_segments.py | 89 +++++++++++++++++++ .../api/writer/test_whylabs_integration.py | 43 +++++++++ python/whylogs/api/logger/logger.py | 6 +- .../whylogs/api/logger/segment_processing.py | 26 ++++-- python/whylogs/core/segmentation_partition.py | 11 +++ python/whylogs/migration/converters.py | 3 +- 6 files changed, 167 insertions(+), 11 deletions(-) diff --git a/python/tests/api/logger/test_segments.py b/python/tests/api/logger/test_segments.py index 979274fd30..6b48f57142 100644 --- a/python/tests/api/logger/test_segments.py +++ b/python/tests/api/logger/test_segments.py @@ -95,6 +95,60 @@ def test_single_column_segment() -> None: assert cardinality == 1.0 +def test_single_column_and_manual_segment() -> None: + input_rows = 100 + segment_column = "col3" + number_of_segments = 5 + d = { + "col1": [i for i in range(input_rows)], + "col2": [i * i * 1.1 for i in range(input_rows)], + segment_column: [f"x{str(i%number_of_segments)}" for i in range(input_rows)], + } + + df = pd.DataFrame(data=d) + test_segments = segment_on_column("col3") + results: SegmentedResultSet = why.log( + df, schema=DatasetSchema(segments=test_segments), segment_key_values={"zzz": "foo", "ver": 1} + ) + assert results.count == number_of_segments + partitions = results.partitions + assert len(partitions) == 1 + partition = partitions[0] + segments = results.segments_in_partition(partition) + assert len(segments) == number_of_segments + + first_segment = next(iter(segments)) + assert first_segment.key == ("x0", "1", "foo") + first_segment_profile = results.profile(first_segment) + assert first_segment_profile is not None + assert first_segment_profile._columns["col1"]._schema.dtype == np.int64 + assert first_segment_profile._columns["col2"]._schema.dtype == np.float64 + assert first_segment_profile._columns["col3"]._schema.dtype.name == "object" + segment_cardinality: CardinalityMetric = ( + first_segment_profile.view().get_column(segment_column).get_metric("cardinality") + ) + cardinality = segment_cardinality.estimate + assert cardinality is not None + assert cardinality == 1.0 + + +def test_throw_on_duplicate_keys() -> None: + input_rows = 100 + segment_column = "col3" + number_of_segments = 5 + d = { + "col1": [i for i in range(input_rows)], + "col2": [i * i * 1.1 for i in range(input_rows)], + segment_column: [f"x{str(i%number_of_segments)}" for i in range(input_rows)], + } + + df = pd.DataFrame(data=d) + test_segments = segment_on_column("col3") + + with pytest.raises(ValueError): + why.log(df, schema=DatasetSchema(segments=test_segments), segment_key_values={segment_column: "foo"}) + + def test_single_column_segment_with_trace_id() -> None: input_rows = 100 segment_column = "col3" @@ -312,6 +366,41 @@ def test_multi_column_segment() -> None: assert count == 1 +def test_multicolumn_and_manual_segment() -> None: + input_rows = 100 + d = { + "col1": [i for i in range(input_rows)], + "col2": [i * i * 1.1 for i in range(input_rows)], + "col3": [f"x{str(i%5)}" for i in range(input_rows)], + } + + df = pd.DataFrame(data=d) + segmentation_partition = SegmentationPartition( + name="col1,col3", mapper=ColumnMapperFunction(col_names=["col1", "col3"]) + ) + test_segments = {segmentation_partition.name: segmentation_partition} + results: SegmentedResultSet = why.log( + df, schema=DatasetSchema(segments=test_segments), segment_key_values={"ver": 42, "zzz": "bar"} + ) + segments = results.segments() + last_segment = segments[-1] + + # Note this segment is not useful as there is only one datapoint per segment, we have 100 rows and + # 100 segments. The segment value is a tuple of strings identifying this segment. + assert last_segment.key == ("99", "x4", "42", "bar") + + last_segment_profile = results.profile(last_segment) + + assert last_segment_profile._columns["col1"]._schema.dtype == np.int64 + assert last_segment_profile._columns["col2"]._schema.dtype == np.float64 + assert last_segment_profile._columns["col3"]._schema.dtype.name == "object" + + segment_distribution: DistributionMetric = last_segment_profile.view().get_column("col1").get_metric("distribution") + count = segment_distribution.n + assert count is not None + assert count == 1 + + def test_multi_column_segment_serialization_roundtrip_v0(tmp_path: Any) -> None: input_rows = 35 d = { diff --git a/python/tests/api/writer/test_whylabs_integration.py b/python/tests/api/writer/test_whylabs_integration.py index 84b6430917..4a390cc6fd 100644 --- a/python/tests/api/writer/test_whylabs_integration.py +++ b/python/tests/api/writer/test_whylabs_integration.py @@ -424,6 +424,49 @@ def test_whylabs_writer_segmented(zipped: bool): assert deserialized_view.get_columns().keys() == data.keys() +@pytest.mark.load +def test_whylabs_writer_explicit_segmented(): + ORG_ID = _get_org() + MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID") + why.init(reinit=True, force_local=True) + schema = DatasetSchema(segments=segment_on_column("col1")) + data = {"col1": [1, 2, 1, 3, 2, 2], "col2": ["foo", "bar", "wat", "foo", "baz", "wat"]} + df = pd.DataFrame(data) + trace_id = str(uuid4()) + profile = why.log(df, schema=schema, trace_id=trace_id, segment_key_values={"version": "1.0.0"}) + + assert profile.count == 3 + partitions = profile.partitions + assert len(partitions) == 1 + partition = partitions[0] + segments = profile.segments_in_partition(partition) + assert len(segments) == 3 + + first_segment = next(iter(segments)) + assert first_segment.key == ("1", "1.0.0") + + writer = WhyLabsWriter() + success, status = writer.write(profile) + assert success + time.sleep(SLEEP_TIME) # platform needs time to become aware of the profile + dataset_api = DatasetProfileApi(writer._api_client) + response: ProfileTracesResponse = dataset_api.get_profile_traces( + org_id=ORG_ID, + dataset_id=MODEL_ID, + trace_id=trace_id, + ) + assert len(response.get("traces")) == 3 + for trace in response.get("traces"): + download_url = trace.get("download_url") + headers = {"Content-Type": "application/octet-stream"} + downloaded_profile = writer._s3_pool.request( + "GET", download_url, headers=headers, timeout=writer._timeout_seconds + ) + deserialized_view = DatasetProfileView.deserialize(downloaded_profile.data) + assert deserialized_view._metadata["whylogs.tag.version"] == "1.0.0" + assert deserialized_view.get_columns().keys() == data.keys() + + @pytest.mark.load @pytest.mark.parametrize( "segmented,zipped", diff --git a/python/whylogs/api/logger/logger.py b/python/whylogs/api/logger/logger.py index 451444ca90..52605c581d 100644 --- a/python/whylogs/api/logger/logger.py +++ b/python/whylogs/api/logger/logger.py @@ -113,17 +113,13 @@ def log( # If segments are defined use segment_processing to return a SegmentedResultSet if active_schema and active_schema.segments: - if segment_key_values: - raise ValueError( - f"using explicit `segment_key_values` {segment_key_values} is not compatible " - f"with segmentation also defined in the DatasetSchema: {active_schema.segments}" - ) segmented_results: SegmentedResultSet = segment_processing( schema=active_schema, obj=obj, pandas=pandas, row=row, segment_cache=self._segment_cache, + segment_key_values=segment_key_values, ) # Update the existing segmented_results metadata with the trace_id and other keys if not present _populate_common_profile_metadata(segmented_results.metadata, trace_id=trace_id, tags=tags) diff --git a/python/whylogs/api/logger/segment_processing.py b/python/whylogs/api/logger/segment_processing.py index 5a285c26ae..f4bd30dfbc 100644 --- a/python/whylogs/api/logger/segment_processing.py +++ b/python/whylogs/api/logger/segment_processing.py @@ -39,7 +39,7 @@ def _process_segment( segments[segment_key] = profile -def _get_segment_from_group_key(group_key, partition_id) -> Tuple[str, ...]: +def _get_segment_from_group_key(group_key, partition_id, explicit_keys: Tuple[str, ...] = ()) -> Tuple[str, ...]: if isinstance(group_key, str): segment_tuple_key: Tuple[str, ...] = (group_key,) elif isinstance(group_key, (List, Iterable, Iterator)): @@ -47,7 +47,7 @@ def _get_segment_from_group_key(group_key, partition_id) -> Tuple[str, ...]: else: segment_tuple_key = (str(group_key),) - return Segment(segment_tuple_key, partition_id) + return Segment(segment_tuple_key + explicit_keys, partition_id) def _is_nan(x): @@ -65,7 +65,11 @@ def _process_simple_partition( pandas: Optional[pd.DataFrame] = None, row: Optional[Mapping[str, Any]] = None, segment_cache: Optional[SegmentCache] = None, + segment_key_values: Optional[Dict[str, str]] = None, ): + explicit_keys = ( + tuple(str(segment_key_values[k]) for k in sorted(segment_key_values.keys())) if segment_key_values else tuple() + ) if pandas is not None: # simple means we can segment on column values grouped_data = pandas.groupby(columns) @@ -81,11 +85,11 @@ def _process_simple_partition( pandas_segment = pandas[mask] else: pandas_segment = grouped_data.get_group(group) - segment_key = _get_segment_from_group_key(group, partition_id) + segment_key = _get_segment_from_group_key(group, partition_id, explicit_keys) _process_segment(pandas_segment, segment_key, segments, schema, segment_cache) elif row: # TODO: consider if we need to combine with the column names - segment_key = Segment(tuple(str(row[element]) for element in columns), partition_id) + segment_key = Segment(tuple(str(row[element]) for element in columns) + explicit_keys, partition_id) _process_segment(row, segment_key, segments, schema, segment_cache) @@ -129,6 +133,7 @@ def _log_segment( pandas: Optional[pd.DataFrame] = None, row: Optional[Mapping[str, Any]] = None, segment_cache: Optional[SegmentCache] = None, + segment_key_values: Optional[Dict[str, str]] = None, ) -> Dict[Segment, Any]: segments: Dict[Segment, Any] = {} pandas, row = _pandas_or_dict(obj, pandas, row) @@ -137,7 +142,13 @@ def _log_segment( if partition.simple: columns = partition.mapper.col_names if partition.mapper else None if columns: - _process_simple_partition(partition.id, schema, segments, columns, pandas, row, segment_cache) + _process_simple_partition( + partition.id, schema, segments, columns, pandas, row, segment_cache, segment_key_values + ) + else: + logger.error( + "Segmented DatasetSchema defines no segments; use an unsegmented DatasetSchema or specify columns to segment on." + ) else: raise NotImplementedError("custom mapped segments not yet implemented") return segments @@ -149,6 +160,7 @@ def segment_processing( pandas: Optional[pd.DataFrame] = None, row: Optional[Dict[str, Any]] = None, segment_cache: Optional[SegmentCache] = None, + segment_key_values: Optional[Dict[str, str]] = None, ) -> SegmentedResultSet: number_of_partitions = len(schema.segments) logger.info(f"The specified schema defines segments with {number_of_partitions} partitions.") @@ -160,6 +172,9 @@ def segment_processing( for partition_name in schema.segments: segment_partition = schema.segments[partition_name] + if segment_partition.mapper and segment_key_values: + segment_partition.mapper.set_explicit_names(segment_key_values.keys()) + logger.info(f"Processing partition with name({partition_name})") logger.debug(f"{partition_name}: is simple ({segment_partition.simple}), id ({segment_partition.id})") if segment_partition.filter: @@ -176,6 +191,7 @@ def segment_processing( pandas=pandas, row=row, segment_cache=segment_cache, + segment_key_values=segment_key_values, ) segmented_profiles[segment_partition.id] = partition_segments segment_partitions.append(segment_partition) diff --git a/python/whylogs/core/segmentation_partition.py b/python/whylogs/core/segmentation_partition.py index bbe640d8aa..ae087a6dbd 100644 --- a/python/whylogs/core/segmentation_partition.py +++ b/python/whylogs/core/segmentation_partition.py @@ -26,6 +26,17 @@ def __post_init__(self): column_string = ",".join(sorted(self.col_names)) segment_hash = hashlib.sha512(bytes(column_string + mapper_string, encoding="utf8")) self.id = segment_hash.hexdigest() + self.explicit_names = list() + + def set_explicit_names(self, key_names: List[str] = []) -> None: + if self.col_names: + for name in key_names: + if name in self.col_names: + raise ValueError( + f"Cannot have segmentation key {name} as both a column name and explicit segment key" + ) + + self.explicit_names = sorted(key_names) @dataclass diff --git a/python/whylogs/migration/converters.py b/python/whylogs/migration/converters.py index afae6a41cf..96c85f7e30 100644 --- a/python/whylogs/migration/converters.py +++ b/python/whylogs/migration/converters.py @@ -78,8 +78,9 @@ def _generate_segment_tags_metadata( segment_tags = [] col_names = partition.mapper.col_names + explicit_names = partition.mapper.explicit_names - for index, column_name in enumerate(col_names): + for index, column_name in enumerate(col_names + explicit_names): segment_tags.append(SegmentTag(key=_TAG_PREFIX + column_name, value=segment.key[index])) else: raise NotImplementedError(