Skip to content

Commit

Permalink
Merge pull request #7 from LemurPwned/feat/query-optim
Browse files Browse the repository at this point in the history
Feat/query optim
  • Loading branch information
LemurPwned authored Jan 4, 2025
2 parents e7a096b + a73879b commit 72518be
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 22 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ jobs:
permissions:
contents: read
packages: write
env:
REGISTRY: ghcr.io/lemurpwned
RELEASE_VERSION: ${{ needs.build.outputs.version }}

steps:
- name: Checkout repository
Expand All @@ -88,8 +91,8 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
file: ./docker/Dockerfile
push: true
tags: |
latest
${{ needs.build.outputs.version }}
${{ env.REGISTRY }}/cypher-shell:latest
${{ env.REGISTRY }}/cypher-shell:${RELEASE_VERSION}
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ python -m cypher_shell --help
or

```bash
python -m cypher_shell --cfg-path configs/movies.yaml
python -m cypher_shell run --cfg-path configs/movies.yaml
```

where `configs/movies.yaml` is a configuration file that contains the node and relationship descriptions.
Expand All @@ -31,9 +31,13 @@ You need to set the `.env` file with your OpenAI API key and Neo4j credentials.
You can also run the tool using Docker.

```bash
docker run --env .env -it ghcr.io/lemurpwned/cypher-shell:latest python3 -m cypher_shell --cfg-path configs/movies.yaml
docker run --env .env -it ghcr.io/lemurpwned/cypher-shell:latest python3 -m cypher_shell run --cfg-path configs/movies.yaml
```

### Run query without LLM

Just preface the query with: `cs:` and the query will not be rewritten by the llm.

## Notes:

- sometimes getting the schema automatically is better than providing it manually.
49 changes: 42 additions & 7 deletions cypher_shell/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os

Expand All @@ -16,6 +17,7 @@
from rich.console import Console

from .agent import CypherFlowSimple
from .optimizer import Optimizer
from .query_runner import QueryRunner
from .utils import get_logger

Expand All @@ -38,13 +40,7 @@ def validate(self, document):
raise ValidationError(message="Query cannot be empty", cursor_position=0)


@app.command(help="Run a Cypher shell")
def run(
cfg_path: str | None = typer.Option(default=None, help="Path to the .yaml configuration file"),
env_path: str | None = typer.Option(default=None, help="Path to the .env file"),
debug: bool = typer.Option(default=False, help="Enable debug mode"),
):
load_dotenv(env_path, override=True)
def load_cfg(cfg_path: str | None = None) -> dict:
cfg = {}
if cfg_path is None:
console.print(
Expand All @@ -59,6 +55,45 @@ def run(
assert (
"node_descriptions" in cfg and "relationship_descriptions" in cfg
), "Both node_descriptions and relationship_descriptions must be provided in the configuration file"
return cfg


@app.command(help="Optimize a Cypher query. Based on the query logs.")
def optimize(
log_path: str,
cfg_path: str | None = typer.Option(default=None, help="Path to the .yaml configuration file"),
env_path: str | None = typer.Option(default=None, help="Path to the .env file"),
debug: bool = typer.Option(default=False, help="Enable debug mode"),
min_timing: float = typer.Option(default=15.0, help="Minimum timing to consider for optimization"),
):
load_dotenv(env_path, override=True)
cfg = load_cfg(cfg_path)
optimizer = Optimizer(cfg)
# read up all the lines in the file
with open(log_path) as f:
all_queries = filter(
lambda x: x["timing"] > min_timing,
[json.loads(line) for line in f if json.loads(line)],
)
for query in sorted(
all_queries,
key=lambda x: x["timing"],
reverse=True,
):
logger.info(f"Optimizing query: {query}")
resp = optimizer.optimize_query(query)
console.print(resp)


@app.command(help="Run a Cypher shell")
def run(
cfg_path: str | None = typer.Option(default=None, help="Path to the .yaml configuration file"),
env_path: str | None = typer.Option(default=None, help="Path to the .env file"),
debug: bool = typer.Option(default=False, help="Enable debug mode"),
):
load_dotenv(env_path, override=True)
cfg = load_cfg(cfg_path)

if debug:
logger.setLevel(logging.DEBUG)
query_runner = QueryRunner(
Expand Down
7 changes: 3 additions & 4 deletions cypher_shell/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def add(self, message: MemoryMessage):
self.memory.append(message)

def add_user_result(self, user_query: str, machine_query: str, result: str, timing: float = -1):
self.memory.append(MemoryMessage(source="user", type="result", content=result))
self.memory.append(MemoryMessage(source="user", type="query", content=user_query))
self.memory.append(MemoryMessage(source="system", type="query", content=machine_query))
self.add(MemoryMessage(source="user", type="result", content=result))
self.add(MemoryMessage(source="user", type="query", content=user_query))
self.add(MemoryMessage(source="system", type="query", content=machine_query))
if self.track_user_queries:
self.user_queries[user_query] = result
if self.write_to_file:
Expand All @@ -62,7 +62,6 @@ def add_user_result(self, user_query: str, machine_query: str, result: str, timi
"cypher_query": machine_query,
"timing": timing,
},
indent=4,
)
)

Expand Down
33 changes: 33 additions & 0 deletions cypher_shell/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Any

from rich.markdown import Markdown
from swarm import Swarm

from .agent import Agent
from .prompts.optim import OPTIMIZATION_PROMPT_GENERAL


class Optimizer:
def __init__(self, cfg: dict):
self.optimizer_agent = Agent(
name="Cypher Query Optimizer",
model="gpt-4o-mini",
temperature=0.0,
instructions=OPTIMIZATION_PROMPT_GENERAL,
)
self.client = Swarm()

def __call__(self, *args: Any, **kwds: Any) -> Any:
pass

def optimize_query(self, query: dict[str, Any]) -> str:
msg = self.client.run(
agent=self.optimizer_agent,
messages=[
{
"role": "user",
"content": f"Query: {query}. Reread the query carefully: {query}",
}
],
)
return Markdown(msg.messages[-1]["content"])
2 changes: 2 additions & 0 deletions cypher_shell/prompts/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _query_run(session: neo4j.Session, query: str):


def get_nodes_schema(session: neo4j.Session):
logger.info("Retrieving node schema")
schema_query = """CALL db.schema.visualization()"""
results = session.run(schema_query)
data = results.data()
Expand All @@ -71,6 +72,7 @@ def node_and_rel_labels(session: neo4j.Session):


def get_properties(session: neo4j.Session):
logger.info("Retrieving node and relationship properties")
node_results = session.run("CALL db.schema.nodeTypeProperties()")
rel_results = session.run("CALL db.schema.relTypeProperties()")
node_data = node_results.data()
Expand Down
8 changes: 8 additions & 0 deletions cypher_shell/prompts/optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
OPTIMIZATION_PROMPT_GENERAL = """
You're an expert at optimizing Cypher queries.
You're given a Cypher query and the user query.
You need to optimize the query to be more efficient.
If you think adding an index will help, suggest it.
"""
5 changes: 0 additions & 5 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
FROM python:3.12-slim

WORKDIR /app
RUN apt-get update && \
apt-get install -y git && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

COPY . .

RUN python3 -m pip install -e .
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "cypher-shell"
description = "Cypher Shell -- a shell for querying Neo4j with Cypher"
url = "https://github.com/LemurPwned/cypher-shell"
version = "0.3"
version = "0.4"
authors = [
{ name = "LemurPwned", email = "lemurpwned@gmail.com" }
]
Expand Down

0 comments on commit 72518be

Please sign in to comment.