Skip to content

Commit

Permalink
phi2 update
Browse files Browse the repository at this point in the history
  • Loading branch information
Chirayu-Tripathi committed Apr 27, 2024
1 parent 38f39d2 commit ea83abd
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 21 deletions.
7 changes: 0 additions & 7 deletions CHANGELOG.rst

This file was deleted.

22 changes: 22 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# This CITATION.cff file was generated with cffinit.
# Visit https://bit.ly/cffinit to generate yours today!

cff-version: 1.2.0
title: nl2query
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Chirayu
family-names: Tripathi
email: chirayutripathi7@gmail.com
orcid: 'https://orcid.org/0000-0001-9495-0063'
repository-code: 'https://github.com/Chirayu-Tripathi/nl2query.git'
abstract: >-
Convert natural language text inputs to Pandas, MongoDB,
Kusto, and Cypher(Neo4j) queries. The models used are
fine-tuned versions of CodeT5+ 220m and Phi2 models.
license: MIT
version: 0.1.6
date-released: '2024-04-27'
110 changes: 106 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# nl2query

> Convert natural language text inputs to Pandas, MongoDB, Kusto, and Cypher(Neo4j) queries. The models used are fine-tuned versions of CodeT5+ 220m models.
> Convert natural language text inputs to Pandas, MongoDB, Kusto, and Cypher(Neo4j) queries. The models used are fine-tuned versions of CodeT5+ 220m and Phi2 model.

