Skip to content

Commit

Permalink
Merge pull request #29 from IBM/dev
Browse files Browse the repository at this point in the history
hotfix: release 0.0.6
  • Loading branch information
francescofuggitti committed Feb 15, 2024
2 parents 06a6945 + 0791620 commit 407c903
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 38 deletions.
18 changes: 5 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
repos:
- repo: https://github.com/ambv/black
rev: 22.6.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.1
hooks:
- id: black
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
args: [--exclude=tests/*]
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black"]
- id: ruff
args: [ --fix ]
- id: ruff-format
7 changes: 7 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Change Log

## 0.0.6 (2024-02-15)

* Use OpenAI v1.12.0.
* Update OpenAI API calls.
* Fix default value in greedy filter.
* Update tests.

## 0.0.5 (2024-02-10)

* Make Rasa an optional package.
Expand Down
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ pip install -e .
Once you have installed all dependencies you are ready to go with:
```python
from nl2ltl import translate
from nl2ltl.engines.rasa.core import RasaEngine
from nl2ltl.engines.gpt.core import GPTEngine, Models
from nl2ltl.filters.simple_filters import BasicFilter
from nl2ltl.engines.utils import pretty

engine = RasaEngine()
engine = GPTEngine()
filter = BasicFilter()
utterance = "Eventually send me a Slack after receiving a Gmail"

Expand All @@ -65,7 +65,8 @@ For instance, Rasa requires a `.tar.gz` format trained model in the
- [x] [Rasa](https://rasa.com/) intents/entities classifier (to use Rasa, please install it with `pip install -e ".[rasa]"`)
- [ ] [Watson Assistant](https://www.ibm.com/products/watson-assistant) intents/entities classifier -- Planned

To use GPT models you need to have the OPEN_API_KEY set as environment variable. To set it:
**NOTE**: To use OpenAI GPT models don't forget to add the `OPEN_API_KEY` environment
variable with:
```bash
export OPENAI_API_KEY=your_api_key
```
Expand Down Expand Up @@ -118,7 +119,11 @@ ltl_formulas = translate(utterance, engine=my_engine, filter=my_filter)
Contributions are welcome! Here's how to set up the development environment:
- set up your preferred virtualenv environment
- clone the repo: `git clone https://github.com/IBM/nl2ltl.git && cd nl2ltl`
- install dependencies: `pip install -e .`
- install dev dependencies: `pip install -e ".[dev]"`
- install pre-commit: `pre-commit install`
- sign-off your commits using the `-s` flag in the commit message to be compliant with
the [DCO](https://developercertificate.org/)

## Tests

Expand Down
15 changes: 9 additions & 6 deletions nl2ltl/engines/gpt/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,23 @@
"""
import json
import os
from enum import Enum
from pathlib import Path
from typing import Dict, Set

import openai
from openai import OpenAI
from pylogics.syntax.base import Formula

from nl2ltl.engines.base import Engine
from nl2ltl.engines.gpt import ENGINE_ROOT
from nl2ltl.engines.gpt.output import GPTOutput, parse_gpt_output, parse_gpt_result
from nl2ltl.filters.base import Filter

openai.api_key = os.getenv("OPENAI_API_KEY")
try:
client = OpenAI()
except Exception:
client = None

engine_root = ENGINE_ROOT
DATA_DIR = engine_root / "data"
PROMPT_PATH = engine_root / DATA_DIR / "prompt.json"
Expand Down Expand Up @@ -75,7 +78,7 @@ def _check_consistency(self) -> None:

def __check_openai_version(self):
"""Check that the GPT tool is at the right version."""
is_right_version = openai.__version__ == "1.12.0"
is_right_version = client._version == "1.12.0"
if not is_right_version:
raise Exception(
"OpenAI needs to be at version 1.12.0. "
Expand Down Expand Up @@ -149,7 +152,7 @@ def _process_utterance(
query = f"NL: {utterance}\n"
messages = [{"role": "user", "content": prompt + query}]
if operation_mode == OperationModes.CHAT.value:
prediction = openai.chat.completions.create(
prediction = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
Expand All @@ -160,7 +163,7 @@ def _process_utterance(
stop=["\n\n"],
)
else:
prediction = openai.completions.create(
prediction = client.completions.create(
model=model,
prompt=messages[0]["content"],
temperature=temperature,
Expand Down
10 changes: 4 additions & 6 deletions nl2ltl/engines/gpt/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def pattern(self) -> str:
Match,
re.search(
"PATTERN: (.*)\n",
self.output["choices"][0]["message"]["content"],
self.output.choices[0].message.content,
),
).group(1)
)
else:
return str(
cast(
Match,
re.search("PATTERN: (.*)\n", self.output["choices"][0]["text"]),
re.search("PATTERN: (.*)\n", self.output.choices[0].text),
).group(1)
)

Expand All @@ -61,15 +61,13 @@ def entities(self) -> Tuple[str]:
return tuple(
cast(
Match,
re.search("SYMBOLS: (.*)", self.output["choices"][0]["message"]["content"]),
re.search("SYMBOLS: (.*)", self.output.choices[0].message.content),
)
.group(1)
.split(", ")
)
else:
return tuple(
cast(Match, re.search("SYMBOLS: (.*)", self.output["choices"][0]["text"])).group(1).split(", ")
)
return tuple(cast(Match, re.search("SYMBOLS: (.*)", self.output.choices[0].text)).group(1).split(", "))


def parse_gpt_output(gpt_output: dict, operation_mode: str) -> GPTOutput:
Expand Down
3 changes: 2 additions & 1 deletion nl2ltl/filters/simple_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pylogics.syntax.base import Formula

from nl2ltl.declare.base import Template
from nl2ltl.filters.base import Filter
from nl2ltl.filters.utils.conflicts import conflicts
from nl2ltl.filters.utils.subsumptions import subsumptions
Expand Down Expand Up @@ -44,7 +45,7 @@ def enforce(output: Dict[Formula, float], entities: Dict[str, float], **kwargs)
"""
result_set = set()

highest_scoring_formula = max(output, key=output.get)
highest_scoring_formula = max(output, key=output.get, default=Template)
formula_conflicts = conflicts(highest_scoring_formula)
formula_subsumptions = subsumptions(highest_scoring_formula)

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nl2ltl"
version = "0.0.5"
version = "0.0.6"
license = {file = "LICENSE"}
authors = [
{ name = "Francesco Fuggitti", email = "francesco.fuggitti@gmail.com" },
Expand Down Expand Up @@ -34,7 +34,7 @@ classifiers = [

dependencies = [
"pylogics",
"openai"
"openai==1.12.0"
]

[project.optional-dependencies]
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

class UtterancesFixtures:
utterances = [
"whenever I get a Slack, send a Gmail",
"Invite Sales employees to Thursday's meeting",
"If a new Eventbrite is created, alert me through Slack",
"send me a Slack whenever I get a Gmail",
"whenever I get a Slack, send a Gmail.",
"Invite Sales employees.",
"If a new Eventbrite is created, alert me through Slack.",
"send me a Slack whenever I get a Gmail.",
]
4 changes: 2 additions & 2 deletions tests/test_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from nl2ltl import translate
from nl2ltl.engines.gpt.core import GPTEngine, Models
from nl2ltl.engines.gpt.core import GPTEngine
from nl2ltl.filters.simple_filters import BasicFilter, GreedyFilter

from .conftest import UtterancesFixtures
Expand All @@ -18,7 +18,7 @@ def setup_class(cls):
"""Setup any state specific to the execution of the given class (which
usually contains tests).
"""
cls.gpt_engine = GPTEngine(model=Models.GPT35_INSTRUCT.value)
cls.gpt_engine = GPTEngine()
cls.basic_filter = BasicFilter()
cls.greedy_filter = GreedyFilter()

Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ commands = ruff check .

[testenv:ruff-check-apply]
skip_install = True
deps = ruff==0.1.9r
deps = ruff==0.1.9
commands = ruff check --fix --show-fixes .

[testenv:ruff-format]
Expand Down

0 comments on commit 407c903

Please sign in to comment.