3
3
from typing import List , Tuple , Any
4
4
from uuid import UUID
5
5
6
+ from labml_app .analyses .helper import get_similarity
7
+ from labml_app .db import user , run
6
8
from labml_db import Model , Key , Index
7
9
from labml_db .serializer .pickle import PickleSerializer
8
10
from labml_db .serializer .yaml import YamlSerializer
9
11
from fastapi import Request
10
12
11
13
from labml_app .analyses .analysis import Analysis
14
+ from labml_app .analyses .experiments .metrics import MetricsAnalysis
12
15
from labml_app .analyses .preferences import MetricPreferenceModel
13
16
14
17
@@ -83,6 +86,10 @@ def create_custom_metric(self, data: dict):
83
86
mp .save ()
84
87
cm .preference_key = mp .key
85
88
89
+ if 'preferences' in data :
90
+ mp .update_preferences (data ['preferences' ])
91
+ mp .save ()
92
+
86
93
cm .save ()
87
94
self .metrics .append ((cm .metric_id , cm .key ))
88
95
self .save ()
@@ -100,6 +107,9 @@ def delete_custom_metric(self, metric_id: str):
100
107
def get_data (self ):
101
108
return [k [1 ].load ().get_data () for k in self .metrics ]
102
109
110
+ def get_metrics (self ):
111
+ return [k [1 ].load () for k in self .metrics ]
112
+
103
113
def update (self , data : dict ):
104
114
for (metric_id , key ) in self .metrics :
105
115
if metric_id == data ['id' ]:
@@ -115,9 +125,10 @@ async def get_custom_metrics(request: Request, run_uuid: str) -> Any:
115
125
r = CustomMetricsListModel ()
116
126
r .save ()
117
127
CustomMetricsListIndex .set (run_uuid , r .key )
118
- else :
119
- r = list_key . load ()
128
+ await create_magic_metric ( request , run_uuid )
129
+ list_key = r . key
120
130
131
+ r = list_key .load ()
121
132
return {'metrics' : r .get_data ()}
122
133
123
134
@@ -167,3 +178,93 @@ async def delete_custom_metric(request: Request, run_uuid: str) -> Any:
167
178
r .delete_custom_metric (data ['id' ])
168
179
169
180
return {'status' : 'success' }
181
+
182
+
183
+ @Analysis .route ('GET' , 'custom_metrics/{run_uuid}/magic' )
184
+ async def create_magic_metric (request : Request , run_uuid : str ) -> Any :
185
+ list_key = CustomMetricsListIndex .get (run_uuid )
186
+
187
+ run_cm = list_key .load ()
188
+
189
+ current_run = run .get (run_uuid )
190
+ if current_run is None :
191
+ return {'is_success' : False , 'message' : 'Run not found' }
192
+ current_run = current_run .get_summary ()
193
+
194
+ analysis_data = MetricsAnalysis .get_or_create (run_uuid ).get_tracking ()
195
+ indicators = [track_item ['name' ] for track_item in analysis_data ]
196
+ indicators = sorted (indicators )
197
+
198
+ u = user .get_by_session_token ('local' ) # todo
199
+
200
+ default_project = u .default_project
201
+ runs = [r .get_summary () for r in default_project .get_runs ()]
202
+
203
+ runs = sorted (runs , key = lambda i : i ['start_time' ], reverse = True )
204
+ similarity = [get_similarity (current_run , x ) for x in runs ]
205
+ runs = [x for s , x in sorted (zip (similarity , runs ), key = lambda pair : pair [0 ], reverse = True )]
206
+ runs = runs [:20 ]
207
+
208
+ current_metrics = run_cm .get_data ()
209
+ current_indicators = []
210
+ for x in current_metrics :
211
+ current_indicators += x ['preferences' ]['series_preferences' ]
212
+ current_indicators = set (current_indicators )
213
+
214
+ indicator_counts = {}
215
+ for r in runs :
216
+ list_key = CustomMetricsListIndex .get (r ['run_uuid' ])
217
+ if list_key is None :
218
+ continue
219
+ cm = list_key .load ()
220
+ cm = cm .get_metrics ()
221
+
222
+ for m in cm :
223
+ m_data = m .get_data ()
224
+ preferences = m_data ['preferences' ]
225
+ if len (preferences ['series_preferences' ]) == 0 :
226
+ continue
227
+ has_current_indicators = False
228
+ for indicator in current_indicators :
229
+ if indicator in preferences ['series_preferences' ]:
230
+ has_current_indicators = True
231
+ break
232
+ if has_current_indicators :
233
+ continue
234
+
235
+ preference_map_key = '|' .join (sorted (preferences ['series_preferences' ]))
236
+ if preference_map_key not in indicator_counts :
237
+ indicator_counts [preference_map_key ] = []
238
+ indicator_counts [preference_map_key ].append ((m .key , m_data ['created_time' ]))
239
+
240
+ if len (indicator_counts ) == 0 :
241
+ return {'is_success' : False , 'message' : "Couldn't find any related past run." }
242
+
243
+ sorted_keys = sorted (indicator_counts .keys (), key = lambda x : len (indicator_counts [x ]), reverse = True )
244
+
245
+ # find the first indicator list with overlap
246
+ selected = None
247
+ for key in sorted_keys :
248
+ ind = key .split ('|' )
249
+ overlap = False
250
+ for i in indicators :
251
+ if i in ind :
252
+ overlap = True
253
+ break
254
+ if overlap :
255
+ selected = key
256
+ break
257
+
258
+ if selected is None : # return smth
259
+ return {'status' : 'error' , 'message' : 'No similar indicators found' }
260
+
261
+ selected_metric = sorted (indicator_counts [selected ], key = lambda x : x [1 ], reverse = True )[0 ]
262
+ selected_metric = selected_metric [0 ].load ()
263
+
264
+ new_metric_data = selected_metric .get_data ()
265
+ new_metric_data ['preferences' ]['series_preferences' ] = \
266
+ [x for x in new_metric_data ['preferences' ]['series_preferences' ] if x in indicators ]
267
+
268
+ run_cm .create_custom_metric (new_metric_data )
269
+
270
+ return {'is_success' : True }
0 commit comments