1
1
import logging
2
2
import os
3
+ from collections import Counter
3
4
from copy import deepcopy
4
5
from pathlib import Path
5
6
from typing import Dict , List , Optional , Union , Tuple , Iterable , Callable , Set , Literal , Any
@@ -75,14 +76,36 @@ def __init__(self,
75
76
self .status : Optional [Status ] = None
76
77
77
78
if data is not None :
79
+ data = deepcopy (data ) # preserve the incoming data variable.
78
80
self .set_data (data , constraints = constraints )
79
81
80
82
@staticmethod
81
83
def _strip_common_prefix (df : pd .DataFrame ) -> Tuple [pd .DataFrame , str ]:
82
- common_prefix = os .path .commonprefix (df .columns .to_list ())
83
- stripped_df = df .copy ()
84
- stripped_df .columns = [col .replace (common_prefix , '' ) for col in df .columns ]
85
- return stripped_df , common_prefix
84
+ # Extract prefixes
85
+ common_prefix = MassComposition .get_common_prefix (df .columns .to_list ())
86
+
87
+ res = df
88
+ # Create a copy of the dataframe and strip the most common prefix from column names
89
+ if common_prefix :
90
+ res = df .copy ()
91
+ res .columns = [col .replace (common_prefix + '_' , '' ) if col .startswith (common_prefix ) else col for col in
92
+ df .columns ]
93
+
94
+ return res , common_prefix
95
+
96
+ @staticmethod
97
+ def get_common_prefix (columns : List [str ]) -> str :
98
+ prefixes = [col .split ('_' )[0 ] for col in columns ]
99
+ # Count the frequency of each prefix
100
+ prefix_counter = Counter (prefixes )
101
+ # Check if prefix_counter is not empty
102
+ if prefix_counter :
103
+ # Find the most common prefix
104
+ common_prefix , freq = prefix_counter .most_common (1 )[0 ]
105
+ # Only return the prefix if its frequency is 3 or more
106
+ if freq >= 3 :
107
+ return common_prefix
108
+ return ""
86
109
87
110
def set_data (self , data : Union [pd .DataFrame , xr .Dataset ],
88
111
constraints : Optional [Dict [str , List ]] = None ):
@@ -104,7 +127,8 @@ def set_data(self, data: Union[pd.DataFrame, xr.Dataset],
104
127
# seek a prefix to self assign the name
105
128
data , common_prefix = self ._strip_common_prefix (data )
106
129
if common_prefix :
107
- self ._specified_columns = {k : v .replace (common_prefix , '' ) for k , v in self ._specified_columns .items ()
130
+ self ._specified_columns = {k : v .replace (f"{ common_prefix } _" , '' ) for k , v in
131
+ self ._specified_columns .items ()
108
132
if v is not None }
109
133
110
134
self .variables = Variables (config = self .config ['vars' ],
@@ -630,24 +654,24 @@ def split_by_estimator(self,
630
654
"""
631
655
# Extract feature names from the estimator, and get the actual features
632
656
feature_names : list [str ] = list (extract_feature_names (estimator ))
633
- features : pd .DataFrame = self ._get_features (feature_names , extra_features , allow_prefix_mismatch )
657
+ features : pd .DataFrame = self ._get_features (feature_names , allow_prefix_mismatch = allow_prefix_mismatch ,
658
+ extra_features = extra_features )
634
659
635
660
# Apply the estimator
636
661
estimates : pd .DataFrame = estimator .predict (X = features )
637
662
if isinstance (estimates , np .ndarray ):
638
663
raise NotImplementedError ("The estimator must return a DataFrame" )
639
664
640
665
# Detect a possible prefix from the estimate columns
641
- features_prefix : str = os . path . commonprefix (features .columns .to_list ())
642
- estimates_prefix : str = os . path . commonprefix (estimates .columns .to_list ())
666
+ features_prefix : str = self . get_common_prefix (features .columns .to_list ())
667
+ estimates_prefix : str = self . get_common_prefix (estimates .columns .to_list ())
643
668
644
669
# If there is a prefix, check that it matches name_1, subject to allow_prefix_mismatch
645
- if estimates_prefix .strip (
646
- '_' ) and not allow_prefix_mismatch and name_1 and not name_1 == estimates_prefix .strip ('_' ):
670
+ if estimates_prefix and not allow_prefix_mismatch and name_1 and not name_1 == estimates_prefix :
647
671
raise ValueError (f"Common prefix mismatch: { features_prefix } and name_1: { name_1 } " )
648
672
649
673
# assign the output names, based on specified names, allow for prefix mismatch
650
- name_1 = name_1 if name_1 else estimates_prefix . strip ( '_' )
674
+ name_1 = name_1 if name_1 else estimates_prefix
651
675
652
676
if mass_recovery_column :
653
677
# Transform the mass recovery to mass by applying the mass recovery to the dry mass of the input stream
@@ -661,7 +685,9 @@ def split_by_estimator(self,
661
685
dry_mass_var ].values / mass_recovery_max
662
686
estimates .rename (columns = {mass_recovery_column : dry_mass_var }, inplace = True )
663
687
664
- estimates .columns = [f .replace (estimates_prefix , "" ) for f in estimates .columns ]
688
+ if estimates_prefix :
689
+ col_name_map : dict [str , str ] = {f : f .replace (estimates_prefix + '_' , "" ) for f in estimates .columns }
690
+ estimates .rename (columns = col_name_map , inplace = True )
665
691
666
692
out : MassComposition = MassComposition (name = name_1 , constraints = self .constraints , data = estimates )
667
693
comp : MassComposition = self .sub (other = out , name = name_2 )
@@ -671,7 +697,7 @@ def split_by_estimator(self,
671
697
return out , comp
672
698
673
699
def _get_features (self , feature_names : List [str ], allow_prefix_mismatch : bool ,
674
- extra_features : Optional [pd .DataFrame ] = None ,) -> pd .DataFrame :
700
+ extra_features : Optional [pd .DataFrame ] = None , ) -> pd .DataFrame :
675
701
"""
676
702
This method checks if the feature names required by an estimator are present in the data. If not, it tries to
677
703
match the feature names by considering a common prefix. If a match is found, the columns in the data are renamed
@@ -696,16 +722,16 @@ def _get_features(self, feature_names: List[str], allow_prefix_mismatch: bool,
696
722
feature_name_map = {name .lower (): name for name in feature_names }
697
723
698
724
df_features : pd .DataFrame = self .data .to_dataframe ()
699
- if extra_features :
725
+ if extra_features is not None :
700
726
df_features = pd .concat ([df_features , extra_features ], axis = 1 )
701
727
702
728
missing_features = set (f .lower () for f in feature_names ) - set (c .lower () for c in df_features .columns )
703
729
704
730
if missing_features :
705
731
prefix : str = f"{ self .name } _"
706
- common_prefix : str = os . path . commonprefix (feature_names )
707
- if common_prefix and common_prefix != prefix and allow_prefix_mismatch :
708
- prefix = common_prefix
732
+ common_prefix : str = self . get_common_prefix (feature_names )
733
+ if common_prefix and common_prefix + '_' != prefix and allow_prefix_mismatch :
734
+ prefix = common_prefix + '_'
709
735
710
736
# create a map to support renaming the columns
711
737
prefixed_feature_map : dict [str , str ] = {f : feature_name_map .get (f"{ prefix } { f .lower ()} " ) for f in
0 commit comments