Skip to content

Commit af2dab8

Browse files
committed
support for OpenRouter and LiteLLM as options, adding model provider option
1 parent 9e6e9bc commit af2dab8

File tree

4 files changed

+58
-32
lines changed

4 files changed

+58
-32
lines changed

gpt_migrate/ai.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from langchain.chat_models import ChatOpenAI
21
import os
32
import openai
43
from utils import parse_code_string
@@ -8,25 +7,35 @@
87
openai.api_key = os.getenv("OPENROUTER_API_KEY")
98

109
class AI:
11-
def __init__(self, model="gpt-4-32k", temperature=0.1, max_tokens=10000):
10+
def __init__(self, model="gpt-4-32k", model_provider="openai", modelrouter="openrouter", temperature=0.1, max_tokens=10000):
1211
self.temperature = temperature
1312
self.max_tokens = max_tokens
13+
self.model_provider = model_provider
1414
self.model_name = model
15-
try:
16-
_ = ChatOpenAI(model_name=model) # check to see if model is available to user
17-
except Exception as e:
18-
print(e)
19-
self.model_name = "gpt-3.5-turbo"
15+
self.modelrouter = modelrouter
2016

2117
def write_code(self, prompt):
2218
message=[{"role": "user", "content": str(prompt)}]
23-
response = completion(
24-
messages=message,
25-
stream=False,
26-
model=self.model_name,
27-
max_tokens=self.max_tokens,
28-
temperature=self.temperature
29-
)
19+
if self.modelrouter == "openrouter":
20+
response = openai.ChatCompletion.create(
21+
model="{}/{}".format(self.model_provider,self.model_name), # Optional (user controls the default)
22+
messages=message,
23+
stream=False,
24+
max_tokens=self.max_tokens,
25+
temperature=self.temperature,
26+
headers={
27+
"HTTP-Referer": "https://gpt-migrate.com",
28+
"X-Title": "GPT-Migrate",
29+
},
30+
)
31+
else:
32+
response = completion(
33+
messages=message,
34+
stream=False,
35+
model=self.model_name,
36+
max_tokens=self.max_tokens,
37+
temperature=self.temperature
38+
)
3039
if response["choices"][0]["message"]["content"].startswith("INSTRUCTIONS:"):
3140
return ("INSTRUCTIONS:","",response["choices"][0]["message"]["content"][14:])
3241
else:
@@ -35,17 +44,32 @@ def write_code(self, prompt):
3544

3645
def run(self, prompt):
3746
message=[{"role": "user", "content": str(prompt)}]
38-
response = completion(
39-
messages=message,
40-
stream=True,
41-
model=self.model_name,
42-
max_tokens=self.max_tokens,
43-
temperature=self.temperature
44-
)
45-
chat = ""
46-
for chunk in response:
47-
delta = chunk["choices"][0]["delta"]
48-
msg = delta.get("content", "")
49-
chat += msg
50-
return chat
47+
if self.modelrouter == "openrouter":
48+
response = openai.ChatCompletion.create(
49+
model="{}/{}".format(self.model_provider,self.model_name), # Optional (user controls the default)
50+
messages=message,
51+
stream=False,
52+
max_tokens=self.max_tokens,
53+
temperature=self.temperature,
54+
headers={
55+
"HTTP-Referer": "https://gpt-migrate.com",
56+
"X-Title": "GPT-Migrate",
57+
},
58+
)
59+
return response["choices"][0]["message"]["content"]
60+
else:
61+
response = completion(
62+
messages=message,
63+
stream=True,
64+
model=self.model_name,
65+
max_tokens=self.max_tokens,
66+
temperature=self.temperature
67+
)
68+
chat = ""
69+
for chunk in response:
70+
delta = chunk["choices"][0]["delta"]
71+
msg = delta.get("content", "")
72+
chat += msg
73+
return chat
74+
5175

gpt_migrate/main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def __init__(self, sourcedir, targetdir, sourcelang, targetlang, sourceentry, so
3232
@app.command()
3333
def main(
3434
model: str = typer.Option("gpt-4-32k", help="Large Language Model to be used."),
35+
model_provider: str = typer.Option("openai", help="Model provider to be used."),
36+
modelrouter: str = typer.Option("openrouter", help="Model router to be used. Options are 'openrouter' or 'litellm'."),
3537
temperature: float = typer.Option(0, help="Temperature setting for the AI model."),
3638
sourcedir: str = typer.Option("../benchmarks/flask-nodejs/source", help="Source directory containing the code to be migrated."),
3739
sourcelang: str = typer.Option(None, help="Source language or framework of the code to be migrated."),
@@ -48,7 +50,9 @@ def main(
4850

4951
ai = AI(
5052
model=model,
53+
model_provider=model_provider,
5154
temperature=temperature,
55+
modelrouter=modelrouter
5256
)
5357

5458
sourcedir = os.path.abspath(sourcedir)
Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
rocket
2-
serde
3-
serde_json
1+
express
42
bcrypt
5-
rusqlite
6-
serde_json
3+
node-json-db
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[{"signature": "readItems() -> object", "description": "Reads the items from a JSON file and returns them as an object."}, {"signature": "writeItems(groceryItems: object)", "description": "Writes the given grocery items object to a JSON file."}]

0 commit comments

Comments
 (0)