7
7
from controlflow .llm .models import get_model
8
8
9
9
10
- def test_get_model_from_openai ():
10
+ def test_get_model_from_openai (monkeypatch ):
11
+ monkeypatch .setenv ("OPENAI_API_KEY" , "fake_openai_api_key" )
11
12
model = get_model ("openai/gpt-4o-mini" )
12
13
assert isinstance (model , ChatOpenAI )
13
14
assert model .model_name == "gpt-4o-mini"
14
15
15
16
16
- def test_get_model_from_anthropic ():
17
+ def test_get_model_from_anthropic (monkeypatch ):
18
+ monkeypatch .setenv ("ANTHROPIC_API_KEY" , "fake_anthropic_api_key" )
17
19
model = get_model ("anthropic/claude-3-haiku-20240307" )
18
20
assert isinstance (model , ChatAnthropic )
19
21
assert model .model == "claude-3-haiku-20240307"
20
22
21
23
22
- def test_get_azure_openai_model ():
24
+ def test_get_azure_openai_model (monkeypatch ):
25
+ monkeypatch .setenv ("AZURE_OPENAI_API_KEY" , "fake_azure_openai_api_key" )
26
+ monkeypatch .setenv (
27
+ "AZURE_OPENAI_ENDPOINT" , "https://fake-endpoint.openai.azure.com"
28
+ )
29
+ monkeypatch .setenv ("OPENAI_API_VERSION" , "2024-05-01-preview" )
23
30
model = get_model ("azure-openai/gpt-4" )
24
31
assert isinstance (model , AzureChatOpenAI )
25
- assert model .deployment_name == "gpt-4"
32
+ assert model .model_name == "gpt-4"
26
33
27
34
28
- def test_get_google_model ():
35
+ def test_get_google_model (monkeypatch ):
36
+ monkeypatch .setenv ("GOOGLE_API_KEY" , "fake_google_api_key" )
29
37
model = get_model ("google/gemini-1.5-pro" )
30
38
assert isinstance (model , ChatGoogleGenerativeAI )
31
39
assert model .model == "models/gemini-1.5-pro"
32
40
33
41
34
- def test_get_groq_model ():
42
+ def test_get_groq_model (monkeypatch ):
43
+ monkeypatch .setenv ("GROQ_API_KEY" , "fake_groq_api_key" )
35
44
model = get_model ("groq/mixtral-8x7b-32768" )
36
45
assert isinstance (model , ChatGroq )
37
46
assert model .model_name == "mixtral-8x7b-32768"
@@ -49,7 +58,8 @@ def test_get_model_with_unsupported_provider():
49
58
get_model ("unsupported/model-name" )
50
59
51
60
52
- def test_get_model_with_temperature ():
61
+ def test_get_model_with_temperature (monkeypatch ):
62
+ monkeypatch .setenv ("ANTHROPIC_API_KEY" , "fake_anthropic_api_key" )
53
63
model = get_model ("anthropic/claude-3-haiku-20240307" , temperature = 0.7 )
54
64
assert isinstance (model , ChatAnthropic )
55
65
assert model .temperature == 0.7
0 commit comments