Skip to content

udpate embed request with embedding_types #359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from cohere.responses.dataset import BaseDataset, Dataset, DatasetUsage, ParseInfo
from cohere.responses.detectlang import DetectLanguageResponse, Language
from cohere.responses.embed_job import EmbedJob
from cohere.responses.embeddings import Embeddings
from cohere.responses.embeddings import EmbeddingResponses, Embeddings
from cohere.responses.feedback import (
GenerateFeedbackResponse,
GeneratePreferenceFeedbackResponse,
Expand Down Expand Up @@ -395,22 +395,19 @@ def embed(
texts: List[str],
model: Optional[str] = None,
truncate: Optional[str] = None,
compression: Optional[str] = None,
input_type: Optional[str] = None,
embedding_types: Optional[List[str]] = None,
) -> Embeddings:
"""Returns an Embeddings object for the provided texts. Visit https://cohere.ai/embed to learn about embeddings.

Args:
text (List[str]): A list of strings to embed.
model (str): (Optional) The model ID to use for embedding the text.
truncate (str): (Optional) One of NONE|START|END, defaults to END. How the API handles text longer than the maximum token length.
compression (str): (Optional) One of "int8" or "binary". The type of compression to use for the embeddings.
input_type (str): (Optional) One of "classification", "clustering", "search_document", "search_query". The type of input text provided to embed.
embedding_types (List[str]): (Optional) Specifies the types of embeddings you want to get back. Not required and default is None, which returns the float embeddings in the response's embeddings field. Can be one or more of the following types: "float", "int8", "uint8", "binary", "ubinary".
"""
responses = {
"embeddings": [],
"compressed_embeddings": [],
}
embedding_responses = EmbeddingResponses()
json_bodys = []

for i in range(0, len(texts), self.batch_size):
Expand All @@ -420,20 +417,19 @@ def embed(
"model": model,
"texts": texts_batch,
"truncate": truncate,
"compression": compression,
"input_type": input_type,
"embedding_types": embedding_types,
}
)

meta = None
for result in self._executor.map(lambda json_body: self._request(cohere.EMBED_URL, json=json_body), json_bodys):
responses["embeddings"].extend(result["embeddings"])
responses["compressed_embeddings"].extend(result.get("compressed_embeddings", []))
embedding_responses.add_response(result)
meta = result["meta"] if not meta else meta

return Embeddings(
embeddings=responses["embeddings"],
compressed_embeddings=responses["compressed_embeddings"],
embeddings=embedding_responses.get_embeddings(),
response_type=embedding_responses.response_type,
meta=meta,
)

Expand Down
16 changes: 9 additions & 7 deletions cohere/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Connector,
DetectLanguageResponse,
Detokenization,
EmbeddingResponses,
Embeddings,
GenerateFeedbackResponse,
GeneratePreferenceFeedbackResponse,
Expand Down Expand Up @@ -279,36 +280,37 @@ async def embed(
texts: List[str],
model: Optional[str] = None,
truncate: Optional[str] = None,
compression: Optional[str] = None,
input_type: Optional[str] = None,
embedding_types: Optional[List[str]] = None,
) -> Embeddings:
"""Returns an Embeddings object for the provided texts. Visit https://cohere.ai/embed to learn about embeddings.

Args:
text (List[str]): A list of strings to embed.
model (str): (Optional) The model ID to use for embedding the text.
truncate (str): (Optional) One of NONE|START|END, defaults to END. How the API handles text longer than the maximum token length.
compression (str): (Optional) One of "int8" or "binary". The type of compression to use for the embeddings.
input_type (str): (Optional) One of "classification", "clustering", "search_document", "search_query". The type of input text provided to embed.
embedding_types (List[str]): (Optional) Specifies the types of embeddings you want to get back. Not required and default is None, which returns the float embeddings in the response's embeddings field. Can be one or more of the following types: "float", "int8", "uint8", "binary", "ubinary".
"""
json_bodys = [
dict(
texts=texts[i : i + cohere.COHERE_EMBED_BATCH_SIZE],
model=model,
truncate=truncate,
compression=compression,
input_type=input_type,
embedding_types=embedding_types,
)
for i in range(0, len(texts), cohere.COHERE_EMBED_BATCH_SIZE)
]
responses = await asyncio.gather(*[self._request(cohere.EMBED_URL, json) for json in json_bodys])
meta = responses[0]["meta"] if responses else None
embedding_responses = EmbeddingResponses()
for response in responses:
embedding_responses.add_response(response)

