Skip to content

Commit b7cf7b5

Browse files
committed
embedding types response update
1 parent 09fb178 commit b7cf7b5

File tree

4 files changed

+60
-12
lines changed

4 files changed

+60
-12
lines changed

cohere/client.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from cohere.responses.dataset import BaseDataset, Dataset, DatasetUsage, ParseInfo
4848
from cohere.responses.detectlang import DetectLanguageResponse, Language
4949
from cohere.responses.embed_job import EmbedJob
50-
from cohere.responses.embeddings import Embeddings
50+
from cohere.responses.embeddings import EmbeddingResponses, Embeddings
5151
from cohere.responses.feedback import (
5252
GenerateFeedbackResponse,
5353
GeneratePreferenceFeedbackResponse,
@@ -407,9 +407,7 @@ def embed(
407407
input_type (str): (Optional) One of "classification", "clustering", "search_document", "search_query". The type of input text provided to embed.
408408
embedding_types (List[str]): (Optional) Specifies the types of embeddings you want to get back. Not required and default is None, which returns the float embeddings in the response's embeddings field. Can be one or more of the following types: "float", "int8", "uint8", "binary", "ubinary".
409409
"""
410-
responses = {
411-
"embeddings": [],
412-
}
410+
embedding_responses = EmbeddingResponses()
413411
json_bodys = []
414412

415413
for i in range(0, len(texts), self.batch_size):
@@ -426,11 +424,12 @@ def embed(
426424

427425
meta = None
428426
for result in self._executor.map(lambda json_body: self._request(cohere.EMBED_URL, json=json_body), json_bodys):
429-
responses["embeddings"].extend(result["embeddings"])
427+
embedding_responses.add_response(result)
430428
meta = result["meta"] if not meta else meta
431429

432430
return Embeddings(
433-
embeddings=responses["embeddings"],
431+
embeddings=embedding_responses.get_embeddings(),
432+
response_type=embedding_responses.response_type,
434433
meta=meta,
435434
)
436435

cohere/client_async.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
Connector,
3131
DetectLanguageResponse,
3232
Detokenization,
33+
EmbeddingResponses,
3334
Embeddings,
3435
GenerateFeedbackResponse,
3536
GeneratePreferenceFeedbackResponse,
@@ -303,9 +304,13 @@ async def embed(
303304
]
304305
responses = await asyncio.gather(*[self._request(cohere.EMBED_URL, json) for json in json_bodys])
305306
meta = responses[0]["meta"] if responses else None
307+
embedding_responses = EmbeddingResponses()
308+
for response in responses:
309+
embedding_responses.add_response(response)
306310

307311
return Embeddings(
308-
embeddings=[e for res in responses for e in res["embeddings"]],
312+
embeddings=embedding_responses.get_embeddings(),
313+
response_type=embedding_responses.response_type,
309314
meta=meta,
310315
)
311316

cohere/responses/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from cohere.responses.connector import Connector, ConnectorOAuth, ConnectorServiceAuth
66
from cohere.responses.dataset import AsyncDataset, Dataset
77
from cohere.responses.detectlang import DetectLanguageResponse, Language
8-
from cohere.responses.embeddings import Embeddings
8+
from cohere.responses.embeddings import EmbeddingResponses, Embeddings
99
from cohere.responses.feedback import (
1010
GenerateFeedbackResponse,
1111
GeneratePreferenceFeedbackResponse,

cohere/responses/embeddings.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,65 @@
1-
from typing import Any, Dict, Iterator, List, Optional
1+
from typing import Any, Dict, Iterator, List, Optional, Union
22

33
from cohere.responses.base import CohereObject
44

5+
EMBEDDINGS_FLOATS_RESPONSE_TYPE = "embeddings_floats"
6+
EMBEDDINGS_BY_TYPE_RESPONSE_TYPE = "embeddings_by_type"
7+
8+
9+
class EmbeddingsByType(CohereObject):
10+
def __init__(
11+
self,
12+
float: Optional[List[List[float]]] = None,
13+
int8: Optional[List[List[int]]] = None,
14+
uint8: Optional[List[List[int]]] = None,
15+
binary: Optional[List[List[int]]] = None,
16+
ubinary: Optional[List[List[int]]] = None,
17+
) -> None:
18+
self.float = float
19+
self.int8 = int8
20+
self.uint8 = uint8
21+
self.binary = binary
22+
self.ubinary = ubinary
23+
524

625
class Embeddings(CohereObject):
726
def __init__(
827
self,
9-
embeddings: List[List[float]],
10-
compressed_embeddings: List[List[int]] = None,
28+
embeddings: Union[List[List[float]], EmbeddingsByType],
29+
response_type: str,
1130
meta: Optional[Dict[str, Any]] = None,
1231
) -> None:
32+
self.response_type = response_type
1333
self.embeddings = embeddings
14-
self.compressed_embeddings = compressed_embeddings
1534
self.meta = meta
1635

1736
def __iter__(self) -> Iterator:
1837
return iter(self.embeddings)
1938

2039
def __len__(self) -> int:
2140
return len(self.embeddings)
41+
42+
43+
class EmbeddingResponses:
44+
def __init__(
45+
self,
46+
) -> None:
47+
self.response_type = None
48+
self.embeddings_floats = []
49+
self.embeddings_by_type = {}
50+
51+
def add_response(self, response):
52+
self.response_type = response["response_type"]
53+
if self.response_type == EMBEDDINGS_FLOATS_RESPONSE_TYPE:
54+
self.embeddings_floats.extend(response["embeddings"])
55+
elif self.response_type == EMBEDDINGS_BY_TYPE_RESPONSE_TYPE:
56+
for k, v in response["embeddings"].items():
57+
if k not in self.embeddings_by_type:
58+
self.embeddings_by_type[k] = []
59+
self.embeddings_by_type[k].extend(v)
60+
61+
def get_embeddings(self) -> Union[List[List[float]], EmbeddingsByType]:
62+
if self.response_type == EMBEDDINGS_FLOATS_RESPONSE_TYPE:
63+
return self.embeddings_floats
64+
elif self.response_type == EMBEDDINGS_BY_TYPE_RESPONSE_TYPE:
65+
return EmbeddingsByType(**self.embeddings_by_type)

0 commit comments

Comments
 (0)