Skip to content

Commit 2cc1935

Browse files
authored
Merge pull request #292 from labmlai/default-chart
Default chart
2 parents 803221d + 5c10d7e commit 2cc1935

File tree

4 files changed

+178
-3
lines changed

4 files changed

+178
-3
lines changed

app/server/labml_app/analyses/experiments/custom_metrics.py

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
from typing import List, Tuple, Any
44
from uuid import UUID
55

6+
from labml_app.analyses.helper import get_similarity
7+
from labml_app.db import user, run
68
from labml_db import Model, Key, Index
79
from labml_db.serializer.pickle import PickleSerializer
810
from labml_db.serializer.yaml import YamlSerializer
911
from fastapi import Request
1012

1113
from labml_app.analyses.analysis import Analysis
14+
from labml_app.analyses.experiments.metrics import MetricsAnalysis
1215
from labml_app.analyses.preferences import MetricPreferenceModel
1316

1417

@@ -83,6 +86,10 @@ def create_custom_metric(self, data: dict):
8386
mp.save()
8487
cm.preference_key = mp.key
8588

89+
if 'preferences' in data:
90+
mp.update_preferences(data['preferences'])
91+
mp.save()
92+
8693
cm.save()
8794
self.metrics.append((cm.metric_id, cm.key))
8895
self.save()
@@ -100,6 +107,9 @@ def delete_custom_metric(self, metric_id: str):
100107
def get_data(self):
101108
return [k[1].load().get_data() for k in self.metrics]
102109

110+
def get_metrics(self):
111+
return [k[1].load() for k in self.metrics]
112+
103113
def update(self, data: dict):
104114
for (metric_id, key) in self.metrics:
105115
if metric_id == data['id']:
@@ -115,9 +125,10 @@ async def get_custom_metrics(request: Request, run_uuid: str) -> Any:
115125
r = CustomMetricsListModel()
116126
r.save()
117127
CustomMetricsListIndex.set(run_uuid, r.key)
118-
else:
119-
r = list_key.load()
128+
await create_magic_metric(request, run_uuid)
129+
list_key = r.key
120130

131+
r = list_key.load()
121132
return {'metrics': r.get_data()}
122133

123134

@@ -167,3 +178,93 @@ async def delete_custom_metric(request: Request, run_uuid: str) -> Any:
167178
r.delete_custom_metric(data['id'])
168179

169180
return {'status': 'success'}
181+
182+
183+
@Analysis.route('GET', 'custom_metrics/{run_uuid}/magic')
184+
async def create_magic_metric(request: Request, run_uuid: str) -> Any:
185+
list_key = CustomMetricsListIndex.get(run_uuid)
186+
187+
run_cm = list_key.load()
188+
189+
current_run = run.get(run_uuid)
190+
if current_run is None:
191+
return {'is_success': False, 'message': 'Run not found'}
192+
current_run = current_run.get_summary()
193+
194+
analysis_data = MetricsAnalysis.get_or_create(run_uuid).get_tracking()
195+
indicators = [track_item['name'] for track_item in analysis_data]
196+
indicators = sorted(indicators)
197+
198+
u = user.get_by_session_token('local') # todo
199+
200+
default_project = u.default_project
201+
runs = [r.get_summary() for r in default_project.get_runs()]
202+
203+
runs = sorted(runs, key=lambda i: i['start_time'], reverse=True)
204+
similarity = [get_similarity(current_run, x) for x in runs]
205+
runs = [x for s, x in sorted(zip(similarity, runs), key=lambda pair: pair[0], reverse=True)]
206+
runs = runs[:20]
207+
208+
current_metrics = run_cm.get_data()
209+
current_indicators = []
210+
for x in current_metrics:
211+
current_indicators += x['preferences']['series_preferences']
212+
current_indicators = set(current_indicators)
213+
214+
indicator_counts = {}
215+
for r in runs:
216+
list_key = CustomMetricsListIndex.get(r['run_uuid'])
217+
if list_key is None:
218+
continue
219+
cm = list_key.load()
220+
cm = cm.get_metrics()
221+
222+
for m in cm:
223+
m_data = m.get_data()
224+
preferences = m_data['preferences']
225+
if len(preferences['series_preferences']) == 0:
226+
continue
227+
has_current_indicators = False
228+
for indicator in current_indicators:
229+
if indicator in preferences['series_preferences']:
230+
has_current_indicators = True
231+
break
232+
if has_current_indicators:
233+
continue
234+
235+
preference_map_key = '|'.join(sorted(preferences['series_preferences']))
236+
if preference_map_key not in indicator_counts:
237+
indicator_counts[preference_map_key] = []
238+
indicator_counts[preference_map_key].append((m.key, m_data['created_time']))
239+
240+
if len(indicator_counts) == 0:
241+
return {'is_success': False, 'message': "Couldn't find any related past run."}
242+
243+
sorted_keys = sorted(indicator_counts.keys(), key=lambda x: len(indicator_counts[x]), reverse=True)
244+
245+
# find the first indicator list with overlap
246+
selected = None
247+
for key in sorted_keys:
248+
ind = key.split('|')
249+
overlap = False
250+
for i in indicators:
251+
if i in ind:
252+
overlap = True
253+
break
254+
if overlap:
255+
selected = key
256+
break
257+
258+
if selected is None: # return smth
259+
return {'status': 'error', 'message': 'No similar indicators found'}
260+
261+
selected_metric = sorted(indicator_counts[selected], key=lambda x: x[1], reverse=True)[0]
262+
selected_metric = selected_metric[0].load()
263+
264+
new_metric_data = selected_metric.get_data()
265+
new_metric_data['preferences']['series_preferences'] = \
266+
[x for x in new_metric_data['preferences']['series_preferences'] if x in indicators]
267+
268+
run_cm.create_custom_metric(new_metric_data)
269+
270+
return {'is_success': True}

