Skip to content

Commit e73e317

Browse files
author
nik
committed
Add labels match to OpenAIChatRuntime, RAG with only errors
1 parent 05e5ba4 commit e73e317

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

adala/runtimes/_openai.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import difflib
23
from rich import print
34

45
from typing import Optional, Dict, Any, List
@@ -32,9 +33,9 @@ def check_if_new_openai_version():
3233
from tenacity import retry, stop_after_attempt, wait_random
3334

3435

35-
@retry(wait=wait_random(min=5, max=10), stop=stop_after_attempt(6))
36+
@retry(wait=wait_random(min=5, max=10), stop=stop_after_attempt(3))
3637
def chat_completion_call(model, messages):
37-
return openai.ChatCompletion.create(model=model, messages=messages)
38+
return openai.ChatCompletion.create(model=model, messages=messages, timeout=120, request_timeout=120)
3839

3940

4041
class OpenAIChatRuntime(Runtime):
@@ -158,8 +159,27 @@ def record_to_record(
158159
]
159160

160161
completion_text = self.execute(messages)
162+
163+
field_schema = field_schema or {}
164+
if output_field_name in field_schema and field_schema[output_field_name]["type"] == "array":
165+
# expected output is one item from the array
166+
expected_items = field_schema[output_field_name]['items']['enum']
167+
completion_text = self._match_items(completion_text, expected_items)
168+
161169
return {output_field_name: completion_text}
162170

171+
def _match_items(self, query: str, items: List[str]) -> str:
172+
# hard constraint: the item must be in the query
173+
filtered_items = [item for item in items if item in query]
174+
if not filtered_items:
175+
# make the best guess - find the most similar item to the query
176+
filtered_items = items
177+
178+
# soft constraint: find the most similar item to the query
179+
scores = list(map(lambda item: difflib.SequenceMatcher(None, query, item).ratio(), filtered_items))
180+
matched_item = filtered_items[scores.index(max(scores))]
181+
return matched_item
182+
163183

164184
class OpenAIVisionRuntime(OpenAIChatRuntime):
165185
"""

adala/skills/collection/rag.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class RAGSkill(TransformSkill):
4646
output_template: str = "{rag}"
4747
num_results: int = 1
4848
memory: Memory = None
49+
only_errors: bool = True
4950

5051
@model_validator(mode="after")
5152
def init_memory(self):
@@ -119,13 +120,16 @@ def improve(
119120
runtime: Runtime to use for generation (not used).
120121
"""
121122

122-
error_indices = feedback.match[
123-
(feedback.match.fillna(True) == False).any(axis=1)
124-
].index
125-
inputs = predictions.loc[error_indices]
123+
if self.only_errors:
124+
indices = feedback.match[
125+
(feedback.match.fillna(True) == False).any(axis=1)
126+
].index
127+
else:
128+
indices = feedback.match.index
129+
inputs = predictions.loc[indices]
126130
input_strings = inputs.apply(
127131
lambda r: self.input_template.format(**r), axis=1
128132
).tolist()
129-
fb = feedback.feedback.loc[error_indices].rename(columns=lambda c: f"{c}__fb")
133+
fb = feedback.feedback.loc[indices].rename(columns=lambda c: f"{c}__fb")
130134
inputs = inputs.join(fb)
131135
self.memory.remember_many(input_strings, inputs.to_dict(orient="records"))

0 commit comments

Comments
 (0)