6
6
7
7
"""
8
8
import json
9
- import os
10
9
from enum import Enum
11
10
from pathlib import Path
12
11
from typing import Dict , Set
13
12
14
- import openai
13
+ from openai import OpenAI
15
14
from pylogics .syntax .base import Formula
16
15
17
16
from nl2ltl .engines .base import Engine
18
17
from nl2ltl .engines .gpt import ENGINE_ROOT
19
18
from nl2ltl .engines .gpt .output import GPTOutput , parse_gpt_output , parse_gpt_result
20
19
from nl2ltl .filters .base import Filter
21
20
22
- openai .api_key = os .getenv ("OPENAI_API_KEY" )
21
+ try :
22
+ client = OpenAI ()
23
+ except Exception :
24
+ client = None
25
+
23
26
engine_root = ENGINE_ROOT
24
27
DATA_DIR = engine_root / "data"
25
28
PROMPT_PATH = engine_root / DATA_DIR / "prompt.json"
@@ -75,7 +78,7 @@ def _check_consistency(self) -> None:
75
78
76
79
def __check_openai_version (self ):
77
80
"""Check that the GPT tool is at the right version."""
78
- is_right_version = openai . __version__ == "1.12.0"
81
+ is_right_version = client . _version == "1.12.0"
79
82
if not is_right_version :
80
83
raise Exception (
81
84
"OpenAI needs to be at version 1.12.0. "
@@ -149,7 +152,7 @@ def _process_utterance(
149
152
query = f"NL: { utterance } \n "
150
153
messages = [{"role" : "user" , "content" : prompt + query }]
151
154
if operation_mode == OperationModes .CHAT .value :
152
- prediction = openai .chat .completions .create (
155
+ prediction = client .chat .completions .create (
153
156
model = model ,
154
157
messages = messages ,
155
158
temperature = temperature ,
@@ -160,7 +163,7 @@ def _process_utterance(
160
163
stop = ["\n \n " ],
161
164
)
162
165
else :
163
- prediction = openai .completions .create (
166
+ prediction = client .completions .create (
164
167
model = model ,
165
168
prompt = messages [0 ]["content" ],
166
169
temperature = temperature ,
0 commit comments