diff --git a/libs/community/langchain_community/llms/sparkllm.py b/libs/community/langchain_community/llms/sparkllm.py index 4e10e3b253633..d3741aa93d9b3 100644 --- a/libs/community/langchain_community/llms/sparkllm.py +++ b/libs/community/langchain_community/llms/sparkllm.py @@ -24,64 +24,149 @@ class SparkLLM(LLM): - """iFlyTek Spark large language model. + """iFlyTek Spark completion model integration. + + Setup: + To use, you should set environment variables ``IFLYTEK_SPARK_APP_ID``, + ``IFLYTEK_SPARK_API_KEY`` and ``IFLYTEK_SPARK_API_SECRET``. + + .. code-block:: bash + + export IFLYTEK_SPARK_APP_ID="your-app-id" + export IFLYTEK_SPARK_API_KEY="your-api-key" + export IFLYTEK_SPARK_API_SECRET="your-api-secret" + + Key init args — completion params: + model: Optional[str] + Name of IFLYTEK SPARK model to use. + temperature: Optional[float] + Sampling temperature. + top_k: Optional[float] + What search sampling control to use. + streaming: Optional[bool] + Whether to stream the results or not. + + Key init args — client params: + app_id: Optional[str] + IFLYTEK SPARK API KEY. Automatically inferred from env var `IFLYTEK_SPARK_APP_ID` if not provided. + api_key: Optional[str] + IFLYTEK SPARK API KEY. If not passed in will be read from env var IFLYTEK_SPARK_API_KEY. + api_secret: Optional[str] + IFLYTEK SPARK API SECRET. If not passed in will be read from env var IFLYTEK_SPARK_API_SECRET. + api_url: Optional[str] + Base URL for API requests. + timeout: Optional[int] + Timeout for requests. + + See full list of supported init args and their descriptions in the params section. + + Instantiate: + .. code-block:: python + + from langchain_community.llms import SparkLLM - To use, you should pass `app_id`, `api_key`, `api_secret` - as a named parameter to the constructor OR set environment - variables ``IFLYTEK_SPARK_APP_ID``, ``IFLYTEK_SPARK_API_KEY`` and - ``IFLYTEK_SPARK_API_SECRET`` + llm = SparkLLM( + app_id="your-app-id", + api_key="your-api_key", + api_secret="your-api-secret", + # model='Spark4.0 Ultra', + # temperature=..., + # other params... + ) - Example: + Invoke: .. code-block:: python - client = SparkLLM( - spark_app_id="", - spark_api_key="", - spark_api_secret="" - ) - """ + input_text = "用50个字左右阐述,生命的意义在于" + llm.invoke(input_text) + + .. code-block:: python + + '生命的意义在于实现自我价值,追求内心的平静与快乐,同时为他人和社会带来正面影响。' + + Stream: + .. code-block:: python + + for chunk in llm.stream(input_text): + print(chunk) + + .. code-block:: python + + 生命 | 的意义在于 | 不断探索和 | 实现个人潜能,通过 | 学习 | 、成长和对社会 | 的贡献,追求内心的满足和幸福。 + + Async: + .. code-block:: python + + await llm.ainvoke(input_text) + + # stream: + # async for chunk in llm.astream(input_text): + # print(chunk) + + # batch: + # await llm.abatch([input_text]) + + .. code-block:: python + + '生命的意义在于实现自我价值,追求内心的平静与快乐,同时为他人和社会带来正面影响。' + + """ # noqa: E501 client: Any = None #: :meta private: - spark_app_id: Optional[str] = None - spark_api_key: Optional[str] = None - spark_api_secret: Optional[str] = None - spark_api_url: Optional[str] = None - spark_llm_domain: Optional[str] = None + spark_app_id: Optional[str] = Field(default=None, alias="app_id") + """Automatically inferred from env var `IFLYTEK_SPARK_APP_ID` + if not provided.""" + spark_api_key: Optional[str] = Field(default=None, alias="api_key") + """IFLYTEK SPARK API KEY. If not passed in will be read from + env var IFLYTEK_SPARK_API_KEY.""" + spark_api_secret: Optional[str] = Field(default=None, alias="api_secret") + """IFLYTEK SPARK API SECRET. If not passed in will be read from + env var IFLYTEK_SPARK_API_SECRET.""" + spark_api_url: Optional[str] = Field(default=None, alias="api_url") + """Base URL path for API requests, leave blank if not using a proxy or service + emulator.""" + spark_llm_domain: Optional[str] = Field(default=None, alias="model") + """Model name to use.""" spark_user_id: str = "lc_user" streaming: bool = False - request_timeout: int = 30 + """Whether to stream the results or not.""" + request_timeout: int = Field(default=30, alias="timeout") + """request timeout for chat http requests""" temperature: float = 0.5 + """What sampling temperature to use.""" top_k: int = 4 + """What search sampling control to use.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for API call not explicitly specified.""" @pre_init def validate_environment(cls, values: Dict) -> Dict: values["spark_app_id"] = get_from_dict_or_env( values, - "spark_app_id", + ["spark_app_id", "app_id"], "IFLYTEK_SPARK_APP_ID", ) values["spark_api_key"] = get_from_dict_or_env( values, - "spark_api_key", + ["spark_api_key", "api_key"], "IFLYTEK_SPARK_API_KEY", ) values["spark_api_secret"] = get_from_dict_or_env( values, - "spark_api_secret", + ["spark_api_secret", "api_secret"], "IFLYTEK_SPARK_API_SECRET", ) values["spark_api_url"] = get_from_dict_or_env( values, - "spark_api_url", + ["spark_api_url", "api_url"], "IFLYTEK_SPARK_API_URL", - "wss://spark-api.xf-yun.com/v3.1/chat", + "wss://spark-api.xf-yun.com/v3.5/chat", ) values["spark_llm_domain"] = get_from_dict_or_env( values, - "spark_llm_domain", + ["spark_llm_domain", "model"], "IFLYTEK_SPARK_LLM_DOMAIN", - "generalv3", + "generalv3.5", ) # put extra params into model_kwargs values["model_kwargs"]["temperature"] = values["temperature"] or cls.temperature @@ -163,7 +248,7 @@ def _stream( [{"role": "user", "content": prompt}], self.spark_user_id, self.model_kwargs, - self.streaming, + True, ) for content in self.client.subscribe(timeout=self.request_timeout): if "data" not in content: @@ -200,11 +285,11 @@ def __init__( ) self.api_url = ( - "wss://spark-api.xf-yun.com/v3.1/chat" if not api_url else api_url + "wss://spark-api.xf-yun.com/v3.5/chat" if not api_url else api_url ) self.app_id = app_id self.model_kwargs = model_kwargs - self.spark_domain = spark_domain or "generalv3" + self.spark_domain = spark_domain or "generalv3.5" self.queue: Queue[Dict] = Queue() self.blocking_message = {"content": "", "role": "assistant"} self.api_key = api_key diff --git a/libs/community/tests/integration_tests/llms/test_sparkllm.py b/libs/community/tests/integration_tests/llms/test_sparkllm.py index 8cacea8b540af..8bcd28f4df212 100644 --- a/libs/community/tests/integration_tests/llms/test_sparkllm.py +++ b/libs/community/tests/integration_tests/llms/test_sparkllm.py @@ -18,3 +18,28 @@ def test_generate() -> None: output = llm.generate(["Say foo:"]) assert isinstance(output, LLMResult) assert isinstance(output.generations, list) + + +def test_spark_llm_with_param_alias() -> None: + """Test SparkLLM with parameters alias.""" + llm = SparkLLM( # type: ignore[call-arg] + app_id="your-app-id", + api_key="your-api-key", + api_secret="your-api-secret", + model="Spark4.0 Ultra", + api_url="your-api-url", + timeout=20, + ) + assert llm.spark_app_id == "your-app-id" + assert llm.spark_api_key == "your-api-key" + assert llm.spark_api_secret == "your-api-secret" + assert llm.spark_llm_domain == "Spark4.0 Ultra" + assert llm.spark_api_url == "your-api-url" + assert llm.request_timeout == 20 + + +def test_spark_llm_with_stream() -> None: + """Test SparkLLM with stream.""" + llm = SparkLLM() # type: ignore[call-arg] + for chunk in llm.stream("你好呀"): + assert isinstance(chunk, str)