app/server/labml_app/analyses/helper.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,35 @@ def get_mean_series(res: List[Dict[str, Any]]) -> Dict[str, Any]:
5151
last_step = res[0]['last_step']
5252

5353
return {'step': step, 'value': mean_value, 'name': 'mean', 'last_step': last_step}
54+
55+
56+
def edit_distance(str1, str2):
57+
len1 = len(str1)
58+
len2 = len(str2)
59+
60+
dp = [[0 for i in range(len1 + 1)] for j in range(2)]
61+
62+
for i in range(0, len1 + 1):
63+
dp[0][i] = i
64+
65+
for i in range(1, len2 + 1):
66+
for j in range(0, len1 + 1):
67+
if j == 0:
68+
dp[i % 2][j] = i
69+
elif str1[j - 1] == str2[i - 1]:
70+
dp[i % 2][j] = dp[(i - 1) % 2][j - 1]
71+
else:
72+
dp[i % 2][j] = (1 + min(dp[(i - 1) % 2][j],
73+
min(dp[i % 2][j - 1],
74+
dp[(i - 1) % 2][j - 1])))
75+
return dp[len2 % 2][len1]
76+
77+
78+
def get_similarity(run_a, run_b):
79+
name_edit = edit_distance(run_a['name'], run_b['name']) / max(len(run_a['name']), len(run_b['name']))
80+
81+
return 1 - name_edit
82+
83+
84+
85+

app/ui/src/network.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ class Network {
132132
return this.sendHttpRequest('POST', `/${url}/preferences/${runUUID}`, data)['promise']
133133
}
134134

135+
async createMagicMetric(runUUID: string): Promise<any> {
136+
return this.sendHttpRequest('GET', `/custom_metrics/${runUUID}/magic`, {})['promise']
137+
}
138+
135139
async createCustomMetric(runUUID: string, data: object): Promise<any> {
136140
return this.sendHttpRequest('POST', `/custom_metrics/${runUUID}/create`, data)['promise']
137141
}

app/ui/src/views/run_view.ts

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import {User} from '../models/user'
44
import {ROUTER, SCREEN} from '../app'
55
import {Weya as $, WeyaElement} from '../../../lib/weya/weya'
66
import {DataLoader} from "../components/loader"
7-
import {AddButton, BackButton, CustomButton, ExpandButton, ShareButton} from "../components/buttons"
7+
import {AddButton, BackButton, CustomButton, ExpandButton, IconButton, ShareButton} from "../components/buttons"
88
import {UserMessages} from "../components/user_messages"
99
import {RunHeaderCard} from "../analyses/experiments/run_header/card"
1010
import {experimentAnalyses} from "../analyses/analyses"
@@ -15,6 +15,7 @@ import {AwesomeRefreshButton} from '../components/refresh_button'
1515
import {setTitle} from '../utils/document'
1616
import {ScreenView} from '../screen_view'
1717
import metricAnalysis from "../analyses/experiments/custom_metrics";
18+
import NETWORK from "../network";
1819

1920
class RunView extends ScreenView {
2021
uuid: string
@@ -37,6 +38,7 @@ class RunView extends ScreenView {
3738
private refresh: AwesomeRefreshButton
3839
private share: ShareButton
3940
private addCustomMetricButton: AddButton
41+
private magicMetricButton: IconButton
4042
private isRankExpanded: boolean
4143
private rankElems: WeyaElement
4244
private processContainer: WeyaElement
@@ -69,6 +71,16 @@ class RunView extends ScreenView {
6971
title: 'Add custom metric',
7072
parent: this.constructor.name
7173
})
74+
this.magicMetricButton = new IconButton({
75+
onButtonClick: () => {
76+
this.magicMetricButton.loading = true
77+
this.creatMagicMetric().then(() => {
78+
this.magicMetricButton.loading = false
79+
})
80+
},
81+
title: 'Add magic metric',
82+
parent: this.constructor.name,
83+
}, '.fas.fa-magic')
7284
}
7385

7486
private get isRank(): boolean {
@@ -185,6 +197,7 @@ class RunView extends ScreenView {
185197
}).render($)
186198
}
187199
this.addCustomMetricButton.render($)
200+
this.magicMetricButton.render($)
188201
})
189202
}
190203

@@ -194,6 +207,31 @@ class RunView extends ScreenView {
194207
}
195208
}
196209

210+
async creatMagicMetric() {
211+
let res = null
212+
try {
213+
res = await NETWORK.createMagicMetric(this.uuid)
214+
} catch (e) {
215+
UserMessages.shared.networkError(e, 'Failed to create magic metric')
216+
return
217+
}
218+
219+
220+
if (res['is_success']) {
221+
try {
222+
this.customMetrics = await CACHE.getCustomMetrics(this.uuid).get(true)
223+
} catch (e) {
224+
UserMessages.shared.networkError(e, 'Failed to load custom metrics')
225+
return
226+
}
227+
228+
this.renderCards()
229+
UserMessages.shared.success('Chart created')
230+
} else {
231+
UserMessages.shared.warning(res['message'])
232+
}
233+
}
234+
197235
async createCustomMetric() {
198236
this.addCustomMetricButton.loading = true
199237
try {

0 commit comments

Comments
 (0)