1
- from langchain_community .embeddings .cohere import CohereEmbeddings
1
+ from typing import Any
2
+
3
+ import cohere
4
+ from langchain_cohere import CohereEmbeddings
2
5
3
6
from langflow .base .models .model import LCModelComponent
4
7
from langflow .field_typing import Embeddings
5
8
from langflow .io import DropdownInput , FloatInput , IntInput , MessageTextInput , Output , SecretStrInput
6
9
10
+ HTTP_STATUS_OK = 200
11
+
7
12
8
13
class CohereEmbeddingsComponent (LCModelComponent ):
9
14
display_name = "Cohere Embeddings"
@@ -12,9 +17,9 @@ class CohereEmbeddingsComponent(LCModelComponent):
12
17
name = "CohereEmbeddings"
13
18
14
19
inputs = [
15
- SecretStrInput (name = "cohere_api_key " , display_name = "Cohere API Key" , required = True ),
20
+ SecretStrInput (name = "api_key " , display_name = "Cohere API Key" , required = True , real_time_refresh = True ),
16
21
DropdownInput (
17
- name = "model " ,
22
+ name = "model_name " ,
18
23
display_name = "Model" ,
19
24
advanced = False ,
20
25
options = [
@@ -24,6 +29,8 @@ class CohereEmbeddingsComponent(LCModelComponent):
24
29
"embed-multilingual-light-v2.0" ,
25
30
],
26
31
value = "embed-english-v2.0" ,
32
+ refresh_button = True ,
33
+ combobox = True ,
27
34
),
28
35
MessageTextInput (name = "truncate" , display_name = "Truncate" , advanced = True ),
29
36
IntInput (name = "max_retries" , display_name = "Max Retries" , value = 3 , advanced = True ),
@@ -36,11 +43,39 @@ class CohereEmbeddingsComponent(LCModelComponent):
36
43
]
37
44
38
45
def build_embeddings (self ) -> Embeddings :
39
- return CohereEmbeddings (
40
- cohere_api_key = self .cohere_api_key ,
41
- model = self .model ,
42
- truncate = self .truncate ,
43
- max_retries = self .max_retries ,
44
- user_agent = self .user_agent ,
45
- request_timeout = self .request_timeout or None ,
46
- )
46
+ data = None
47
+ try :
48
+ data = CohereEmbeddings (
49
+ cohere_api_key = self .api_key ,
50
+ model = self .model_name ,
51
+ truncate = self .truncate ,
52
+ max_retries = self .max_retries ,
53
+ user_agent = self .user_agent ,
54
+ request_timeout = self .request_timeout or None ,
55
+ )
56
+ except Exception as e :
57
+ msg = (
58
+ "Unable to create Cohere Embeddings. " ,
59
+ "Please verify the API key and model parameters, and try again." ,
60
+ )
61
+ raise ValueError (msg ) from e
62
+ # added status if not the return data would be serialised to create the status
63
+ return data
64
+
65
+ def get_model (self ):
66
+ try :
67
+ co = cohere .ClientV2 (self .api_key )
68
+ response = co .models .list (endpoint = "embed" )
69
+ models = response .models
70
+ return [model .name for model in models ]
71
+ except Exception as e :
72
+ msg = f"Failed to fetch Cohere models. Error: { e } "
73
+ raise ValueError (msg ) from e
74
+
75
+ async def update_build_config (self , build_config : dict , field_value : Any , field_name : str | None = None ):
76
+ if field_name in {"model_name" , "api_key" }:
77
+ if build_config .get ("api_key" , {}).get ("value" , None ):
78
+ build_config ["model_name" ]["options" ] = self .get_model ()
79
+ else :
80
+ build_config ["model_name" ]["options" ] = field_value
81
+ return build_config
0 commit comments