[![Downloads](https://static.pepy.tech/badge/nl2query)](https://pepy.tech/project/nl2query)
Expand Down Expand Up @@ -37,22 +37,124 @@ queryfier.generate_query('''which cabinet has average age less than 21?''') #Gro
```

## 2. MongoDB Query
Suppose you want to convert the textual question to Mongo query, follow the code below
Suppose you want to convert the textual question to Mongo query, follow the instruction code below

### MongoDB query using CodeT5

The generate_query method takes a textual query and returns a MongoDB query. It also accepts optional parameters to control the generation process, such as num_beams, max_length, repetition_penalty, length_penalty, early_stopping, top_p, top_k, and num_return_sequences.

```py
from nl2query import MongoQuery
import pymongo # import if performing analysis using python client
keys = ['_id', 'index', 'passengerid', 'survived', 'Pclass', 'name', 'sex', 'age', 'sibsp', 'parch', 'ticket', 'fare', 'cabin', 'embarked'] #keys present in the collection to be queried.
queryfier = MongoQuery(keys, 'titanic')
queryfier = MongoQuery('T5', collection_keys = keys, collection_name = 'titanic')
queryfier.generate_query('''which pclass has the minimum average fare?''')

keys = ['_id', 'index', 'total_bill', 'tip', 'sex', 'smoker', 'day', 'time', 'size']
queryfier = MongoQuery(keys, 'tips')
queryfier = MongoQuery('T5', collection_keys = keys, collection_name = 'titanic')
queryfier.generate_query('''find the day on which combined sales was highest''')

```
In the above code the keys can be found by running the following piece `db.tips.find_one({}).keys()`

### MongoDB query using Phi2

The generate_query method takes a database schema and a textual query and returns a MongoDB query. It also accepts optional parameters to control the generation process, such as max_length, no_repeat_ngram_size, and repetition_penalty. *The Phi2 model performs better than the CodeT5+ model.*

```py
from nl2query import MongoQuery
schema = shipwreck = '''{
"collections": [
{
"name": "shipwrecks",
"indexes": [
{
"key": {
"_id": 1
}
},
{
"key": {
"feature_type": 1
}
},
{
"key": {
"chart": 1
}
},
{
"key": {
"latdec": 1,
"londec": 1
}
}
],
"uniqueIndexes": [],
"document": {
"properties": {
"_id": {
"bsonType": "string"
},
"recrd": {
"bsonType": "string"
},
"vesslterms": {
"bsonType": "string"
},
"feature_type": {
"bsonType": "string"
},
"chart": {
"bsonType": "string"
},
"latdec": {
"bsonType": "double"
},
"londec": {
"bsonType": "double"
},
"gp_quality": {
"bsonType": "string"
},
"depth": {
"bsonType": "string"
},
"sounding_type": {
"bsonType": "string"
},
"history": {
"bsonType": "string"
},
"quasou": {
"bsonType": "string"
},
"watlev": {
"bsonType": "string"
},
"coordinates": {
"bsonType": "array",
"items": {
"bsonType": "double"
}
}
}
}
}
],
"version": 1
}'''

queryfier = MongoQuery('Phi2')
text = 'Find the count of shipwrecks for each unique combination of "latdec" and "longdec"'
queryfier.generate_query(schema, text, max_length = 1024)

text = 'Find the total count of shipwreck for each unique category of chart'
queryfier.generate_query(schema, text, max_length = 1024)


```


## 3. Kusto Query
Suppose you want to convert the textual question to Kusto query, follow the code below
Expand Down
34 changes: 34 additions & 0 deletions changelog.d/20240426_154856_chirayutripathi7.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
.. A new scriv changelog fragment.
..
.. Uncomment the header that is right (remove the leading dots).
..
.. Removed
.. -------
..
.. - A bullet item for the Removed category.
..
### Added
-----
- Added support for Phi2 model to solve mongodb query conversion.
.. - A bullet item for the Added category.
..
.. Changed
.. -------
..
.. - A bullet item for the Changed category.
..
.. Deprecated
.. ----------
..
.. - A bullet item for the Deprecated category.
..
.. Fixed
.. -----
..
.. - A bullet item for the Fixed category.
..
.. Security
.. --------
..
.. - A bullet item for the Security category.
..
4 changes: 3 additions & 1 deletion nl2query/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import pandas as pd
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel


from .cypherquery import CypherQuery
from .kustoquery import KustoQuery
Expand Down
113 changes: 108 additions & 5 deletions nl2query/mongoquery.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
import re
"""
This module contains classes for generating MongoDB queries using different models.
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
Classes:
MongoQueryT5: Uses T5 model to generate MongoDB queries.
MongoQueryPhi2: Uses Phi2 model to generate MongoDB queries.
MongoQuery: Factory class to create an instance of either MongoQueryT5 or MongoQueryPhi2.
Each class has methods to load the model, preprocess the input, and generate the query.
Example:
To create an instance of MongoQueryT5:
>>> mq = MongoQuery('T5', collection_keys=['key1', 'key2'], collection_name='my_collection')
To generate a query:
>>> query = mq.generate_query('Find all documents where key1 is "value1"')
"""


import re
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from .base import QueryLanguage


class MongoQuery(QueryLanguage):
class MongoQueryT5(QueryLanguage):
"""Base QueryLanguage class extended to perform query generation for MongoDB"""

def __init__(
Expand All @@ -26,7 +45,7 @@ def __init__(
# self.db = db

def _load_model(self) -> object:
"""Constructor for MongoQuery class"""
"""Helper function to load the model for MongoQuery class"""
model = AutoModelForSeq2SeqLM.from_pretrained(self.path)
self.tokenizer = AutoTokenizer.from_pretrained(self.path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -97,3 +116,87 @@ def generate_query(
pattern, lambda x: {**self.keys_mapping, **upper_text}[x.group()], query
)
return query

class MongoQueryPhi2(QueryLanguage):
"""Base QueryLanguage class extended to perform query generation for MongoDB using Phi2 model."""
def __init__(
self,
path: str = "Chirayu/phi-2-mongodb",
):
"""Constructor for MongoQuery class"""

# self.db_schema = db_schema
self.adapter = path
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._load_model()
# self.db = db

def _load_model(self) -> object:
"""Helper function to load the model for MongoQuery class"""

base_model_id = "microsoft/phi-2"
self.tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=True)
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
base_model_id, trust_remote_code=True, quantization_config=bnb_config, revision="refs/pr/23", device_map={"": 0}, torch_dtype="auto", flash_attn=True, flash_rotary=True, fused_dense=True
)

self.model = PeftModel.from_pretrained(model, self.adapter).to(self.device)
return self.model, self.tokenizer

def preprocess(self, db_schema: str, text: str) -> str:
"""Pre-Process the db_schema by removing new line and extra spaces, and creates a prompt for the model."""
db_schema = db_schema.replace("\n","").replace(" ","")

prompt_template = f"""<s>
Task Description:
Your task is to create a MongoDB query that accurately fulfills the provided Instruct while strictly adhering to the given MongoDB schema. Ensure that the query solely relies on keys and columns present in the schema. Minimize the usage of lookup operations wherever feasible to enhance query efficiency.
MongoDB Schema:
{db_schema}
### Instruct:
{text}
### Output:
"""

return prompt_template

def generate_query(
self,
db_schema: str,
textual_query: str,
max_length: int = 1024,
no_repeat_ngram_size: int = 10,
repetition_penalty: int = 1.02,
) -> str:
"""Execute the Phi2 to generate the query for the MongoDB framework."""
prompt = self.preprocess(db_schema, textual_query)
model_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
output = self.model.generate(**model_inputs, max_length = max_length, no_repeat_ngram_size = no_repeat_ngram_size, repetition_penalty = repetition_penalty, pad_token_id = self.tokenizer.eos_token_id, eos_token_id = self.tokenizer.eos_token_id)[0]
query = self.tokenizer.decode(output, skip_special_tokens=False)
start_idx = query.index('Output')
try:
stop_idx = query.index('</s>')
except Exception as e:
print(e)
stop_idx = len(query)
return query[start_idx+8:stop_idx].strip()


class MongoQuery:
"""Primary class to call the appropriate model"""
def __new__(cls, model_type, **kwargs):
if model_type == 'T5':
return MongoQueryT5(**kwargs)
elif model_type == 'Phi2':
return MongoQueryPhi2(**kwargs)
else:
raise ValueError("Invalid model_type. Expected 'T5' or 'Phi2'")
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "nl2query"
version = "0.1.5"
version = "0.1.6"
description = ""
authors = ["Chirayu-Tripathi <chirayutripathi7@gmail.com>"]
readme = "README.md"
Expand All @@ -9,8 +9,9 @@ readme = "README.md"
python = "^3.8"
torch = { version = "^2.0.1", source = "pytorch" }
regex = "^2023.6.3"
transformers = "^4.31.0"
transformers = "^4.38.2"
pandas = "^1.5.3"
peft = "^0.5.0"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.3.3"
Expand Down
2 changes: 0 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@ deps =
flake8
isort
mccabe
pylint
pytest
commands =
black --check nl2query
isort --check nl2query
flake8 nl2query --max-complexity 10
pylint nl2query
pytest .
coverage run --source=nl2query --branch -m pytest .
coverage report -m
Expand Down

0 comments on commit ea83abd

Please sign in to comment.