return Embeddings(
embeddings=[e for res in responses for e in res["embeddings"]],
compressed_embeddings=[e for res in responses for e in res["compressed_embeddings"]]
if compression
else None,
embeddings=embedding_responses.get_embeddings(),
response_type=embedding_responses.response_type,
meta=meta,
)

Expand Down
2 changes: 1 addition & 1 deletion cohere/responses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from cohere.responses.connector import Connector, ConnectorOAuth, ConnectorServiceAuth
from cohere.responses.dataset import AsyncDataset, Dataset
from cohere.responses.detectlang import DetectLanguageResponse, Language
from cohere.responses.embeddings import Embeddings
from cohere.responses.embeddings import EmbeddingResponses, Embeddings
from cohere.responses.feedback import (
GenerateFeedbackResponse,
GeneratePreferenceFeedbackResponse,
Expand Down
52 changes: 48 additions & 4 deletions cohere/responses/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,65 @@
from typing import Any, Dict, Iterator, List, Optional
from typing import Any, Dict, Iterator, List, Optional, Union

from cohere.responses.base import CohereObject

EMBEDDINGS_FLOATS_RESPONSE_TYPE = "embeddings_floats"
EMBEDDINGS_BY_TYPE_RESPONSE_TYPE = "embeddings_by_type"


class EmbeddingsByType(CohereObject):
def __init__(
self,
float: Optional[List[List[float]]] = None,
int8: Optional[List[List[int]]] = None,
uint8: Optional[List[List[int]]] = None,
binary: Optional[List[List[int]]] = None,
ubinary: Optional[List[List[int]]] = None,
) -> None:
self.float = float
self.int8 = int8
self.uint8 = uint8
self.binary = binary
self.ubinary = ubinary


class Embeddings(CohereObject):
def __init__(
self,
embeddings: List[List[float]],
compressed_embeddings: List[List[int]] = None,
embeddings: Union[List[List[float]], EmbeddingsByType],
response_type: str,
meta: Optional[Dict[str, Any]] = None,
) -> None:
self.response_type = response_type
self.embeddings = embeddings
self.compressed_embeddings = compressed_embeddings
self.meta = meta

def __iter__(self) -> Iterator:
return iter(self.embeddings)

def __len__(self) -> int:
return len(self.embeddings)


class EmbeddingResponses:
def __init__(
self,
) -> None:
self.response_type = None
self.embeddings_floats = []
self.embeddings_by_type = {}

def add_response(self, response):
self.response_type = response["response_type"]
if self.response_type == EMBEDDINGS_FLOATS_RESPONSE_TYPE:
self.embeddings_floats.extend(response["embeddings"])
elif self.response_type == EMBEDDINGS_BY_TYPE_RESPONSE_TYPE:
for k, v in response["embeddings"].items():
if k not in self.embeddings_by_type:
self.embeddings_by_type[k] = []
self.embeddings_by_type[k].extend(v)

def get_embeddings(self) -> Union[List[List[float]], EmbeddingsByType]:
if self.response_type == EMBEDDINGS_FLOATS_RESPONSE_TYPE:
return self.embeddings_floats
elif self.response_type == EMBEDDINGS_BY_TYPE_RESPONSE_TYPE:
return EmbeddingsByType(**self.embeddings_by_type)