From d7c2e23d477e330b758f229a72b4fb9c37a3b756 Mon Sep 17 00:00:00 2001 From: nik Date: Thu, 30 Nov 2023 09:20:48 +0000 Subject: [PATCH] Reformat the code --- adala/runtimes/_openai.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/adala/runtimes/_openai.py b/adala/runtimes/_openai.py index 25daf39d..bb9d137c 100644 --- a/adala/runtimes/_openai.py +++ b/adala/runtimes/_openai.py @@ -35,7 +35,9 @@ def check_if_new_openai_version(): @retry(wait=wait_random(min=5, max=10), stop=stop_after_attempt(3)) def chat_completion_call(model, messages): - return openai.ChatCompletion.create(model=model, messages=messages, timeout=120, request_timeout=120) + return openai.ChatCompletion.create( + model=model, messages=messages, timeout=120, request_timeout=120 + ) class OpenAIChatRuntime(Runtime): @@ -161,9 +163,12 @@ def record_to_record( completion_text = self.execute(messages) field_schema = field_schema or {} - if output_field_name in field_schema and field_schema[output_field_name]["type"] == "array": + if ( + output_field_name in field_schema + and field_schema[output_field_name]["type"] == "array" + ): # expected output is one item from the array - expected_items = field_schema[output_field_name]['items']['enum'] + expected_items = field_schema[output_field_name]["items"]["enum"] completion_text = self._match_items(completion_text, expected_items) return {output_field_name: completion_text} @@ -176,7 +181,12 @@ def _match_items(self, query: str, items: List[str]) -> str: filtered_items = items # soft constraint: find the most similar item to the query - scores = list(map(lambda item: difflib.SequenceMatcher(None, query, item).ratio(), filtered_items)) + scores = list( + map( + lambda item: difflib.SequenceMatcher(None, query, item).ratio(), + filtered_items, + ) + ) matched_item = filtered_items[scores.index(max(scores))] return matched_item