Skip to content

Commit 407c903

Browse files
Merge pull request #29 from IBM/dev
hotfix: release 0.0.6
2 parents 06a6945 + 0791620 commit 407c903

10 files changed

+44
-38
lines changed

.pre-commit-config.yaml

+5-13
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,7 @@
11
repos:
2-
- repo: https://github.com/ambv/black
3-
rev: 22.6.0
2+
- repo: https://github.com/astral-sh/ruff-pre-commit
3+
rev: v0.2.1
44
hooks:
5-
- id: black
6-
- repo: https://github.com/PyCQA/flake8
7-
rev: 4.0.1
8-
hooks:
9-
- id: flake8
10-
args: [--exclude=tests/*]
11-
- repo: https://github.com/pycqa/isort
12-
rev: 5.12.0
13-
hooks:
14-
- id: isort
15-
args: ["--profile", "black"]
5+
- id: ruff
6+
args: [ --fix ]
7+
- id: ruff-format

CHANGES.md

+7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Change Log
22

3+
## 0.0.6 (2024-02-15)
4+
5+
* Use OpenAI v1.12.0.
6+
* Update OpenAI API calls.
7+
* Fix default value in greedy filter.
8+
* Update tests.
9+
310
## 0.0.5 (2024-02-10)
411

512
* Make Rasa an optional package.

README.md

+8-3
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ pip install -e .
3737
Once you have installed all dependencies you are ready to go with:
3838
```python
3939
from nl2ltl import translate
40-
from nl2ltl.engines.rasa.core import RasaEngine
40+
from nl2ltl.engines.gpt.core import GPTEngine, Models
4141
from nl2ltl.filters.simple_filters import BasicFilter
4242
from nl2ltl.engines.utils import pretty
4343

44-
engine = RasaEngine()
44+
engine = GPTEngine()
4545
filter = BasicFilter()
4646
utterance = "Eventually send me a Slack after receiving a Gmail"
4747

@@ -65,7 +65,8 @@ For instance, Rasa requires a `.tar.gz` format trained model in the
6565
- [x] [Rasa](https://rasa.com/) intents/entities classifier (to use Rasa, please install it with `pip install -e ".[rasa]"`)
6666
- [ ] [Watson Assistant](https://www.ibm.com/products/watson-assistant) intents/entities classifier -- Planned
6767

68-
To use GPT models you need to have the OPEN_API_KEY set as environment variable. To set it:
68+
**NOTE**: To use OpenAI GPT models don't forget to add the `OPEN_API_KEY` environment
69+
variable with:
6970
```bash
7071
export OPENAI_API_KEY=your_api_key
7172
```
@@ -118,7 +119,11 @@ ltl_formulas = translate(utterance, engine=my_engine, filter=my_filter)
118119
Contributions are welcome! Here's how to set up the development environment:
119120
- set up your preferred virtualenv environment
120121
- clone the repo: `git clone https://github.com/IBM/nl2ltl.git && cd nl2ltl`
122+
- install dependencies: `pip install -e .`
121123
- install dev dependencies: `pip install -e ".[dev]"`
124+
- install pre-commit: `pre-commit install`
125+
- sign-off your commits using the `-s` flag in the commit message to be compliant with
126+
the [DCO](https://developercertificate.org/)
122127

123128
## Tests
124129

nl2ltl/engines/gpt/core.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,23 @@
66
77
"""
88
import json
9-
import os
109
from enum import Enum
1110
from pathlib import Path
1211
from typing import Dict, Set
1312

14-
import openai
13+
from openai import OpenAI
1514
from pylogics.syntax.base import Formula
1615

1716
from nl2ltl.engines.base import Engine
1817
from nl2ltl.engines.gpt import ENGINE_ROOT
1918
from nl2ltl.engines.gpt.output import GPTOutput, parse_gpt_output, parse_gpt_result
2019
from nl2ltl.filters.base import Filter
2120

22-
openai.api_key = os.getenv("OPENAI_API_KEY")
21+
try:
22+
client = OpenAI()
23+
except Exception:
24+
client = None
25+
2326
engine_root = ENGINE_ROOT
2427
DATA_DIR = engine_root / "data"
2528
PROMPT_PATH = engine_root / DATA_DIR / "prompt.json"
@@ -75,7 +78,7 @@ def _check_consistency(self) -> None:
7578

7679
def __check_openai_version(self):
7780
"""Check that the GPT tool is at the right version."""
78-
is_right_version = openai.__version__ == "1.12.0"
81+
is_right_version = client._version == "1.12.0"
7982
if not is_right_version:
8083
raise Exception(
8184
"OpenAI needs to be at version 1.12.0. "
@@ -149,7 +152,7 @@ def _process_utterance(
149152
query = f"NL: {utterance}\n"
150153
messages = [{"role": "user", "content": prompt + query}]
151154
if operation_mode == OperationModes.CHAT.value:
152-
prediction = openai.chat.completions.create(
155+
prediction = client.chat.completions.create(
153156
model=model,
154157
messages=messages,
155158
temperature=temperature,
@@ -160,7 +163,7 @@ def _process_utterance(
160163
stop=["\n\n"],
161164
)
162165
else:
163-
prediction = openai.completions.create(
166+
prediction = client.completions.create(
164167
model=model,
165168
prompt=messages[0]["content"],
166169
temperature=temperature,

nl2ltl/engines/gpt/output.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ def pattern(self) -> str:
4040
Match,
4141
re.search(
4242
"PATTERN: (.*)\n",
43-
self.output["choices"][0]["message"]["content"],
43+
self.output.choices[0].message.content,
4444
),
4545
).group(1)
4646
)
4747
else:
4848
return str(
4949
cast(
5050
Match,
51-
re.search("PATTERN: (.*)\n", self.output["choices"][0]["text"]),
51+
re.search("PATTERN: (.*)\n", self.output.choices[0].text),
5252
).group(1)
5353
)
5454

@@ -61,15 +61,13 @@ def entities(self) -> Tuple[str]:
6161
return tuple(
6262
cast(
6363
Match,
64-
re.search("SYMBOLS: (.*)", self.output["choices"][0]["message"]["content"]),
64+
re.search("SYMBOLS: (.*)", self.output.choices[0].message.content),
6565
)
6666
.group(1)
6767
.split(", ")
6868
)
6969
else:
70-
return tuple(
71-
cast(Match, re.search("SYMBOLS: (.*)", self.output["choices"][0]["text"])).group(1).split(", ")
72-
)
70+
return tuple(cast(Match, re.search("SYMBOLS: (.*)", self.output.choices[0].text)).group(1).split(", "))
7371

7472

7573
def parse_gpt_output(gpt_output: dict, operation_mode: str) -> GPTOutput:

nl2ltl/filters/simple_filters.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from pylogics.syntax.base import Formula
55

6+
from nl2ltl.declare.base import Template
67
from nl2ltl.filters.base import Filter
78
from nl2ltl.filters.utils.conflicts import conflicts
89
from nl2ltl.filters.utils.subsumptions import subsumptions
@@ -44,7 +45,7 @@ def enforce(output: Dict[Formula, float], entities: Dict[str, float], **kwargs)
4445
"""
4546
result_set = set()
4647

47-
highest_scoring_formula = max(output, key=output.get)
48+
highest_scoring_formula = max(output, key=output.get, default=Template)
4849
formula_conflicts = conflicts(highest_scoring_formula)
4950
formula_subsumptions = subsumptions(highest_scoring_formula)
5051

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nl2ltl"
3-
version = "0.0.5"
3+
version = "0.0.6"
44
license = {file = "LICENSE"}
55
authors = [
66
{ name = "Francesco Fuggitti", email = "francesco.fuggitti@gmail.com" },
@@ -34,7 +34,7 @@ classifiers = [
3434

3535
dependencies = [
3636
"pylogics",
37-
"openai"
37+
"openai==1.12.0"
3838
]
3939

4040
[project.optional-dependencies]

tests/conftest.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
class UtterancesFixtures:
1515
utterances = [
16-
"whenever I get a Slack, send a Gmail",
17-
"Invite Sales employees to Thursday's meeting",
18-
"If a new Eventbrite is created, alert me through Slack",
19-
"send me a Slack whenever I get a Gmail",
16+
"whenever I get a Slack, send a Gmail.",
17+
"Invite Sales employees.",
18+
"If a new Eventbrite is created, alert me through Slack.",
19+
"send me a Slack whenever I get a Gmail.",
2020
]

tests/test_gpt.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
from nl2ltl import translate
7-
from nl2ltl.engines.gpt.core import GPTEngine, Models
7+
from nl2ltl.engines.gpt.core import GPTEngine
88
from nl2ltl.filters.simple_filters import BasicFilter, GreedyFilter
99

1010
from .conftest import UtterancesFixtures
@@ -18,7 +18,7 @@ def setup_class(cls):
1818
"""Setup any state specific to the execution of the given class (which
1919
usually contains tests).
2020
"""
21-
cls.gpt_engine = GPTEngine(model=Models.GPT35_INSTRUCT.value)
21+
cls.gpt_engine = GPTEngine()
2222
cls.basic_filter = BasicFilter()
2323
cls.greedy_filter = GreedyFilter()
2424

tox.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ commands = ruff check .
3535

3636
[testenv:ruff-check-apply]
3737
skip_install = True
38-
deps = ruff==0.1.9r
38+
deps = ruff==0.1.9
3939
commands = ruff check --fix --show-fixes .
4040

4141
[testenv:ruff-format]

0 commit comments

Comments
 (0)