Skip to content

Commit 76338a1

Browse files
authored
Merge pull request #5 from SCAI-BIO/angular-workshop-changes
Angular workshop changes
2 parents ce0c039 + a20dc6b commit 76338a1

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

datastew/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010
"mapping",
1111
"parsing",
1212
"model",
13-
"sqllite"
13+
"sqllite",
14+
"DataDictionarySource",
1415
]

datastew/visualisation.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_cohort_specific_color_code(cohort_name: str):
3939
def enrichment_plot(acc_gpt, acc_mpnet, acc_fuzzy, title, save_plot=False, save_dir="resources/results/plots"):
4040
if not (len(acc_gpt) == len(acc_fuzzy) == len(acc_mpnet)):
4141
raise ValueError("acc_gpt, acc_mpnet and acc_fuzzy should be of the same length!")
42-
data = {"Maximum Considered Rank": list(range(1, len(acc_gpt) + 1)), "GPT": acc_gpt,
42+
data = {"Maximum Considered Rank": list(range(1, len(acc_gpt) + 1)), "GPT": acc_gpt,
4343
"MPNet": acc_mpnet, "Fuzzy": acc_fuzzy}
4444
df = pd.DataFrame(data)
4545
sns.set(style="whitegrid")
@@ -112,7 +112,7 @@ def bar_chart_average_acc_two_distributions(dist1_fuzzy: pd.DataFrame, dist1_gpt
112112

113113

114114
def scatter_plot_two_distributions(tables1: [MappingTable], tables2: [MappingTable], label1: str, label2: str,
115-
store_html: bool = True, legend_font_size: int = 16,
115+
store_html: bool = True, legend_font_size: int = 16,
116116
store_destination: str = "resources/results/plots/ad_vs_pd.html"):
117117
vectors_tables1 = np.concatenate([table.get_embeddings_numpy() for table in tables1])
118118
vectors_tables2 = np.concatenate([table.get_embeddings_numpy() for table in tables2])
@@ -157,26 +157,26 @@ def scatter_plot_all_cohorts(tables1: [MappingTable], tables2: [MappingTable], l
157157
boundaries = np.insert(boundaries, 0, 0)
158158
for idx in range(len(tables1)):
159159
if labels1[idx]:
160-
fig.add_trace(go.Scatter(x=tsne_result[boundaries[idx] : boundaries[idx + 1], 0],
161-
y=tsne_result[boundaries[idx] : boundaries[idx + 1], 1],
160+
fig.add_trace(go.Scatter(x=tsne_result[boundaries[idx]: boundaries[idx + 1], 0],
161+
y=tsne_result[boundaries[idx]: boundaries[idx + 1], 1],
162162
mode="markers", name=labels1[idx],
163-
text=descriptions[boundaries[idx] : boundaries[idx + 1]],
163+
text=descriptions[boundaries[idx]: boundaries[idx + 1]],
164164
# line=dict(color=get_cohort_specific_color_code(labels1[idx]))
165165
))
166166
for idy in range(len(tables1), len(boundaries) - 1):
167-
fig.add_trace(go.Scatter(x=tsne_result[boundaries[idy] : boundaries[idy + 1], 0],
168-
y=tsne_result[boundaries[idy] : boundaries[idy + 1], 1],
167+
fig.add_trace(go.Scatter(x=tsne_result[boundaries[idy]: boundaries[idy + 1], 0],
168+
y=tsne_result[boundaries[idy]: boundaries[idy + 1], 1],
169169
mode="markers",
170170
name=labels2[idy - len(tables1)],
171-
text=descriptions[boundaries[idy] : boundaries[idy + 1]],
171+
text=descriptions[boundaries[idy]: boundaries[idy + 1]],
172172
# line=dict(color=get_cohort_specific_color_code(labels2[idy - len(tables1)]))
173173
))
174174
if store_html:
175175
fig.write_html(store_base_dir + "/tsne_all_cohorts.html")
176176
fig.show()
177177

178178

179-
def get_html_plot_for_current_database_state(repository: BaseRepository, perplexity: int = 5) -> str:
179+
def get_plot_for_current_database_state(repository: BaseRepository, perplexity: int = 5, return_type="html") -> str:
180180
# get up to 1000 entries from db
181181
mappings = repository.get_all_mappings()
182182
# Extract embeddings
@@ -206,10 +206,12 @@ def get_html_plot_for_current_database_state(repository: BaseRepository, perplex
206206
yaxis=dict(title='t-SNE Component 2'),
207207
)
208208
fig = go.Figure(data=[scatter_plot], layout=layout)
209-
# Convert the Plotly figure to HTML
210-
html_plot = fig.to_html(full_html=False)
209+
if return_type == "html":
210+
plot = fig.to_html(full_html=False)
211+
elif return_type == "json":
212+
plot = fig.to_json()
213+
else:
214+
raise ValueError(f'Return type {return_type} is not viable. Use either "html" or "json".')
211215
else:
212-
html_plot = '<b>Too few database entries to visualize</b>'
213-
return html_plot
214-
215-
216+
plot = '<b>Too few database entries to visualize</b>'
217+
return plot

0 commit comments

Comments
 (0)