Skip to content

Commit

Permalink
feat: add sanitization
Browse files Browse the repository at this point in the history
  • Loading branch information
tiadams committed Jul 1, 2024
1 parent 888697e commit ce0c039
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
17 changes: 12 additions & 5 deletions datastew/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def get_embedding(self, text: str) -> [float]:
def get_embeddings(self, messages: [str]) -> [[float]]:
pass

def sanitize(self, message: str) -> str:
return message.strip().lower()


class GPT4Adapter(EmbeddingModel):
def __init__(self, api_key: str):
Expand All @@ -27,18 +30,20 @@ def get_embedding(self, text: str, model="text-embedding-ada-002"):
return None
if isinstance(text, str):
text = text.replace("\n", " ")
text = self.sanitize(text)
return openai.Embedding.create(input=[text], model=model)["data"][0]["embedding"]
except Exception as e:
logging.error(f"Error getting embedding for {text}: {e}")
return None

def get_embeddings(self, messages: [str], model="text-embedding-ada-002", max_length=2048):
sanitized_messages = [self.sanitize(message) for message in messages]
embeddings = []
total_chunks = (len(messages) + max_length - 1) // max_length
total_chunks = (len(sanitized_messages) + max_length - 1) // max_length
current_chunk = 0
for i in range(0, len(messages), max_length):
for i in range(0, len(sanitized_messages), max_length):
current_chunk += 1
chunk = messages[i:i + max_length]
chunk = sanitized_messages[i:i + max_length]
response = openai.Embedding.create(input=chunk, model=model)
embeddings.extend([item["embedding"] for item in response["data"]])
logging.info("Processed chunk %d/%d", current_chunk, total_chunks)
Expand All @@ -58,16 +63,18 @@ def get_embedding(self, text: str):
return None
if isinstance(text, str):
text = text.replace("\n", " ")
text = self.sanitize(text)
return self.mpnet_model.encode(text)
except Exception as e:
logging.error(f"Error getting embedding for {text}: {e}")
return None

def get_embeddings(self, messages: [str]) -> [[float]]:
sanitized_messages = [self.sanitize(message) for message in messages]
try:
embeddings = self.mpnet_model.encode(messages)
embeddings = self.mpnet_model.encode(sanitized_messages)
except Exception as e:
logging.error(f"Failed for messages {messages}")
logging.error(f"Failed for messages {sanitized_messages}")
flattened_embeddings = [[float(element) for element in row] for row in embeddings]
return flattened_embeddings

Expand Down
9 changes: 9 additions & 0 deletions tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,12 @@ def test_text_embedding(self):
text_embedding = TextEmbedding(text, embedding)
self.assertEqual(text_embedding.text, text)
self.assertEqual(text_embedding.embedding, embedding)

def test_sanitization(self):
text1 = " Test"
text2 = "test "
embedding1 = self.mpnet_adapter.get_embedding(text1)
embedding2 = self.mpnet_adapter.get_embedding(text2)
self.assertListEqual(embedding1.tolist(), embedding2.tolist())


0 comments on commit ce0c039

Please sign in to comment.