Skip to content

Commit acddc01

Browse files
committed
Revert back to b1dbf95 due to next commit breaking staging
1 parent 61f6f71 commit acddc01

File tree

10 files changed

+58
-78
lines changed

10 files changed

+58
-78
lines changed

pyproject.toml

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "cohere"
3-
version = "5.13.2"
3+
version = "5.12.0"
44
description = ""
55
readme = "README.md"
66
authors = []
@@ -36,15 +36,6 @@ boto3 = { version="^1.34.0", optional = true}
3636
fastavro = "^1.9.4"
3737
httpx = ">=0.21.2"
3838
httpx-sse = "0.4.0"
39-
# Without specifying a version, the numpy and pandas will be 1.24.4 and 2.0.3 respectively
40-
numpy = [
41-
{ version="~1.24.4", python = "<3.12", optional = true },
42-
{ version="~1.26", python = ">=3.12", optional = true }
43-
]
44-
pandas = [
45-
{ version="~2.0.3", python = "<3.13", optional = true },
46-
{ version="~2.2.3", python = ">=3.13", optional = true }
47-
]
4839
parameterized = "^0.9.0"
4940
pydantic = ">= 1.9.2"
5041
pydantic-core = "^2.18.2"

reference.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1793,7 +1793,7 @@ We recommend a maximum of 1,000 documents for optimal endpoint performance.
17931793
<dl>
17941794
<dd>
17951795

1796-
**model:** `typing.Optional[str]` — The identifier of the model to use, eg `rerank-v3.5`.
1796+
**model:** `typing.Optional[str]` — The identifier of the model to use, one of : `rerank-english-v3.0`, `rerank-multilingual-v3.0`, `rerank-english-v2.0`, `rerank-multilingual-v2.0`
17971797

17981798
</dd>
17991799
</dl>
@@ -2382,6 +2382,7 @@ response = client.v2.chat_stream(
23822382
p=1.1,
23832383
return_prompt=True,
23842384
logprobs=True,
2385+
stream=True,
23852386
)
23862387
for chunk in response:
23872388
yield chunk
@@ -2654,6 +2655,7 @@ client.v2.chat(
26542655
content="messages",
26552656
)
26562657
],
2658+
stream=False,
26572659
)
26582660

