Skip to content

Commit 20d677c

Browse files
authored
Merge pull request #6 from Mindinventory/improvement_002
Make llm instances configurable.
2 parents 66137f7 + aa22c50 commit 20d677c

File tree

4 files changed

+13
-12
lines changed

4 files changed

+13
-12
lines changed

mindsql/_utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
GOOGLE_GEN_AI_VALUE_ERROR = "For GoogleGenAI, config must be provided with an api_key"
2828
GOOGLE_GEN_AI_APIKEY_ERROR = "config must contain a Google AI Studio api_key"
2929
LLAMA_VALUE_ERROR = "For LlamaAI, config must be provided with a model_path"
30+
CONFIG_REQUIRED_ERROR = "Configuration is required."
3031
LLAMA_PROMPT_EXCEPTION = "Prompt cannot be empty."
3132
OPENAI_VALUE_ERROR = "OpenAI API key is required"
3233
OPENAI_PROMPT_EMPTY_EXCEPTION = "Prompt cannot be empty."

mindsql/llms/googlegenai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ def __init__(self, config=None):
2020

2121
if 'api_key' not in config:
2222
raise ValueError(GOOGLE_GEN_AI_APIKEY_ERROR)
23-
api_key = config['api_key']
23+
api_key = config.pop('api_key')
2424
genai.configure(api_key=api_key)
25-
self.model = genai.GenerativeModel('gemini-pro')
25+
self.model = genai.GenerativeModel('gemini-pro', **config)
2626

2727
def system_message(self, message: str) -> any:
2828
"""

mindsql/llms/llama.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from llama_cpp import Llama
22

3-
from .._utils.constants import LLAMA_VALUE_ERROR, LLAMA_PROMPT_EXCEPTION
3+
from .._utils.constants import LLAMA_VALUE_ERROR, LLAMA_PROMPT_EXCEPTION, CONFIG_REQUIRED_ERROR
44
from .illm import ILlm
55

66

@@ -16,13 +16,13 @@ def __init__(self, config=None):
1616
None
1717
"""
1818
if config is None:
19-
raise ValueError("")
19+
raise ValueError(CONFIG_REQUIRED_ERROR)
2020

2121
if 'model_path' not in config:
2222
raise ValueError(LLAMA_VALUE_ERROR)
23-
path = config['model_path']
23+
path = config.pop('model_path')
2424

25-
self.model = Llama(model_path=path)
25+
self.model = Llama(model_path=path, **config)
2626

2727
def system_message(self, message: str) -> any:
2828
"""

mindsql/llms/open_ai.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from openai import OpenAI
22

3-
from .._utils.constants import OPENAI_VALUE_ERROR, OPENAI_PROMPT_EMPTY_EXCEPTION
43
from . import ILlm
4+
from .._utils.constants import OPENAI_VALUE_ERROR, OPENAI_PROMPT_EMPTY_EXCEPTION
55

66

77
class OpenAi(ILlm):
@@ -16,6 +16,7 @@ def __init__(self, config=None, client=None):
1616
Returns:
1717
None
1818
"""
19+
self.config = config
1920
self.client = client
2021

2122
if client is not None:
@@ -24,9 +25,8 @@ def __init__(self, config=None, client=None):
2425

2526
if 'api_key' not in config:
2627
raise ValueError(OPENAI_VALUE_ERROR)
27-
28-
if 'api_key' in config:
29-
self.client = OpenAI(api_key=config['api_key'])
28+
api_key = config.pop('api_key')
29+
self.client = OpenAI(api_key=api_key, **config)
3030

3131
def system_message(self, message: str) -> any:
3232
"""
@@ -82,6 +82,6 @@ def invoke(self, prompt, **kwargs) -> str:
8282
model = self.config.get("model", "gpt-3.5-turbo")
8383
temperature = kwargs.get("temperature", 0.1)
8484
max_tokens = kwargs.get("max_tokens", 500)
85-
response = self.client.chat.completions.create(model=model, messages=[{"role": "user", "content": prompt}], max_tokens=max_tokens, stop=None,
86-
temperature=temperature)
85+
response = self.client.chat.completions.create(model=model, messages=[{"role": "user", "content": prompt}],
86+
max_tokens=max_tokens, stop=None, temperature=temperature)
8787
return response.choices[0].message.content

0 commit comments

Comments
 (0)