|
1 | 1 | import os
|
| 2 | +import difflib |
2 | 3 | from rich import print
|
3 | 4 |
|
4 | 5 | from typing import Optional, Dict, Any, List
|
@@ -32,9 +33,9 @@ def check_if_new_openai_version():
|
32 | 33 | from tenacity import retry, stop_after_attempt, wait_random
|
33 | 34 |
|
34 | 35 |
|
35 |
| -@retry(wait=wait_random(min=5, max=10), stop=stop_after_attempt(6)) |
| 36 | +@retry(wait=wait_random(min=5, max=10), stop=stop_after_attempt(3)) |
36 | 37 | def chat_completion_call(model, messages):
|
37 |
| - return openai.ChatCompletion.create(model=model, messages=messages) |
| 38 | + return openai.ChatCompletion.create(model=model, messages=messages, timeout=120, request_timeout=120) |
38 | 39 |
|
39 | 40 |
|
40 | 41 | class OpenAIChatRuntime(Runtime):
|
@@ -158,8 +159,27 @@ def record_to_record(
|
158 | 159 | ]
|
159 | 160 |
|
160 | 161 | completion_text = self.execute(messages)
|
| 162 | + |
| 163 | + field_schema = field_schema or {} |
| 164 | + if output_field_name in field_schema and field_schema[output_field_name]["type"] == "array": |
| 165 | + # expected output is one item from the array |
| 166 | + expected_items = field_schema[output_field_name]['items']['enum'] |
| 167 | + completion_text = self._match_items(completion_text, expected_items) |
| 168 | + |
161 | 169 | return {output_field_name: completion_text}
|
162 | 170 |
|
| 171 | + def _match_items(self, query: str, items: List[str]) -> str: |
| 172 | + # hard constraint: the item must be in the query |
| 173 | + filtered_items = [item for item in items if item in query] |
| 174 | + if not filtered_items: |
| 175 | + # make the best guess - find the most similar item to the query |
| 176 | + filtered_items = items |
| 177 | + |
| 178 | + # soft constraint: find the most similar item to the query |
| 179 | + scores = list(map(lambda item: difflib.SequenceMatcher(None, query, item).ratio(), filtered_items)) |
| 180 | + matched_item = filtered_items[scores.index(max(scores))] |
| 181 | + return matched_item |
| 182 | + |
163 | 183 |
|
164 | 184 | class OpenAIVisionRuntime(OpenAIChatRuntime):
|
165 | 185 | """
|
|
0 commit comments