Skip to content

Commit 5ca0f0a

Browse files
authored
Merge pull request #35 from SCAI-BIO/add-embedding-function-to-data-dict
Add function to directly retrieve embedding from data dictionary
2 parents 2d6bcb3 + 1cba9ef commit 5ca0f0a

File tree

2 files changed

+58
-10
lines changed

2 files changed

+58
-10
lines changed

datastew/process/parsing.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from abc import ABC
2+
from typing import Dict
3+
from datastew.embedding import EmbeddingModel, MPNetAdapter
24

35
import pandas as pd
46
import numpy as np
@@ -52,16 +54,34 @@ def to_dataframe(self) -> pd.DataFrame:
5254

5355

5456
class DataDictionarySource(Source):
55-
"""
56-
Contains mapping of variable -> description
57-
"""
5857

5958
def __init__(self, file_path: str, variable_field: str, description_field: str):
60-
self.file_path = file_path
61-
self.variable_field = variable_field
62-
self.description_field = description_field
63-
64-
def to_dataframe(self) -> pd.DataFrame:
59+
"""
60+
Initialize the DataDictionarySource with the path to the data dictionary file
61+
and the fields that represent the variables and their descriptions.
62+
63+
:param file_path: Path to the data dictionary file.
64+
:param variable_field: The column that contains the variable names.
65+
:param description_field: The column that contains the variable descriptions.
66+
"""
67+
self.file_path: str = file_path
68+
self.variable_field: str = variable_field
69+
self.description_field: str = description_field
70+
71+
def to_dataframe(self, dropna: bool = True) -> pd.DataFrame:
72+
"""
73+
Load the data dictionary file into a pandas DataFrame, select the variable and
74+
description fields, and ensure they exist. Optionally remove rows with missing
75+
variables or descriptions based on the 'dropna' parameter.
76+
77+
:param dropna: If True, rows with missing 'variable' or 'description' values are
78+
dropped. Defaults to True.
79+
:return: A DataFrame containing two columns:
80+
- 'variable': The variable names from the data dictionary.
81+
- 'description': The descriptions corresponding to each variable.
82+
:raises ValueError: If either the variable field or the description field is not
83+
found in the data dictionary file.
84+
"""
6585
df = super().to_dataframe()
6686
# sanity check
6787
if self.variable_field not in df.columns:
@@ -70,9 +90,32 @@ def to_dataframe(self) -> pd.DataFrame:
7090
raise ValueError(f"Description field {self.description_field} not found in {self.file_path}")
7191
df = df[[self.variable_field, self.description_field]]
7292
df = df.rename(columns={self.variable_field: "variable", self.description_field: "description"})
73-
df.dropna(subset=["variable", "description"], inplace=True)
93+
if dropna:
94+
df.dropna(subset=["variable", "description"], inplace=True)
7495
return df
75-
96+
97+
def get_embeddings(self, embedding_model: EmbeddingModel = None) -> Dict[str, list]:
98+
"""
99+
Compute embedding vectors for each description in the data dictionary. The
100+
resulting vectors are mapped to their respective variables and returned as a
101+
dictionary.
102+
103+
:param embedding_model: The embedding model used to compute embeddings for the descriptions.
104+
Defaults to MPNetAdapter.
105+
:return: A dictionary where each key is a variable name and the value is the
106+
embedding vector for the corresponding description.
107+
:rtype: Dict[str, list]
108+
"""
109+
# Compute vectors for all descriptions
110+
df: pd.DataFrame = self.to_dataframe()
111+
descriptions: list[str] = df["description"].tolist()
112+
if embedding_model is None:
113+
embedding_model = MPNetAdapter()
114+
embeddings = embedding_model.get_embeddings(descriptions)
115+
# variable identify descriptions -> variable to embedding
116+
variable_to_embedding: Dict[str, list] = dict(zip(df["variable"], embeddings))
117+
return variable_to_embedding
118+
76119

77120
class EmbeddingSource:
78121
def __init__(self, source_path: str):

tests/test_parser.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,8 @@ def test_parse_data_dict_excel(self):
3939
mapping_table.add_descriptions(data_dictionary_source)
4040
mappings = mapping_table.get_mappings()
4141
self.assertEqual(11, len(mappings))
42+
43+
def test_get_embeddings(self):
44+
vectors = self.data_dictionary_source.get_embeddings()
45+
self.assertEqual(len(vectors), 11)
46+
self.assertIn("Q_8", vectors)

0 commit comments

Comments
 (0)