Skip to content

Commit

Permalink
Merge pull request #389 from namin/ollama-autoconf
Browse files Browse the repository at this point in the history
Support provider Ollama for automatic configuration.
  • Loading branch information
jlowin authored Jan 11, 2025
2 parents bcf422e + 6cab490 commit daf0883
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/guides/configure-llms.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ At this time, supported providers for automatic configuration include:
| Anthropic | `anthropic` | (included) |
| Google | `google` | `langchain_google_genai` |
| Groq | `groq` | `langchain_groq` |
| Ollama | `ollama` | `langchain-ollama` |

If the required dependencies are not installed, ControlFlow will be unable to load the model and will raise an error.

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ tests = [
"langchain_community",
"langchain_google_genai",
"langchain_groq",
"langchain-ollama',
"pytest-asyncio>=0.18.2,!=0.22.0,<0.23.0",
"pytest-env>=0.8,<2.0",
"pytest-rerunfailures>=10,<14",
Expand Down
8 changes: 8 additions & 0 deletions src/controlflow/llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ def get_model(
"To use Groq as an LLM provider, please install the `langchain_groq` package."
)
cls = ChatGroq
elif provider == "ollama":
try:
from langchain_ollama import ChatOllama
except ImportError:
raise ImportError(
"To use Ollama as an LLM provider, please install the `langchain-ollama` package."
)
cls = ChatOllama
else:
raise ValueError(
f"Could not load provider `{provider}` automatically. Please provide the LLM class manually."
Expand Down
5 changes: 5 additions & 0 deletions tests/llm/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI

from controlflow.llm.models import get_model
Expand Down Expand Up @@ -45,6 +46,10 @@ def test_get_groq_model(monkeypatch):
assert isinstance(model, ChatGroq)
assert model.model_name == "mixtral-8x7b-32768"

def test_get_ollama_model(monkeypatch):
model = get_model("ollama/qwen2.5")
assert isinstance(model, ChatOllama)
assert model.model == "qwen2.5"

def test_get_model_with_invalid_format():
with pytest.raises(ValueError, match="The model `gpt-4o` is not valid."):
Expand Down

0 comments on commit daf0883

Please sign in to comment.