Skip to content

Commit

Permalink
Merge pull request #406 from PrefectHQ/temp-none
Browse files Browse the repository at this point in the history
Set temperature to None by default
  • Loading branch information
jlowin authored Feb 6, 2025
2 parents ca4adb0 + 9fe8e1b commit 739b165
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 14 deletions.
6 changes: 3 additions & 3 deletions src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class Agent(ControlFlowModel, abc.ABC):
default=False,
description="If True, the agent is given tools for interacting with a human user.",
)
memories: list[Memory] | list[AsyncMemory] = Field(
memories: list[Union[Memory, AsyncMemory]] = Field(
default=[],
description="A list of memory modules for the agent to use.",
)
Expand Down Expand Up @@ -345,7 +345,7 @@ def _run_model(

create_markdown_artifact(
markdown=f"""
{response.content or '(No content)'}
{response.content or "(No content)"}
#### Payload
```json
Expand Down Expand Up @@ -409,7 +409,7 @@ async def _run_model_async(

create_markdown_artifact(
markdown=f"""
{response.content or '(No content)'}
{response.content or "(No content)"}
#### Payload
```json
Expand Down
4 changes: 1 addition & 3 deletions src/controlflow/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ class Defaults(ControlFlowModel):
model: Optional[Any]
history: History
agent: Agent
memory_provider: (
Optional[Union[MemoryProvider, str]] | Optional[Union[AsyncMemoryProvider, str]]
)
memory_provider: Optional[Union[MemoryProvider, AsyncMemoryProvider, str]]

# add more defaults here
def __repr__(self) -> str:
Expand Down
8 changes: 7 additions & 1 deletion src/controlflow/llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def get_model(
"To use Google as an LLM provider, please install the `langchain_google_genai` package."
)
cls = ChatGoogleGenerativeAI
if temperature is None:
temperature = 0.7
elif provider == "groq":
try:
from langchain_groq import ChatGroq
Expand All @@ -60,6 +62,8 @@ def get_model(
"To use Groq as an LLM provider, please install the `langchain_groq` package."
)
cls = ChatGroq
if temperature is None:
temperature = 0.7
elif provider == "ollama":
try:
from langchain_ollama import ChatOllama
Expand All @@ -73,7 +77,9 @@ def get_model(
f"Could not load provider `{provider}` automatically. Please provide the LLM class manually."
)

return cls(model=model, temperature=temperature, **kwargs)
if temperature is not None:
kwargs["temperature"] = temperature
return cls(model=model, **kwargs)


def _get_initial_default_model() -> BaseChatModel:
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def get_tools(self) -> list[Tool]:
tools = as_tools(tools)
return tools

def get_memories(self) -> list[Memory] | list[AsyncMemory]:
def get_memories(self) -> list[Union[Memory, AsyncMemory]]:
memories = set()

memories.update(self.agent.memories)
Expand Down Expand Up @@ -525,7 +525,7 @@ def compile_prompt(self) -> str:
]

prompt = "\n\n".join([p for p in prompts if p])
logger.debug(f"{'='*10}\nCompiled prompt: {prompt}\n{'='*10}")
logger.debug(f"{'=' * 10}\nCompiled prompt: {prompt}\n{'=' * 10}")
return prompt

def compile_messages(self) -> list[BaseMessage]:
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/orchestration/prompt_templates.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from pydantic import model_validator

Expand Down Expand Up @@ -98,7 +98,7 @@ def should_render(self) -> bool:

class MemoryTemplate(Template):
template_path: str = "memories.jinja"
memories: list[Memory] | list[AsyncMemory]
memories: list[Union[Memory, AsyncMemory]]

def should_render(self) -> bool:
return bool(self.memories)
Expand Down
4 changes: 3 additions & 1 deletion src/controlflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def _validate_pretty_print_agent_events(cls, data: dict) -> dict:
default="openai/gpt-4o",
description="The default LLM model for agents.",
)
llm_temperature: float = Field(0.7, description="The temperature for LLM sampling.")
llm_temperature: Union[float, None] = Field(
None, description="The temperature for LLM sampling."
)
max_input_tokens: int = Field(
100_000, description="The maximum number of tokens to send to an LLM."
)
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __getitem__(self, item):
return self.root[item]

def __repr__(self) -> str:
return f'Labels: {", ".join(self.root)}'
return f"Labels: {', '.join(self.root)}"


class TaskStatus(Enum):
Expand Down Expand Up @@ -162,7 +162,7 @@ class Task(ControlFlowModel):
description="Agents that are allowed to mark this task as complete. If None, all agents are allowed.",
)
interactive: bool = False
memories: list[Memory] | list[AsyncMemory] = Field(
memories: list[Union[Memory, AsyncMemory]] = Field(
default=[],
description="A list of memory modules for the task to use.",
)
Expand Down

0 comments on commit 739b165

Please sign in to comment.