Skip to content

Commit

Permalink
Update test_models.py
Browse files Browse the repository at this point in the history
  • Loading branch information
aaazzam committed Jun 12, 2024
1 parent 5af9220 commit 0116614
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions tests/llm/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,47 @@
)


def test_get_provider_from_string_openai():
def test_get_provider_from_string_openai() -> None:
provider = get_provider_from_string("openai")
assert provider.__name__ == "ChatOpenAI"

def test_get_provider_from_string_azure_openai():

def test_get_provider_from_string_azure_openai() -> None:
provider = get_provider_from_string("azure_openai")
assert provider.__name__ == "AzureChatOpenAI"

def test_get_provider_from_string_anthropic():

def test_get_provider_from_string_anthropic() -> None:
pytest.importorskip("langchain_anthropic")
provider = get_provider_from_string("anthropic")
assert provider.__name__ == "ChatAnthropic"

def test_get_provider_from_string_google():

def test_get_provider_from_string_google() -> None:
pytest.importorskip("langchain_google_genai")
provider = get_provider_from_string("google")
assert provider.__name__ == "ChatGoogleGenerativeAI"

def test_get_provider_from_string_invalid_provider():

def test_get_provider_from_string_invalid_provider() -> None:
with pytest.raises(ValueError):
get_provider_from_string("invalid_provider.gpt4")

def test_get_provider_from_string_missing_module():

def test_get_provider_from_string_missing_module() -> None:
with pytest.raises(ImportError):
get_provider_from_string("openai.missing_module")

def test_get_model_from_string(mocker: pytest_mock.MockFixture):

def test_get_model_from_string(mocker: pytest_mock.MockFixture) -> None:
# Test getting a model from string
mock_provider_class = mocker.Mock()
mock_provider_instance = mocker.Mock()
mock_provider_class.return_value = mock_provider_instance
mocker.patch("controlflow.llm.models.get_provider_from_string", return_value=mock_provider_class)
mocker.patch(
"controlflow.llm.models.get_provider_from_string",
return_value=mock_provider_class,
)
model = get_model_from_string("openai/davinci", temperature=0.5)
assert model == mock_provider_instance
mock_provider_class.assert_called_once_with(
Expand All @@ -50,16 +59,21 @@ def test_get_model_from_string(mocker: pytest_mock.MockFixture):
mock_provider_class.reset_mock()
mocker.patch("controlflow.settings.settings.llm_model", "anthropic/claude")
mocker.patch("controlflow.settings.settings.llm_temperature", 0.7)
pytest.importorskip("langchain_anthropic") # Skip if langchain_anthropic is not installed
pytest.importorskip(
"langchain_anthropic"
) # Skip if langchain_anthropic is not installed
model = get_model_from_string()
assert model == mock_provider_instance
mock_provider_class.assert_called_once_with(
name="claude",
temperature=0.7,
)

def test_get_default_model(mocker: pytest_mock.MockFixture):

def test_get_default_model(mocker: pytest_mock.MockFixture) -> None:
# Test getting the default model
mock_get_model_from_string = mocker.patch("controlflow.llm.models.get_model_from_string")
mock_get_model_from_string = mocker.patch(
"controlflow.llm.models.get_model_from_string"
)
get_default_model()
mock_get_model_from_string.assert_called_once_with()
mock_get_model_from_string.assert_called_once_with()

0 comments on commit 0116614

Please sign in to comment.