26592661
```
@@ -2971,7 +2973,7 @@ Available models and corresponding embedding dimensions:
29712973

29722974
**embedding_types:** `typing.Sequence[EmbeddingType]`
29732975

2974-
Specifies the types of embeddings you want to get back. Can be one or more of the following types.
2976+
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
29752977

29762978
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
29772979
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
@@ -3084,7 +3086,15 @@ client.v2.rerank(
30843086
<dl>
30853087
<dd>
30863088

3087-
**model:** `str` — The identifier of the model to use, eg `rerank-v3.5`.
3089+
**model:** `str`
3090+
3091+
The identifier of the model to use.
3092+
3093+
Supported models:
3094+
- `rerank-english-v3.0`
3095+
- `rerank-multilingual-v3.0`
3096+
- `rerank-english-v2.0`
3097+
- `rerank-multilingual-v2.0`
30883098

30893099
</dd>
30903100
</dl>

src/cohere/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@
245245
)
246246
from . import connectors, datasets, embed_jobs, finetuning, models, v2
247247
from .aws_client import AwsClient
248-
from .bedrock_client import BedrockClient, BedrockClientV2
248+
from .bedrock_client import BedrockClient
249249
from .client import AsyncClient, Client
250250
from .client_v2 import AsyncClientV2, ClientV2
251251
from .datasets import (
@@ -257,7 +257,7 @@
257257
)
258258
from .embed_jobs import CreateEmbedJobRequestTruncate
259259
from .environment import ClientEnvironment
260-
from .sagemaker_client import SagemakerClient, SagemakerClientV2
260+
from .sagemaker_client import SagemakerClient
261261
from .v2 import (
262262
V2ChatRequestDocumentsItem,
263263
V2ChatRequestSafetyMode,
@@ -287,7 +287,6 @@
287287
"AwsClient",
288288
"BadRequestError",
289289
"BedrockClient",
290-
"BedrockClientV2",
291290
"ChatCitation",
292291
"ChatCitationGenerationEvent",
293292
"ChatConnector",
@@ -461,7 +460,6 @@
461460
"ResponseFormat",
462461
"ResponseFormatV2",
463462
"SagemakerClient",
464-
"SagemakerClientV2",
465463
"SearchQueriesGenerationStreamedChatResponse",
466464
"SearchResultsStreamedChatResponse",
467465
"ServiceUnavailableError",

src/cohere/aws_client.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
timeout: typing.Optional[float] = None,
5757
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
5858
):
59-
ClientV2.__init__(
59+
Client.__init__(
6060
self,
6161
base_url="https://api.cohere.com", # this url is unused for BedrockClient
6262
environment=ClientEnvironment.PRODUCTION,
@@ -196,13 +196,6 @@ def _hook(
196196

197197
return _hook
198198

199-
def get_boto3_session(
200-
**kwargs: typing.Any,
201-
):
202-
non_none_args = {k: v for k, v in kwargs.items() if v is not None}
203-
return lazy_boto3().Session(**non_none_args)
204-
205-
206199

207200
def map_request_to_bedrock(
208201
service: str,
@@ -211,22 +204,19 @@ def map_request_to_bedrock(
211204
aws_session_token: typing.Optional[str] = None,
212205
aws_region: typing.Optional[str] = None,
213206
) -> EventHook:
214-
session = get_boto3_session(
207+
session = lazy_boto3().Session(
215208
region_name=aws_region,
216209
aws_access_key_id=aws_access_key,
217210
aws_secret_access_key=aws_secret_key,
218211
aws_session_token=aws_session_token,
219212
)
220-
aws_region = session.region_name
221213
credentials = session.get_credentials()
222-
signer = lazy_botocore().auth.SigV4Auth(credentials, service, aws_region)
214+
signer = lazy_botocore().auth.SigV4Auth(credentials, service, session.region_name)
223215

224216
def _event_hook(request: httpx.Request) -> None:
225217
headers = request.headers.copy()
226218
del headers["connection"]
227219

228-
229-
api_version = request.url.path.split("/")[-2]
230220
endpoint = request.url.path.split("/")[-1]
231221
body = json.loads(request.read())
232222
model = body["model"]
@@ -240,9 +230,6 @@ def _event_hook(request: httpx.Request) -> None:
240230
request.url = URL(url)
241231
request.headers["host"] = request.url.host
242232

243-
if endpoint == "rerank":
244-
body["api_version"] = get_api_version(version=api_version)
245-
246233
if "stream" in body:
247234
del body["stream"]
248235

@@ -268,6 +255,20 @@ def _event_hook(request: httpx.Request) -> None:
268255
return _event_hook
269256

270257

258+
def get_endpoint_from_url(url: str,
259+
chat_model: typing.Optional[str] = None,
260+
embed_model: typing.Optional[str] = None,
261+
generate_model: typing.Optional[str] = None,
262+
) -> str:
263+
if chat_model and chat_model in url:
264+
return "chat"
265+
if embed_model and embed_model in url:
266+
return "embed"
267+
if generate_model and generate_model in url:
268+
return "generate"
269+
raise ValueError(f"Unknown endpoint in url: {url}")
270+
271+
271272
def get_url(
272273
*,
273274
platform: str,
@@ -282,12 +283,3 @@ def get_url(
282283
endpoint = "invocations" if not stream else "invocations-response-stream"
283284
return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}"
284285
return ""
285-
286-
287-
def get_api_version(*, version: str):
288-
int_version = {
289-
"v1": 1,
290-
"v2": 2,
291-
}
292-
293-
return int_version.get(version, 1)

src/cohere/base_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2012,7 +2012,7 @@ def rerank(
20122012
We recommend a maximum of 1,000 documents for optimal endpoint performance.
20132013
20142014
model : typing.Optional[str]
2015-
The identifier of the model to use, eg `rerank-v3.5`.
2015+
The identifier of the model to use, one of : `rerank-english-v3.0`, `rerank-multilingual-v3.0`, `rerank-english-v2.0`, `rerank-multilingual-v2.0`
20162016
20172017
top_n : typing.Optional[int]
20182018
The number of most relevant documents or indices to return, defaults to the length of the documents
@@ -5047,7 +5047,7 @@ async def rerank(
50475047
We recommend a maximum of 1,000 documents for optimal endpoint performance.
50485048
50495049
model : typing.Optional[str]
5050-
The identifier of the model to use, eg `rerank-v3.5`.
5050+
The identifier of the model to use, one of : `rerank-english-v3.0`, `rerank-multilingual-v3.0`, `rerank-english-v2.0`, `rerank-multilingual-v2.0`
50515051
50525052
top_n : typing.Optional[int]
50535053
The number of most relevant documents or indices to return, defaults to the length of the documents

src/cohere/bedrock_client.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ def __init__(
2525
timeout=timeout,
2626
)
2727

28-
def rerank(self, *, query, documents, model = ..., top_n = ..., rank_fields = ..., return_documents = ..., max_chunks_per_doc = ..., request_options = None):
29-
raise NotImplementedError("Please use cohere.BedrockClientV2 instead: Rerank API on Bedrock is not supported with cohere.BedrockClient for this model.")
3028

3129
class BedrockClientV2(AwsClientV2):
3230
def __init__(

src/cohere/core/client_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def get_headers(self) -> typing.Dict[str, str]:
2424
headers: typing.Dict[str, str] = {
2525
"X-Fern-Language": "Python",
2626
"X-Fern-SDK-Name": "cohere",
27-
"X-Fern-SDK-Version": "5.13.2",
27+
"X-Fern-SDK-Version": "5.12.0",
2828
}
2929
if self._client_name is not None:
3030
headers["X-Client-Name"] = self._client_name

src/cohere/sagemaker_client.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@ def __init__(
2626
aws_region=aws_region,
2727
timeout=timeout,
2828
)
29-
try:
30-
self.sagemaker_finetuning = Client(aws_region=aws_region)
31-
except Exception:
32-
pass
29+
self.sagemaker_finetuning = Client(aws_region=aws_region)
3330

3431

3532
class SagemakerClientV2(AwsClientV2):
@@ -53,7 +50,4 @@ def __init__(
5350
aws_region=aws_region,
5451
timeout=timeout,
5552
)
56-
try:
57-
self.sagemaker_finetuning = Client(aws_region=aws_region)
58-
except Exception:
59-
pass
53+
self.sagemaker_finetuning = Client(aws_region=aws_region)

src/cohere/v2/client.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def chat_stream(
222222
p=1.1,
223223
return_prompt=True,
224224
logprobs=True,
225+
stream=True,
225226
)
226227
for chunk in response:
227228
yield chunk
@@ -534,6 +535,7 @@ def chat(
534535
content="messages",
535536
)
536537
],
538+
stream=False,
537539
)
538540
"""
539541
_response = self._client_wrapper.httpx_client.request(
@@ -568,7 +570,6 @@ def chat(
568570
"p": p,
569571
"return_prompt": return_prompt,
570572
"logprobs": logprobs,
571-
"stream": False,
572573
},
573574
request_options=request_options,
574575
omit=OMIT,
@@ -736,7 +737,7 @@ def embed(
736737
input_type : EmbedInputType
737738
738739
embedding_types : typing.Sequence[EmbeddingType]
739-
Specifies the types of embeddings you want to get back. Can be one or more of the following types.
740+
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
740741
741742
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
742743
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
@@ -936,7 +937,13 @@ def rerank(
936937
Parameters
937938
----------
938939
model : str
939-
The identifier of the model to use, eg `rerank-v3.5`.
940+
The identifier of the model to use.
941+
942+
Supported models:
943+
- `rerank-english-v3.0`
944+
- `rerank-multilingual-v3.0`
945+
- `rerank-english-v2.0`
946+
- `rerank-multilingual-v2.0`
940947
941948
query : str
942949
The search query
@@ -1301,6 +1308,7 @@ async def main() -> None:
13011308
p=1.1,
13021309
return_prompt=True,
13031310
logprobs=True,
1311+
stream=True,
13041312
)
13051313
async for chunk in response:
13061314
yield chunk
@@ -1340,7 +1348,6 @@ async def main() -> None:
13401348
"p": p,
13411349
"return_prompt": return_prompt,
13421350
"logprobs": logprobs,
1343-
"stream": True,
13441351
},
13451352
request_options=request_options,
13461353
omit=OMIT,
@@ -1621,6 +1628,7 @@ async def main() -> None:
16211628
content="messages",
16221629
)
16231630
],
1631+
stream=False,
16241632
)
16251633
16261634
@@ -1658,7 +1666,6 @@ async def main() -> None:
16581666
"p": p,
16591667
"return_prompt": return_prompt,
16601668
"logprobs": logprobs,
1661-
"stream": False,
16621669
},
16631670
request_options=request_options,
16641671
omit=OMIT,
@@ -1826,7 +1833,7 @@ async def embed(
18261833
input_type : EmbedInputType
18271834
18281835
embedding_types : typing.Sequence[EmbeddingType]
1829-
Specifies the types of embeddings you want to get back. Can be one or more of the following types.
1836+
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
18301837
18311838
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
18321839
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
@@ -2034,7 +2041,13 @@ async def rerank(
20342041
Parameters
20352042
----------
20362043
model : str
2037-
The identifier of the model to use, eg `rerank-v3.5`.
2044+
The identifier of the model to use.
2045+
2046+
Supported models:
2047+
- `rerank-english-v3.0`
2048+
- `rerank-multilingual-v3.0`
2049+
- `rerank-english-v2.0`
2050+
- `rerank-multilingual-v2.0`
20382051
20392052
query : str
20402053
The search query

tests/test_client_init.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

0 commit comments

Comments
 (0)