Skip to content

Commit

Permalink
Merge pull request #71 from jlowin/llm-api
Browse files Browse the repository at this point in the history
add llm settings
  • Loading branch information
jlowin authored May 25, 2024
2 parents e78c87f + 2f77c41 commit 6d11409
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/controlflow/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class Agent(ControlFlowModel, ExposeSyncMethodsMixin):
)
model: str = Field(
description="The model used by the agent. If not provided, the default model will be used.",
default_factory=lambda: controlflow.settings.model,
default_factory=lambda: controlflow.settings.llm_model,
)

def __init__(self, name, **kwargs):
Expand Down
32 changes: 28 additions & 4 deletions src/controlflow/llm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ def completion(
handler = CompoundHandler(handlers=handlers or [])

if model is None:
model = controlflow.settings.model
model = controlflow.settings.llm_model
if "api_key" not in kwargs:
kwargs["api_key"] = controlflow.settings.llm_api_key
if "api_version" not in kwargs:
kwargs["api_version"] = controlflow.settings.llm_api_version
if "api_base" not in kwargs:
kwargs["api_base"] = controlflow.settings.llm_api_base

tools = as_tools(tools or [])

Expand Down Expand Up @@ -151,7 +157,13 @@ def _completion_stream(

handler = CompoundHandler(handlers=handlers or [])
if model is None:
model = controlflow.settings.model
model = controlflow.settings.llm_model
if "api_key" not in kwargs:
kwargs["api_key"] = controlflow.settings.llm_api_key
if "api_version" not in kwargs:
kwargs["api_version"] = controlflow.settings.llm_api_version
if "api_base" not in kwargs:
kwargs["api_base"] = controlflow.settings.llm_api_base

tools = as_tools(tools or [])

Expand Down Expand Up @@ -273,7 +285,13 @@ async def completion_async(

handler = CompoundHandler(handlers=handlers or [])
if model is None:
model = controlflow.settings.model
model = controlflow.settings.llm_model
if "api_key" not in kwargs:
kwargs["api_key"] = controlflow.settings.llm_api_key
if "api_version" not in kwargs:
kwargs["api_version"] = controlflow.settings.llm_api_version
if "api_base" not in kwargs:
kwargs["api_base"] = controlflow.settings.llm_api_base

tools = as_tools(tools or [])

Expand Down Expand Up @@ -379,7 +397,13 @@ async def _completion_stream_async(

handler = CompoundHandler(handlers=handlers or [])
if model is None:
model = controlflow.settings.model
model = controlflow.settings.llm_model
if "api_key" not in kwargs:
kwargs["api_key"] = controlflow.settings.llm_api_key
if "api_version" not in kwargs:
kwargs["api_version"] = controlflow.settings.llm_api_version
if "api_base" not in kwargs:
kwargs["api_base"] = controlflow.settings.llm_api_base

tools = as_tools(tools or [])

Expand Down
9 changes: 8 additions & 1 deletion src/controlflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,14 @@ class Settings(ControlFlowSettings):

# ------------ LLM settings ------------

model: str = Field("gpt-4o", description="The LLM model to use.")
llm_model: str = Field("gpt-4o", description="The LLM model to use.")
llm_api_key: Optional[str] = Field(None, description="The LLM API key to use.")
llm_api_base: Optional[str] = Field(
None, description="The LLM API base URL to use."
)
llm_api_version: Optional[str] = Field(
None, description="The LLM API version to use."
)

# ------------ Flow visualization settings ------------

Expand Down

0 comments on commit 6d11409

Please sign in to comment.