Skip to content

Commit

Permalink
partition bugfix for 2d case
Browse files Browse the repository at this point in the history
  • Loading branch information
elphick committed Nov 27, 2024
1 parent 64bd503 commit c39e2e9
Showing 1 changed file with 17 additions and 23 deletions.
40 changes: 17 additions & 23 deletions elphick/geomet/interval_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,34 +115,28 @@ def split_by_partition(self, partition_definition: Union[pd.Series, Callable], n
:param name_2: The name of the second sample.
:return: A tuple of two IntervalSamples.
"""
if not isinstance(partition_definition, Callable):
raise TypeError("The definition is not a callable function")

# Check that the partition definition has the correct number of arguments and that the names match
if isinstance(self.mass_data.index, pd.MultiIndex):
interval_levels = [level for level in self.mass_data.index.levels if isinstance(level, pd.IntervalIndex)]
else:
interval_levels = [self.mass_data.index] if isinstance(self.mass_data.index, pd.IntervalIndex) else []
dim_cols = [col for col in self.mass_data.index.names if
col != isinstance(self.mass_data.index.get_level_values(col), pd.IntervalIndex)]
fraction_means: pd.DataFrame = self.mass_data.index.to_frame()[dim_cols].apply(
lambda x: MeanIntervalIndex(x).mean, axis=0)

# Get the function from the partial object if necessary
partition_func = partition_definition.func if isinstance(partition_definition,
functools.partial) else partition_definition

# Check that the required argument names are present in the IntervalIndex levels
required_args = partition_func.__code__.co_varnames[:len(interval_levels)]
for arg, level in zip(required_args, interval_levels):
if arg != level.name:
raise ValueError(f"The partition definition argument name does not match the index name. "
f"Expected {level.name}, found {arg}")

fraction_means: dict = {}
# iterate the Index or MultiIndex
if isinstance(self.mass_data.index, pd.MultiIndex):
for idx in self.mass_data.index.levels[0]:
# get the mean of the fractions, by converting to a MeanIntervalIndex
fraction_means[idx] = MeanIntervalIndex(self.mass_data.index.get_loc_level(idx)).mean
if isinstance(partition_definition, Callable):
partition_func = partition_definition.func if isinstance(partition_definition,
functools.partial) else partition_definition
# Check that the required argument names are present in the IntervalIndex levels
required_args = partition_func.__code__.co_varnames[:len(dim_cols)]
elif isinstance(partition_definition, pd.Series):
required_args = partition_definition.index.names
else:
fraction_means[self.mass_data.index.name] = MeanIntervalIndex(self.mass_data.index).mean
raise TypeError(f"The partition definition must be a function or a pandas Series:"
f" type = {type(partition_definition)}")
for arg, dim in zip(required_args, dim_cols):
if arg != dim:
raise ValueError(f"The partition definition argument name does not match the index name. "
f"Expected {dim}, found {arg}")

self.to_stream()
self: 'Stream'
Expand Down

0 comments on commit c39e2e9

Please sign in to comment.