diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d601eaa..3d11d52 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,15 +1,7 @@ repos: -- repo: https://github.com/ambv/black - rev: 22.6.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.2.1 hooks: - - id: black -- repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 - hooks: - - id: flake8 - args: [--exclude=tests/*] -- repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black"] \ No newline at end of file + - id: ruff + args: [ --fix ] + - id: ruff-format diff --git a/CHANGES.md b/CHANGES.md index 85c4939..fb7bb76 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,12 @@ # Change Log +## 0.0.6 (2024-02-15) + +* Use OpenAI v1.12.0. +* Update OpenAI API calls. +* Fix default value in greedy filter. +* Update tests. + ## 0.0.5 (2024-02-10) * Make Rasa an optional package. diff --git a/README.md b/README.md index a9f385c..f0c2a93 100644 --- a/README.md +++ b/README.md @@ -37,11 +37,11 @@ pip install -e . Once you have installed all dependencies you are ready to go with: ```python from nl2ltl import translate -from nl2ltl.engines.rasa.core import RasaEngine +from nl2ltl.engines.gpt.core import GPTEngine, Models from nl2ltl.filters.simple_filters import BasicFilter from nl2ltl.engines.utils import pretty -engine = RasaEngine() +engine = GPTEngine() filter = BasicFilter() utterance = "Eventually send me a Slack after receiving a Gmail" @@ -65,7 +65,8 @@ For instance, Rasa requires a `.tar.gz` format trained model in the - [x] [Rasa](https://rasa.com/) intents/entities classifier (to use Rasa, please install it with `pip install -e ".[rasa]"`) - [ ] [Watson Assistant](https://www.ibm.com/products/watson-assistant) intents/entities classifier -- Planned -To use GPT models you need to have the OPEN_API_KEY set as environment variable. To set it: +**NOTE**: To use OpenAI GPT models don't forget to add the `OPEN_API_KEY` environment +variable with: ```bash export OPENAI_API_KEY=your_api_key ``` @@ -118,7 +119,11 @@ ltl_formulas = translate(utterance, engine=my_engine, filter=my_filter) Contributions are welcome! Here's how to set up the development environment: - set up your preferred virtualenv environment - clone the repo: `git clone https://github.com/IBM/nl2ltl.git && cd nl2ltl` +- install dependencies: `pip install -e .` - install dev dependencies: `pip install -e ".[dev]"` +- install pre-commit: `pre-commit install` +- sign-off your commits using the `-s` flag in the commit message to be compliant with +the [DCO](https://developercertificate.org/) ## Tests diff --git a/nl2ltl/engines/gpt/core.py b/nl2ltl/engines/gpt/core.py index 4cce8d0..4068b57 100644 --- a/nl2ltl/engines/gpt/core.py +++ b/nl2ltl/engines/gpt/core.py @@ -6,12 +6,11 @@ """ import json -import os from enum import Enum from pathlib import Path from typing import Dict, Set -import openai +from openai import OpenAI from pylogics.syntax.base import Formula from nl2ltl.engines.base import Engine @@ -19,7 +18,11 @@ from nl2ltl.engines.gpt.output import GPTOutput, parse_gpt_output, parse_gpt_result from nl2ltl.filters.base import Filter -openai.api_key = os.getenv("OPENAI_API_KEY") +try: + client = OpenAI() +except Exception: + client = None + engine_root = ENGINE_ROOT DATA_DIR = engine_root / "data" PROMPT_PATH = engine_root / DATA_DIR / "prompt.json" @@ -75,7 +78,7 @@ def _check_consistency(self) -> None: def __check_openai_version(self): """Check that the GPT tool is at the right version.""" - is_right_version = openai.__version__ == "1.12.0" + is_right_version = client._version == "1.12.0" if not is_right_version: raise Exception( "OpenAI needs to be at version 1.12.0. " @@ -149,7 +152,7 @@ def _process_utterance( query = f"NL: {utterance}\n" messages = [{"role": "user", "content": prompt + query}] if operation_mode == OperationModes.CHAT.value: - prediction = openai.chat.completions.create( + prediction = client.chat.completions.create( model=model, messages=messages, temperature=temperature, @@ -160,7 +163,7 @@ def _process_utterance( stop=["\n\n"], ) else: - prediction = openai.completions.create( + prediction = client.completions.create( model=model, prompt=messages[0]["content"], temperature=temperature, diff --git a/nl2ltl/engines/gpt/output.py b/nl2ltl/engines/gpt/output.py index 5377933..b35ba31 100644 --- a/nl2ltl/engines/gpt/output.py +++ b/nl2ltl/engines/gpt/output.py @@ -40,7 +40,7 @@ def pattern(self) -> str: Match, re.search( "PATTERN: (.*)\n", - self.output["choices"][0]["message"]["content"], + self.output.choices[0].message.content, ), ).group(1) ) @@ -48,7 +48,7 @@ def pattern(self) -> str: return str( cast( Match, - re.search("PATTERN: (.*)\n", self.output["choices"][0]["text"]), + re.search("PATTERN: (.*)\n", self.output.choices[0].text), ).group(1) ) @@ -61,15 +61,13 @@ def entities(self) -> Tuple[str]: return tuple( cast( Match, - re.search("SYMBOLS: (.*)", self.output["choices"][0]["message"]["content"]), + re.search("SYMBOLS: (.*)", self.output.choices[0].message.content), ) .group(1) .split(", ") ) else: - return tuple( - cast(Match, re.search("SYMBOLS: (.*)", self.output["choices"][0]["text"])).group(1).split(", ") - ) + return tuple(cast(Match, re.search("SYMBOLS: (.*)", self.output.choices[0].text)).group(1).split(", ")) def parse_gpt_output(gpt_output: dict, operation_mode: str) -> GPTOutput: diff --git a/nl2ltl/filters/simple_filters.py b/nl2ltl/filters/simple_filters.py index 506b40d..65ae13c 100644 --- a/nl2ltl/filters/simple_filters.py +++ b/nl2ltl/filters/simple_filters.py @@ -3,6 +3,7 @@ from pylogics.syntax.base import Formula +from nl2ltl.declare.base import Template from nl2ltl.filters.base import Filter from nl2ltl.filters.utils.conflicts import conflicts from nl2ltl.filters.utils.subsumptions import subsumptions @@ -44,7 +45,7 @@ def enforce(output: Dict[Formula, float], entities: Dict[str, float], **kwargs) """ result_set = set() - highest_scoring_formula = max(output, key=output.get) + highest_scoring_formula = max(output, key=output.get, default=Template) formula_conflicts = conflicts(highest_scoring_formula) formula_subsumptions = subsumptions(highest_scoring_formula) diff --git a/pyproject.toml b/pyproject.toml index 05f2528..5a5e947 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "nl2ltl" -version = "0.0.5" +version = "0.0.6" license = {file = "LICENSE"} authors = [ { name = "Francesco Fuggitti", email = "francesco.fuggitti@gmail.com" }, @@ -34,7 +34,7 @@ classifiers = [ dependencies = [ "pylogics", - "openai" + "openai==1.12.0" ] [project.optional-dependencies] diff --git a/tests/conftest.py b/tests/conftest.py index 0e2ff51..ca9ca9b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,8 +13,8 @@ class UtterancesFixtures: utterances = [ - "whenever I get a Slack, send a Gmail", - "Invite Sales employees to Thursday's meeting", - "If a new Eventbrite is created, alert me through Slack", - "send me a Slack whenever I get a Gmail", + "whenever I get a Slack, send a Gmail.", + "Invite Sales employees.", + "If a new Eventbrite is created, alert me through Slack.", + "send me a Slack whenever I get a Gmail.", ] diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 643bfa6..841e13a 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -4,7 +4,7 @@ import pytest from nl2ltl import translate -from nl2ltl.engines.gpt.core import GPTEngine, Models +from nl2ltl.engines.gpt.core import GPTEngine from nl2ltl.filters.simple_filters import BasicFilter, GreedyFilter from .conftest import UtterancesFixtures @@ -18,7 +18,7 @@ def setup_class(cls): """Setup any state specific to the execution of the given class (which usually contains tests). """ - cls.gpt_engine = GPTEngine(model=Models.GPT35_INSTRUCT.value) + cls.gpt_engine = GPTEngine() cls.basic_filter = BasicFilter() cls.greedy_filter = GreedyFilter() diff --git a/tox.ini b/tox.ini index 7c3febc..906f448 100644 --- a/tox.ini +++ b/tox.ini @@ -35,7 +35,7 @@ commands = ruff check . [testenv:ruff-check-apply] skip_install = True -deps = ruff==0.1.9r +deps = ruff==0.1.9 commands = ruff check --fix --show-fixes . [testenv:ruff-format]