|
4 | 4 | from langchain_core.embeddings import Embeddings
|
5 | 5 | from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
|
6 | 6 | from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
| 7 | +from requests import RequestException |
7 | 8 |
|
8 | 9 | BAICHUAN_API_URL: str = "http://api.baichuan-ai.com/v1/embeddings"
|
9 | 10 |
|
|
22 | 23 | # NOTE!! BaichuanTextEmbeddings only supports Chinese text embedding.
|
23 | 24 | # Multi-language support is coming soon.
|
24 | 25 | class BaichuanTextEmbeddings(BaseModel, Embeddings):
|
25 |
| - """Baichuan Text Embedding models.""" |
| 26 | + """Baichuan Text Embedding models. |
| 27 | +
|
| 28 | + To use, you should set the environment variable ``BAICHUAN_API_KEY`` to |
| 29 | + your API key or pass it as a named parameter to the constructor. |
| 30 | +
|
| 31 | + Example: |
| 32 | + .. code-block:: python |
| 33 | +
|
| 34 | + from langchain_community.embeddings import BaichuanTextEmbeddings |
| 35 | +
|
| 36 | + baichuan = BaichuanTextEmbeddings(baichuan_api_key="my-api-key") |
| 37 | + """ |
26 | 38 |
|
27 | 39 | session: Any #: :meta private:
|
28 | 40 | model_name: str = "Baichuan-Text-Embedding"
|
29 | 41 | baichuan_api_key: Optional[SecretStr] = None
|
| 42 | + """Automatically inferred from env var `BAICHUAN_API_KEY` if not provided.""" |
30 | 43 |
|
31 | 44 | @root_validator(allow_reuse=True)
|
32 | 45 | def validate_environment(cls, values: Dict) -> Dict:
|
@@ -65,29 +78,26 @@ def _embed(self, texts: List[str]) -> Optional[List[List[float]]]:
|
65 | 78 | A list of list of floats representing the embeddings, or None if an
|
66 | 79 | error occurs.
|
67 | 80 | """
|
68 |
| - try: |
69 |
| - response = self.session.post( |
70 |
| - BAICHUAN_API_URL, json={"input": texts, "model": self.model_name} |
| 81 | + response = self.session.post( |
| 82 | + BAICHUAN_API_URL, json={"input": texts, "model": self.model_name} |
| 83 | + ) |
| 84 | + # Raise exception if response status code from 400 to 600 |
| 85 | + response.raise_for_status() |
| 86 | + # Check if the response status code indicates success |
| 87 | + if response.status_code == 200: |
| 88 | + resp = response.json() |
| 89 | + embeddings = resp.get("data", []) |
| 90 | + # Sort resulting embeddings by index |
| 91 | + sorted_embeddings = sorted(embeddings, key=lambda e: e.get("index", 0)) |
| 92 | + # Return just the embeddings |
| 93 | + return [result.get("embedding", []) for result in sorted_embeddings] |
| 94 | + else: |
| 95 | + # Log error or handle unsuccessful response appropriately |
| 96 | + # Handle 100 <= status_code < 400, not include 200 |
| 97 | + raise RequestException( |
| 98 | + f"Error: Received status code {response.status_code} from " |
| 99 | + "`BaichuanEmbedding` API" |
71 | 100 | )
|
72 |
| - # Check if the response status code indicates success |
73 |
| - if response.status_code == 200: |
74 |
| - resp = response.json() |
75 |
| - embeddings = resp.get("data", []) |
76 |
| - # Sort resulting embeddings by index |
77 |
| - sorted_embeddings = sorted(embeddings, key=lambda e: e.get("index", 0)) |
78 |
| - # Return just the embeddings |
79 |
| - return [result.get("embedding", []) for result in sorted_embeddings] |
80 |
| - else: |
81 |
| - # Log error or handle unsuccessful response appropriately |
82 |
| - print( # noqa: T201 |
83 |
| - f"Error: Received status code {response.status_code} from " |
84 |
| - "embedding API" |
85 |
| - ) |
86 |
| - return None |
87 |
| - except Exception as e: |
88 |
| - # Log the exception or handle it as needed |
89 |
| - print(f"Exception occurred while trying to get embeddings: {str(e)}") # noqa: T201 |
90 |
| - return None |
91 | 101 |
|
92 | 102 | def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: # type: ignore[override]
|
93 | 103 | """Public method to get embeddings for a list of documents.
|
|
0 commit comments