-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathfunctions.py
149 lines (117 loc) · 3.95 KB
/
functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Helping functions for 'introduction' and 'quickstart' notebooks."""
# -- File info -- #
__author__ = 'Andreas R. Stokholm'
__contributors__ = ''
__copyright__ = ['Technical University of Denmark', 'European Space Agency']
__contact__ = ['stokholm@space.dtu.dk']
__version__ = '1.0.0'
__date__ = '2022-10-17'
# -- Built-in modules -- #
# -- Third-party modules -- #
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import r2_score, f1_score
# -- Proprietary modules -- #
from utils import ICE_STRINGS, GROUP_NAMES
def chart_cbar(ax, n_classes, chart, cmap='vridis'):
"""
Create discrete colourbar for plot with the sea ice parameter class names.
Parameters
----------
n_classes: int
Number of classes for the chart parameter.
chart: str
The relevant chart.
"""
arranged = np.arange(0, n_classes)
cmap = plt.get_cmap(cmap, n_classes - 1)
norm = mpl.colors.BoundaryNorm(arranged - 0.5, cmap.N) # Get colour boundaries. -0.5 to center ticks for each color.
arranged = arranged[:-1] # Discount the mask class.
cbar = plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ticks=arranged, fraction=0.0485, pad=0.049, ax=ax)
cbar.set_label(label=ICE_STRINGS[chart])
cbar.set_ticklabels(list(GROUP_NAMES[chart].values()))
def compute_metrics(true, pred, charts, metrics):
"""
Calculates metrics for each chart and the combined score. true and pred must be 1d arrays of equal length.
Parameters
----------
true :
ndarray, 1d contains all true pixels. Must be numpy array.
pred :
ndarray, 1d contains all predicted pixels. Must be numpy array.
charts : List
List of charts.
metrics : Dict
Stores metric calculation function and weight for each chart.
Returns
-------
combined_score: float
Combined weighted average score.
scores: list
List of scores for each chart.
"""
scores = {}
for chart in charts:
if true[chart].ndim == 1 and pred[chart].ndim == 1:
scores[chart] = np.round(metrics[chart]['func'](true=true[chart], pred=pred[chart]) * 100, 3)
else:
print(f"true and pred must be 1D numpy array, got {true['SIC'].ndim} and {pred['SIC'].ndim} dimensions with shape {true['SIC'].shape} and {pred.shape}, respectively")
combined_score = compute_combined_score(scores=scores, charts=charts, metrics=metrics)
return combined_score, scores
def r2_metric(true, pred):
"""
Calculate the r2 metric.
Parameters
----------
true :
ndarray, 1d contains all true pixels. Must by numpy array.
pred :
ndarray, 1d contains all predicted pixels. Must by numpy array.
Returns
-------
r2 : float
The calculated r2 score.
"""
r2 = r2_score(y_true=true, y_pred=pred)
return r2
def f1_metric(true, pred):
"""
Calculate the weighted f1 metric.
Parameters
----------
true :
ndarray, 1d contains all true pixels.
pred :
ndarray, 1d contains all predicted pixels.
Returns
-------
f1 : float
The calculated f1 score.
"""
f1 = f1_score(y_true=true, y_pred=pred, average='weighted')
return f1
def compute_combined_score(scores, charts, metrics):
"""
Calculate the combined weighted score.
Parameters
----------
scores : List
Score for each chart.
charts : List
List of charts.
metrics : Dict
Stores metric calculation function and weight for each chart.
Returns
-------
: float
The combined weighted score.
"""
combined_metric = 0
sum_weight = 0
for chart in charts:
combined_metric += scores[chart] * metrics[chart]['weight']
sum_weight += metrics[chart]['weight']
return np.round(combined_metric / sum_weight, 3)