From 2f77c415a1b7dcc0c2f7458b66ad3c1d145b5117 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Fri, 24 May 2024 23:33:13 -0400 Subject: [PATCH] add llm settings --- src/controlflow/core/agent.py | 2 +- src/controlflow/llm/completions.py | 32 ++++++++++++++++++++++++++---- src/controlflow/settings.py | 9 ++++++++- 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/controlflow/core/agent.py b/src/controlflow/core/agent.py index bbe5a50c..57ab3364 100644 --- a/src/controlflow/core/agent.py +++ b/src/controlflow/core/agent.py @@ -76,7 +76,7 @@ class Agent(ControlFlowModel, ExposeSyncMethodsMixin): ) model: str = Field( description="The model used by the agent. If not provided, the default model will be used.", - default_factory=lambda: controlflow.settings.model, + default_factory=lambda: controlflow.settings.llm_model, ) def __init__(self, name, **kwargs): diff --git a/src/controlflow/llm/completions.py b/src/controlflow/llm/completions.py index 94c5236d..48f206fc 100644 --- a/src/controlflow/llm/completions.py +++ b/src/controlflow/llm/completions.py @@ -64,7 +64,13 @@ def completion( handler = CompoundHandler(handlers=handlers or []) if model is None: - model = controlflow.settings.model + model = controlflow.settings.llm_model + if "api_key" not in kwargs: + kwargs["api_key"] = controlflow.settings.llm_api_key + if "api_version" not in kwargs: + kwargs["api_version"] = controlflow.settings.llm_api_version + if "api_base" not in kwargs: + kwargs["api_base"] = controlflow.settings.llm_api_base tools = as_tools(tools or []) @@ -151,7 +157,13 @@ def _completion_stream( handler = CompoundHandler(handlers=handlers or []) if model is None: - model = controlflow.settings.model + model = controlflow.settings.llm_model + if "api_key" not in kwargs: + kwargs["api_key"] = controlflow.settings.llm_api_key + if "api_version" not in kwargs: + kwargs["api_version"] = controlflow.settings.llm_api_version + if "api_base" not in kwargs: + kwargs["api_base"] = controlflow.settings.llm_api_base tools = as_tools(tools or []) @@ -273,7 +285,13 @@ async def completion_async( handler = CompoundHandler(handlers=handlers or []) if model is None: - model = controlflow.settings.model + model = controlflow.settings.llm_model + if "api_key" not in kwargs: + kwargs["api_key"] = controlflow.settings.llm_api_key + if "api_version" not in kwargs: + kwargs["api_version"] = controlflow.settings.llm_api_version + if "api_base" not in kwargs: + kwargs["api_base"] = controlflow.settings.llm_api_base tools = as_tools(tools or []) @@ -379,7 +397,13 @@ async def _completion_stream_async( handler = CompoundHandler(handlers=handlers or []) if model is None: - model = controlflow.settings.model + model = controlflow.settings.llm_model + if "api_key" not in kwargs: + kwargs["api_key"] = controlflow.settings.llm_api_key + if "api_version" not in kwargs: + kwargs["api_version"] = controlflow.settings.llm_api_version + if "api_base" not in kwargs: + kwargs["api_base"] = controlflow.settings.llm_api_base tools = as_tools(tools or []) diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index be9e49b8..51f241ba 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -90,7 +90,14 @@ class Settings(ControlFlowSettings): # ------------ LLM settings ------------ - model: str = Field("gpt-4o", description="The LLM model to use.") + llm_model: str = Field("gpt-4o", description="The LLM model to use.") + llm_api_key: Optional[str] = Field(None, description="The LLM API key to use.") + llm_api_base: Optional[str] = Field( + None, description="The LLM API base URL to use." + ) + llm_api_version: Optional[str] = Field( + None, description="The LLM API version to use." + ) # ------------ Flow visualization settings ------------