Skip to content

Commit

Permalink
Update GANDLF/utils/generic.py
Browse files Browse the repository at this point in the history
Co-authored-by: Sarthak Pati <patis@iu.edu>
  • Loading branch information
szmazurek and sarthakpati authored Nov 22, 2023
1 parent 111a03a commit 711099d
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions GANDLF/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,26 +287,28 @@ def __update_metric_from_list_to_single_string(input_metrics_dict) -> dict:

def define_average_type_key(
params: Dict[str, Union[Dict[str, Any], Any]], metric_name: str
):
) -> str:
"""Determine if the the 'average' filed is defined in the metric config.
If not, fallback to the default 'macro'
values.
Args:
params (dict): The parameter dictionary containing training and data information.
metric_name (str): The name of the metric.
Returns:
str: The average type key.
"""
average_type_key = params["metrics"][metric_name].get("average", "macro")
return average_type_key


def define_multidim_average_type_key(params, metric_name):
def define_multidim_average_type_key(params, metric_name) -> str:
"""Determine if the the 'multidim_average' filed is defined in the metric config.
If not, fallback to the default 'global'.
Args:
params (dict): The parameter dictionary containing training and data information.
metric_name (str): The name of the metric.
Returns:
str: The average type key.
"""
Expand Down

0 comments on commit 711099d

Please sign in to comment.