From ce0c039a2665e9c39cc99b5f0c913650a255b7d8 Mon Sep 17 00:00:00 2001 From: TimAdams84 Date: Mon, 1 Jul 2024 17:45:55 +0200 Subject: [PATCH] feat: add sanitization --- datastew/embedding.py | 17 ++++++++++++----- tests/test_embedding.py | 9 +++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/datastew/embedding.py b/datastew/embedding.py index 510ac0c..8a77ac6 100644 --- a/datastew/embedding.py +++ b/datastew/embedding.py @@ -12,6 +12,9 @@ def get_embedding(self, text: str) -> [float]: def get_embeddings(self, messages: [str]) -> [[float]]: pass + def sanitize(self, message: str) -> str: + return message.strip().lower() + class GPT4Adapter(EmbeddingModel): def __init__(self, api_key: str): @@ -27,18 +30,20 @@ def get_embedding(self, text: str, model="text-embedding-ada-002"): return None if isinstance(text, str): text = text.replace("\n", " ") + text = self.sanitize(text) return openai.Embedding.create(input=[text], model=model)["data"][0]["embedding"] except Exception as e: logging.error(f"Error getting embedding for {text}: {e}") return None def get_embeddings(self, messages: [str], model="text-embedding-ada-002", max_length=2048): + sanitized_messages = [self.sanitize(message) for message in messages] embeddings = [] - total_chunks = (len(messages) + max_length - 1) // max_length + total_chunks = (len(sanitized_messages) + max_length - 1) // max_length current_chunk = 0 - for i in range(0, len(messages), max_length): + for i in range(0, len(sanitized_messages), max_length): current_chunk += 1 - chunk = messages[i:i + max_length] + chunk = sanitized_messages[i:i + max_length] response = openai.Embedding.create(input=chunk, model=model) embeddings.extend([item["embedding"] for item in response["data"]]) logging.info("Processed chunk %d/%d", current_chunk, total_chunks) @@ -58,16 +63,18 @@ def get_embedding(self, text: str): return None if isinstance(text, str): text = text.replace("\n", " ") + text = self.sanitize(text) return self.mpnet_model.encode(text) except Exception as e: logging.error(f"Error getting embedding for {text}: {e}") return None def get_embeddings(self, messages: [str]) -> [[float]]: + sanitized_messages = [self.sanitize(message) for message in messages] try: - embeddings = self.mpnet_model.encode(messages) + embeddings = self.mpnet_model.encode(sanitized_messages) except Exception as e: - logging.error(f"Failed for messages {messages}") + logging.error(f"Failed for messages {sanitized_messages}") flattened_embeddings = [[float(element) for element in row] for row in embeddings] return flattened_embeddings diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 019bb3f..976db2a 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -26,3 +26,12 @@ def test_text_embedding(self): text_embedding = TextEmbedding(text, embedding) self.assertEqual(text_embedding.text, text) self.assertEqual(text_embedding.embedding, embedding) + + def test_sanitization(self): + text1 = " Test" + text2 = "test " + embedding1 = self.mpnet_adapter.get_embedding(text1) + embedding2 = self.mpnet_adapter.get_embedding(text2) + self.assertListEqual(embedding1.tolist(), embedding2.tolist()) + +