@@ -39,7 +39,7 @@ def get_cohort_specific_color_code(cohort_name: str):
39
39
def enrichment_plot (acc_gpt , acc_mpnet , acc_fuzzy , title , save_plot = False , save_dir = "resources/results/plots" ):
40
40
if not (len (acc_gpt ) == len (acc_fuzzy ) == len (acc_mpnet )):
41
41
raise ValueError ("acc_gpt, acc_mpnet and acc_fuzzy should be of the same length!" )
42
- data = {"Maximum Considered Rank" : list (range (1 , len (acc_gpt ) + 1 )), "GPT" : acc_gpt ,
42
+ data = {"Maximum Considered Rank" : list (range (1 , len (acc_gpt ) + 1 )), "GPT" : acc_gpt ,
43
43
"MPNet" : acc_mpnet , "Fuzzy" : acc_fuzzy }
44
44
df = pd .DataFrame (data )
45
45
sns .set (style = "whitegrid" )
@@ -112,7 +112,7 @@ def bar_chart_average_acc_two_distributions(dist1_fuzzy: pd.DataFrame, dist1_gpt
112
112
113
113
114
114
def scatter_plot_two_distributions (tables1 : [MappingTable ], tables2 : [MappingTable ], label1 : str , label2 : str ,
115
- store_html : bool = True , legend_font_size : int = 16 ,
115
+ store_html : bool = True , legend_font_size : int = 16 ,
116
116
store_destination : str = "resources/results/plots/ad_vs_pd.html" ):
117
117
vectors_tables1 = np .concatenate ([table .get_embeddings_numpy () for table in tables1 ])
118
118
vectors_tables2 = np .concatenate ([table .get_embeddings_numpy () for table in tables2 ])
@@ -157,26 +157,26 @@ def scatter_plot_all_cohorts(tables1: [MappingTable], tables2: [MappingTable], l
157
157
boundaries = np .insert (boundaries , 0 , 0 )
158
158
for idx in range (len (tables1 )):
159
159
if labels1 [idx ]:
160
- fig .add_trace (go .Scatter (x = tsne_result [boundaries [idx ] : boundaries [idx + 1 ], 0 ],
161
- y = tsne_result [boundaries [idx ] : boundaries [idx + 1 ], 1 ],
160
+ fig .add_trace (go .Scatter (x = tsne_result [boundaries [idx ]: boundaries [idx + 1 ], 0 ],
161
+ y = tsne_result [boundaries [idx ]: boundaries [idx + 1 ], 1 ],
162
162
mode = "markers" , name = labels1 [idx ],
163
- text = descriptions [boundaries [idx ] : boundaries [idx + 1 ]],
163
+ text = descriptions [boundaries [idx ]: boundaries [idx + 1 ]],
164
164
# line=dict(color=get_cohort_specific_color_code(labels1[idx]))
165
165
))
166
166
for idy in range (len (tables1 ), len (boundaries ) - 1 ):
167
- fig .add_trace (go .Scatter (x = tsne_result [boundaries [idy ] : boundaries [idy + 1 ], 0 ],
168
- y = tsne_result [boundaries [idy ] : boundaries [idy + 1 ], 1 ],
167
+ fig .add_trace (go .Scatter (x = tsne_result [boundaries [idy ]: boundaries [idy + 1 ], 0 ],
168
+ y = tsne_result [boundaries [idy ]: boundaries [idy + 1 ], 1 ],
169
169
mode = "markers" ,
170
170
name = labels2 [idy - len (tables1 )],
171
- text = descriptions [boundaries [idy ] : boundaries [idy + 1 ]],
171
+ text = descriptions [boundaries [idy ]: boundaries [idy + 1 ]],
172
172
# line=dict(color=get_cohort_specific_color_code(labels2[idy - len(tables1)]))
173
173
))
174
174
if store_html :
175
175
fig .write_html (store_base_dir + "/tsne_all_cohorts.html" )
176
176
fig .show ()
177
177
178
178
179
- def get_html_plot_for_current_database_state (repository : BaseRepository , perplexity : int = 5 ) -> str :
179
+ def get_plot_for_current_database_state (repository : BaseRepository , perplexity : int = 5 , return_type = "html" ) -> str :
180
180
# get up to 1000 entries from db
181
181
mappings = repository .get_all_mappings ()
182
182
# Extract embeddings
@@ -206,10 +206,12 @@ def get_html_plot_for_current_database_state(repository: BaseRepository, perplex
206
206
yaxis = dict (title = 't-SNE Component 2' ),
207
207
)
208
208
fig = go .Figure (data = [scatter_plot ], layout = layout )
209
- # Convert the Plotly figure to HTML
210
- html_plot = fig .to_html (full_html = False )
209
+ if return_type == "html" :
210
+ plot = fig .to_html (full_html = False )
211
+ elif return_type == "json" :
212
+ plot = fig .to_json ()
213
+ else :
214
+ raise ValueError (f'Return type { return_type } is not viable. Use either "html" or "json".' )
211
215
else :
212
- html_plot = '<b>Too few database entries to visualize</b>'
213
- return html_plot
214
-
215
-
216
+ plot = '<b>Too few database entries to visualize</b>'
217
+ return plot
0 commit comments