Skip to content

Commit

Permalink
Add function, tests & documentation to visualize vector embeddings fo…
Browse files Browse the repository at this point in the history
…r data dictonaries
  • Loading branch information
tiadams committed Oct 21, 2024
1 parent 5ca0f0a commit 7a5c047
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 3 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ embedding_model = GPT4Adapter(key="your_api_key")
df = map_dictionary_to_dictionary(source, target, embedding_model=embedding_model)
```

You can also retrieve embeddings from data dictionaries and visualize them in form of an interactive scatterplot to
explore sematic neighborhoods:

```python
from datastew.visualisation import plot_embeddings

# Get embedding vectors for your dictionaries
source_embeddings = source.get_embeddings()

# plot embedding neighborhoods for several dictionaries
plot_embeddings(data_dictionaries=[source, target])

```

### Creating and using stored mappings

A simple example how to initialize an in memory database and compute a similarity mapping is shown in
Expand Down
63 changes: 61 additions & 2 deletions datastew/visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import plotly.graph_objects as go
import seaborn as sns
from sklearn.manifold import TSNE
from typing import Optional
from typing import Optional, List

from datastew.process.parsing import DataDictionarySource
from datastew.embedding import EmbeddingModel, MPNetAdapter
from datastew.conf import COLORS_AD, COLORS_PD
from datastew.mapping import MappingTable
from datastew.repository.base import BaseRepository
Expand Down Expand Up @@ -177,7 +179,8 @@ def scatter_plot_all_cohorts(tables1: [MappingTable], tables2: [MappingTable], l
fig.show()


def get_plot_for_current_database_state(repository: BaseRepository, terminology: Optional[str] = None, perplexity: int = 5, return_type="html") -> str:
def get_plot_for_current_database_state(repository: BaseRepository, terminology: Optional[str] = None,
perplexity: int = 5, return_type="html") -> str:
if not terminology:
mappings = repository.get_mappings()
else:
Expand Down Expand Up @@ -218,3 +221,59 @@ def get_plot_for_current_database_state(repository: BaseRepository, terminology:
else:
plot = "<b>Too few database entries to visualize</b>"
return plot


def plot_embeddings(data_dictionaries: List[DataDictionarySource], embedding_model: Optional[EmbeddingModel] = None,
perplexity: int = 5):
"""
Plots a t-SNE representation of embeddings from multiple data dictionaries and displays the plot.
:param data_dictionaries: A list of DataDictionarySource objects to extract embeddings from.
:param embedding_model: The embedding model used to compute embeddings. Defaults to MPNetAdapter.
:param perplexity: The perplexity for the t-SNE algorithm. Higher values give more global structure.
"""
if embedding_model is None:
embedding_model = MPNetAdapter()
all_embeddings = []
all_texts = []
all_colors = []
plotly_colors = px.colors.qualitative.Plotly
for idx, dictionary in enumerate(data_dictionaries):
embeddings_dict = dictionary.get_embeddings(embedding_model=embedding_model)
embeddings = list(embeddings_dict.values())
texts = dictionary.to_dataframe()['description']
color = plotly_colors[idx % len(plotly_colors)]
all_embeddings.extend(embeddings)
all_texts.extend(texts)
all_colors.extend([color] * len(embeddings))
embeddings_array = np.array(all_embeddings)
# Adjust perplexity if there are enough points
if embeddings_array.shape[0] > 30:
perplexity = min(perplexity, 30)
if embeddings_array.shape[0] > perplexity:
# Compute t-SNE embeddings
tsne_embeddings = TSNE(n_components=2, perplexity=perplexity).fit_transform(embeddings_array)
# Create Plotly scatter plot
scatter_plot = go.Scatter(
x=tsne_embeddings[:, 0],
y=tsne_embeddings[:, 1],
mode="markers",
marker=dict(
size=8,
color=all_colors, # Use the assigned colors from Plotly palette
opacity=0.7
),
text=all_texts,
hoverinfo="text"
)
layout = go.Layout(
title="t-SNE Embeddings of Data Dictionaries",
xaxis=dict(title="t-SNE Component 1"),
yaxis=dict(title="t-SNE Component 2"),
)
fig = go.Figure(data=[scatter_plot], layout=layout)
# Display the plot
fig.show()
else:
print("Too few data dictionary entries to visualize. Adjust param 'perplexity' to a value less then the number "
"of data points.")
51 changes: 51 additions & 0 deletions tests/resources/data_dict1.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
VAR,DESC
VAR_1,VAR_1 amounts measured under controlled conditions.
VAR_2,VAR_2 data collected on a monthly basis.
VAR_3,Summarizes VAR_3 occurrences during the experiment.
VAR_4,Summarizes VAR_4 occurrences during the experiment.
VAR_5,This is the index value for VAR_5 activity.
VAR_6,Tracks fluctuations in VAR_6 over time.
VAR_7,This is the index value for VAR_7 activity.
VAR_8,It records the count of VAR_8 for each observed period.
VAR_9,VAR_9 data collected on a monthly basis.
VAR_10,Summarizes VAR_10 occurrences during the experiment.
VAR_11,VAR_11 amounts measured under controlled conditions.
VAR_12,It records the count of VAR_12 for each observed period.
VAR_13,Summarizes VAR_13 occurrences during the experiment.
VAR_14,Summarizes VAR_14 occurrences during the experiment.
VAR_15,VAR_15 data collected on a monthly basis.
VAR_16,VAR_16 amounts measured under controlled conditions.
VAR_17,It records the count of VAR_17 for each observed period.
VAR_18,VAR_18 data collected on a monthly basis.
VAR_19,Tracks fluctuations in VAR_19 over time.
VAR_20,An indicator representing the levels of VAR_20 in a sample.
VAR_21,Variable associated with the change in VAR_21 throughout the day.
VAR_22,Variable associated with the change in VAR_22 throughout the day.
VAR_23,This variable captures the daily average VAR_23 readings.
VAR_24,This is the index value for VAR_24 activity.
VAR_25,Tracks fluctuations in VAR_25 over time.
VAR_26,Tracks fluctuations in VAR_26 over time.
VAR_27,Variable associated with the change in VAR_27 throughout the day.
VAR_28,VAR_28 measurement gathered every quarter.
VAR_29,VAR_29 measurement gathered every quarter.
VAR_30,VAR_30 data collected on a monthly basis.
VAR_31,It records the count of VAR_31 for each observed period.
VAR_32,Summarizes VAR_32 occurrences during the experiment.
VAR_33,Variable associated with the change in VAR_33 throughout the day.
VAR_34,VAR_34 amounts measured under controlled conditions.
VAR_35,It records the count of VAR_35 for each observed period.
VAR_36,VAR_36 amounts measured under controlled conditions.
VAR_37,An indicator representing the levels of VAR_37 in a sample.
VAR_38,Summarizes VAR_38 occurrences during the experiment.
VAR_39,This variable captures the daily average VAR_39 readings.
VAR_40,Variable associated with the change in VAR_40 throughout the day.
VAR_41,VAR_41 amounts measured under controlled conditions.
VAR_42,Variable associated with the change in VAR_42 throughout the day.
VAR_43,Variable associated with the change in VAR_43 throughout the day.
VAR_44,This variable captures the daily average VAR_44 readings.
VAR_45,Summarizes VAR_45 occurrences during the experiment.
VAR_46,Variable associated with the change in VAR_46 throughout the day.
VAR_47,Variable associated with the change in VAR_47 throughout the day.
VAR_48,An indicator representing the levels of VAR_48 in a sample.
VAR_49,VAR_49 amounts measured under controlled conditions.
VAR_50,VAR_50 amounts measured under controlled conditions.
51 changes: 51 additions & 0 deletions tests/resources/data_dict2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
VAR,DESC
VAR_1,Measurement of VAR_1 on a daily timescale.
VAR_2,Measurement of VAR_2 on a daily timescale.
VAR_3,Monthly statistics for VAR_3 collection.
VAR_4,Variable linked to fluctuations in VAR_4 across different times.
VAR_5,Variable linked to fluctuations in VAR_5 across different times.
VAR_6,Captures the total number of VAR_6 per session.
VAR_7,Tracks variations in VAR_7 over multiple intervals.
VAR_8,Measurement of VAR_8 on a daily timescale.
VAR_9,Captures the total number of VAR_9 per session.
VAR_10,Measurement of VAR_10 on a daily timescale.
VAR_11,Monthly statistics for VAR_11 collection.
VAR_12,Captures the total number of VAR_12 per session.
VAR_13,The count of VAR_13 instances recorded throughout the study.
VAR_14,Quarterly data focused on VAR_14 trends.
VAR_15,Monthly statistics for VAR_15 collection.
VAR_16,VAR_16 index reflecting seasonal variation.
VAR_17,Measurement of VAR_17 on a daily timescale.
VAR_18,Variable linked to fluctuations in VAR_18 across different times.
VAR_19,Captures the total number of VAR_19 per session.
VAR_20,The count of VAR_20 instances recorded throughout the study.
VAR_21,Captures the total number of VAR_21 per session.
VAR_22,An estimate of the concentration of VAR_22 in specific regions.
VAR_23,Captures the total number of VAR_23 per session.
VAR_24,The count of VAR_24 instances recorded throughout the study.
VAR_25,Quarterly data focused on VAR_25 trends.
VAR_26,VAR_26 index reflecting seasonal variation.
VAR_27,Variable linked to fluctuations in VAR_27 across different times.
VAR_28,Variable linked to fluctuations in VAR_28 across different times.
VAR_29,Tracks variations in VAR_29 over multiple intervals.
VAR_30,VAR_30 values measured under precise situations.
VAR_31,Tracks variations in VAR_31 over multiple intervals.
VAR_32,Monthly statistics for VAR_32 collection.
VAR_33,VAR_33 values measured under precise situations.
VAR_34,Tracks variations in VAR_34 over multiple intervals.
VAR_35,VAR_35 values measured under precise situations.
VAR_36,Tracks variations in VAR_36 over multiple intervals.
VAR_37,An estimate of the concentration of VAR_37 in specific regions.
VAR_38,VAR_38 index reflecting seasonal variation.
VAR_39,VAR_39 index reflecting seasonal variation.
VAR_40,Variable linked to fluctuations in VAR_40 across different times.
VAR_41,Quarterly data focused on VAR_41 trends.
VAR_42,VAR_42 values measured under precise situations.
VAR_43,An estimate of the concentration of VAR_43 in specific regions.
VAR_44,Variable linked to fluctuations in VAR_44 across different times.
VAR_45,The count of VAR_45 instances recorded throughout the study.
VAR_46,Monthly statistics for VAR_46 collection.
VAR_47,Monthly statistics for VAR_47 collection.
VAR_48,An estimate of the concentration of VAR_48 in specific regions.
VAR_49,Tracks variations in VAR_49 over multiple intervals.
VAR_50,Variable linked to fluctuations in VAR_50 across different times.
10 changes: 9 additions & 1 deletion tests/test_visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
bar_chart_average_acc_two_distributions,
enrichment_plot,
scatter_plot_all_cohorts,
scatter_plot_two_distributions,
scatter_plot_two_distributions, plot_embeddings,
)


Expand Down Expand Up @@ -127,3 +127,11 @@ def test_bar_chart_average_acc_two_distributions(self):
mpnet_2 = pd.DataFrame({"M1": [0.9, 0.78, 0.68], "M2": [0.69, 0.9, 0.68], "M3": [0.71, 0.75, 0.9]}, index=labels).T

bar_chart_average_acc_two_distributions(fuzzy_1, gpt_1, mpnet_1, fuzzy_2, gpt_2, mpnet_2, "title", "AD", "PD")

def test_plot_data_dict(self):
TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
data_dictionary_source_1 = DataDictionarySource(os.path.join(TEST_DIR_PATH, "resources", "data_dict1.csv"),
"VAR", "DESC")
data_dictionary_source_2 = DataDictionarySource(os.path.join(TEST_DIR_PATH, "resources", "data_dict2.csv"),
"VAR", "DESC")
plot_embeddings([data_dictionary_source_1, data_dictionary_source_2])

0 comments on commit 7a5c047

Please sign in to comment.