Skip to content

Commit

Permalink
Handle explicit segmentation keys (#1566)
Browse files Browse the repository at this point in the history
## Description

Allow `segment_key_values` in addition to segmented `DatasetSchema`


- [ ] I have reviewed the [Guidelines for Contributing](CONTRIBUTING.md)
and the [Code of Conduct](CODE_OF_CONDUCT.md).
  • Loading branch information
richard-rogers authored Sep 25, 2024
1 parent 8b8bff7 commit 4ad4cc9
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 11 deletions.
89 changes: 89 additions & 0 deletions python/tests/api/logger/test_segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,60 @@ def test_single_column_segment() -> None:
assert cardinality == 1.0


def test_single_column_and_manual_segment() -> None:
input_rows = 100
segment_column = "col3"
number_of_segments = 5
d = {
"col1": [i for i in range(input_rows)],
"col2": [i * i * 1.1 for i in range(input_rows)],
segment_column: [f"x{str(i%number_of_segments)}" for i in range(input_rows)],
}

df = pd.DataFrame(data=d)
test_segments = segment_on_column("col3")
results: SegmentedResultSet = why.log(
df, schema=DatasetSchema(segments=test_segments), segment_key_values={"zzz": "foo", "ver": 1}
)
assert results.count == number_of_segments
partitions = results.partitions
assert len(partitions) == 1
partition = partitions[0]
segments = results.segments_in_partition(partition)
assert len(segments) == number_of_segments

first_segment = next(iter(segments))
assert first_segment.key == ("x0", "1", "foo")
first_segment_profile = results.profile(first_segment)
assert first_segment_profile is not None
assert first_segment_profile._columns["col1"]._schema.dtype == np.int64
assert first_segment_profile._columns["col2"]._schema.dtype == np.float64
assert first_segment_profile._columns["col3"]._schema.dtype.name == "object"
segment_cardinality: CardinalityMetric = (
first_segment_profile.view().get_column(segment_column).get_metric("cardinality")
)
cardinality = segment_cardinality.estimate
assert cardinality is not None
assert cardinality == 1.0


def test_throw_on_duplicate_keys() -> None:
input_rows = 100
segment_column = "col3"
number_of_segments = 5
d = {
"col1": [i for i in range(input_rows)],
"col2": [i * i * 1.1 for i in range(input_rows)],
segment_column: [f"x{str(i%number_of_segments)}" for i in range(input_rows)],
}

df = pd.DataFrame(data=d)
test_segments = segment_on_column("col3")

with pytest.raises(ValueError):
why.log(df, schema=DatasetSchema(segments=test_segments), segment_key_values={segment_column: "foo"})


def test_single_column_segment_with_trace_id() -> None:
input_rows = 100
segment_column = "col3"
Expand Down Expand Up @@ -312,6 +366,41 @@ def test_multi_column_segment() -> None:
assert count == 1


def test_multicolumn_and_manual_segment() -> None:
input_rows = 100
d = {
"col1": [i for i in range(input_rows)],
"col2": [i * i * 1.1 for i in range(input_rows)],
"col3": [f"x{str(i%5)}" for i in range(input_rows)],
}

df = pd.DataFrame(data=d)
segmentation_partition = SegmentationPartition(
name="col1,col3", mapper=ColumnMapperFunction(col_names=["col1", "col3"])
)
test_segments = {segmentation_partition.name: segmentation_partition}
results: SegmentedResultSet = why.log(
df, schema=DatasetSchema(segments=test_segments), segment_key_values={"ver": 42, "zzz": "bar"}
)
segments = results.segments()
last_segment = segments[-1]

# Note this segment is not useful as there is only one datapoint per segment, we have 100 rows and
# 100 segments. The segment value is a tuple of strings identifying this segment.
assert last_segment.key == ("99", "x4", "42", "bar")

last_segment_profile = results.profile(last_segment)

assert last_segment_profile._columns["col1"]._schema.dtype == np.int64
assert last_segment_profile._columns["col2"]._schema.dtype == np.float64
assert last_segment_profile._columns["col3"]._schema.dtype.name == "object"

segment_distribution: DistributionMetric = last_segment_profile.view().get_column("col1").get_metric("distribution")
count = segment_distribution.n
assert count is not None
assert count == 1


def test_multi_column_segment_serialization_roundtrip_v0(tmp_path: Any) -> None:
input_rows = 35
d = {
Expand Down
43 changes: 43 additions & 0 deletions python/tests/api/writer/test_whylabs_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,49 @@ def test_whylabs_writer_segmented(zipped: bool):
assert deserialized_view.get_columns().keys() == data.keys()


@pytest.mark.load
def test_whylabs_writer_explicit_segmented():
ORG_ID = _get_org()
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(reinit=True, force_local=True)
schema = DatasetSchema(segments=segment_on_column("col1"))
data = {"col1": [1, 2, 1, 3, 2, 2], "col2": ["foo", "bar", "wat", "foo", "baz", "wat"]}
df = pd.DataFrame(data)
trace_id = str(uuid4())
profile = why.log(df, schema=schema, trace_id=trace_id, segment_key_values={"version": "1.0.0"})

assert profile.count == 3
partitions = profile.partitions
assert len(partitions) == 1
partition = partitions[0]
segments = profile.segments_in_partition(partition)
assert len(segments) == 3

first_segment = next(iter(segments))
assert first_segment.key == ("1", "1.0.0")

writer = WhyLabsWriter()
success, status = writer.write(profile)
assert success
time.sleep(SLEEP_TIME) # platform needs time to become aware of the profile
dataset_api = DatasetProfileApi(writer._api_client)
response: ProfileTracesResponse = dataset_api.get_profile_traces(
org_id=ORG_ID,
dataset_id=MODEL_ID,
trace_id=trace_id,
)
assert len(response.get("traces")) == 3
for trace in response.get("traces"):
download_url = trace.get("download_url")
headers = {"Content-Type": "application/octet-stream"}
downloaded_profile = writer._s3_pool.request(
"GET", download_url, headers=headers, timeout=writer._timeout_seconds
)
deserialized_view = DatasetProfileView.deserialize(downloaded_profile.data)
assert deserialized_view._metadata["whylogs.tag.version"] == "1.0.0"
assert deserialized_view.get_columns().keys() == data.keys()


@pytest.mark.load
@pytest.mark.parametrize(
"segmented,zipped",
Expand Down
6 changes: 1 addition & 5 deletions python/whylogs/api/logger/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,13 @@ def log(

# If segments are defined use segment_processing to return a SegmentedResultSet
if active_schema and active_schema.segments:
if segment_key_values:
raise ValueError(
f"using explicit `segment_key_values` {segment_key_values} is not compatible "
f"with segmentation also defined in the DatasetSchema: {active_schema.segments}"
)
segmented_results: SegmentedResultSet = segment_processing(
schema=active_schema,
obj=obj,
pandas=pandas,
row=row,
segment_cache=self._segment_cache,
segment_key_values=segment_key_values,
)
# Update the existing segmented_results metadata with the trace_id and other keys if not present
_populate_common_profile_metadata(segmented_results.metadata, trace_id=trace_id, tags=tags)
Expand Down
26 changes: 21 additions & 5 deletions python/whylogs/api/logger/segment_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def _process_segment(
segments[segment_key] = profile


def _get_segment_from_group_key(group_key, partition_id) -> Tuple[str, ...]:
def _get_segment_from_group_key(group_key, partition_id, explicit_keys: Tuple[str, ...] = ()) -> Tuple[str, ...]:
if isinstance(group_key, str):
segment_tuple_key: Tuple[str, ...] = (group_key,)
elif isinstance(group_key, (List, Iterable, Iterator)):
segment_tuple_key = tuple(str(k) for k in group_key)
else:
segment_tuple_key = (str(group_key),)

return Segment(segment_tuple_key, partition_id)
return Segment(segment_tuple_key + explicit_keys, partition_id)


def _is_nan(x):
Expand All @@ -65,7 +65,11 @@ def _process_simple_partition(
pandas: Optional[pd.DataFrame] = None,
row: Optional[Mapping[str, Any]] = None,
segment_cache: Optional[SegmentCache] = None,
segment_key_values: Optional[Dict[str, str]] = None,
):
explicit_keys = (
tuple(str(segment_key_values[k]) for k in sorted(segment_key_values.keys())) if segment_key_values else tuple()
)
if pandas is not None:
# simple means we can segment on column values
grouped_data = pandas.groupby(columns)
Expand All @@ -81,11 +85,11 @@ def _process_simple_partition(
pandas_segment = pandas[mask]
else:
pandas_segment = grouped_data.get_group(group)
segment_key = _get_segment_from_group_key(group, partition_id)
segment_key = _get_segment_from_group_key(group, partition_id, explicit_keys)
_process_segment(pandas_segment, segment_key, segments, schema, segment_cache)
elif row:
# TODO: consider if we need to combine with the column names
segment_key = Segment(tuple(str(row[element]) for element in columns), partition_id)
segment_key = Segment(tuple(str(row[element]) for element in columns) + explicit_keys, partition_id)
_process_segment(row, segment_key, segments, schema, segment_cache)


Expand Down Expand Up @@ -129,6 +133,7 @@ def _log_segment(
pandas: Optional[pd.DataFrame] = None,
row: Optional[Mapping[str, Any]] = None,
segment_cache: Optional[SegmentCache] = None,
segment_key_values: Optional[Dict[str, str]] = None,
) -> Dict[Segment, Any]:
segments: Dict[Segment, Any] = {}
pandas, row = _pandas_or_dict(obj, pandas, row)
Expand All @@ -137,7 +142,13 @@ def _log_segment(
if partition.simple:
columns = partition.mapper.col_names if partition.mapper else None
if columns:
_process_simple_partition(partition.id, schema, segments, columns, pandas, row, segment_cache)
_process_simple_partition(
partition.id, schema, segments, columns, pandas, row, segment_cache, segment_key_values
)
else:
logger.error(
"Segmented DatasetSchema defines no segments; use an unsegmented DatasetSchema or specify columns to segment on."
)
else:
raise NotImplementedError("custom mapped segments not yet implemented")
return segments
Expand All @@ -149,6 +160,7 @@ def segment_processing(
pandas: Optional[pd.DataFrame] = None,
row: Optional[Dict[str, Any]] = None,
segment_cache: Optional[SegmentCache] = None,
segment_key_values: Optional[Dict[str, str]] = None,
) -> SegmentedResultSet:
number_of_partitions = len(schema.segments)
logger.info(f"The specified schema defines segments with {number_of_partitions} partitions.")
Expand All @@ -160,6 +172,9 @@ def segment_processing(

for partition_name in schema.segments:
segment_partition = schema.segments[partition_name]
if segment_partition.mapper and segment_key_values:
segment_partition.mapper.set_explicit_names(segment_key_values.keys())

logger.info(f"Processing partition with name({partition_name})")
logger.debug(f"{partition_name}: is simple ({segment_partition.simple}), id ({segment_partition.id})")
if segment_partition.filter:
Expand All @@ -176,6 +191,7 @@ def segment_processing(
pandas=pandas,
row=row,
segment_cache=segment_cache,
segment_key_values=segment_key_values,
)
segmented_profiles[segment_partition.id] = partition_segments
segment_partitions.append(segment_partition)
Expand Down
11 changes: 11 additions & 0 deletions python/whylogs/core/segmentation_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ def __post_init__(self):
column_string = ",".join(sorted(self.col_names))
segment_hash = hashlib.sha512(bytes(column_string + mapper_string, encoding="utf8"))
self.id = segment_hash.hexdigest()
self.explicit_names = list()

def set_explicit_names(self, key_names: List[str] = []) -> None:
if self.col_names:
for name in key_names:
if name in self.col_names:
raise ValueError(
f"Cannot have segmentation key {name} as both a column name and explicit segment key"
)

self.explicit_names = sorted(key_names)


@dataclass
Expand Down
3 changes: 2 additions & 1 deletion python/whylogs/migration/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ def _generate_segment_tags_metadata(

segment_tags = []
col_names = partition.mapper.col_names
explicit_names = partition.mapper.explicit_names

for index, column_name in enumerate(col_names):
for index, column_name in enumerate(col_names + explicit_names):
segment_tags.append(SegmentTag(key=_TAG_PREFIX + column_name, value=segment.key[index]))
else:
raise NotImplementedError(
Expand Down

0 comments on commit 4ad4cc9

Please sign in to comment.