Skip to content

Commit

Permalink
phi2 update
Browse files Browse the repository at this point in the history
  • Loading branch information
Chirayu-Tripathi committed Apr 27, 2024
1 parent ea83abd commit ee0b10c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 18 deletions.
7 changes: 6 additions & 1 deletion nl2query/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import pandas as pd
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from peft import PeftModel


Expand Down
56 changes: 39 additions & 17 deletions nl2query/mongoquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
"""


import re
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from peft import PeftModel
from .base import QueryLanguage

Expand Down Expand Up @@ -117,14 +121,16 @@ def generate_query(
)
return query


class MongoQueryPhi2(QueryLanguage):
"""Base QueryLanguage class extended to perform query generation for MongoDB using Phi2 model."""

def __init__(
self,
path: str = "Chirayu/phi-2-mongodb",
):
"""Constructor for MongoQuery class"""

# self.db_schema = db_schema
self.adapter = path
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -138,21 +144,29 @@ def _load_model(self) -> object:
self.tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=True)
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
base_model_id, trust_remote_code=True, quantization_config=bnb_config, revision="refs/pr/23", device_map={"": 0}, torch_dtype="auto", flash_attn=True, flash_rotary=True, fused_dense=True
base_model_id,
trust_remote_code=True,
quantization_config=bnb_config,
revision="refs/pr/23",
device_map={"": 0},
torch_dtype="auto",
flash_attn=True,
flash_rotary=True,
fused_dense=True,
)

self.model = PeftModel.from_pretrained(model, self.adapter).to(self.device)
return self.model, self.tokenizer

def preprocess(self, db_schema: str, text: str) -> str:
"""Pre-Process the db_schema by removing new line and extra spaces, and creates a prompt for the model."""
db_schema = db_schema.replace("\n","").replace(" ","")
db_schema = db_schema.replace("\n", "").replace(" ", "")

prompt_template = f"""<s>
Task Description:
Expand Down Expand Up @@ -180,23 +194,31 @@ def generate_query(
"""Execute the Phi2 to generate the query for the MongoDB framework."""
prompt = self.preprocess(db_schema, textual_query)
model_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
output = self.model.generate(**model_inputs, max_length = max_length, no_repeat_ngram_size = no_repeat_ngram_size, repetition_penalty = repetition_penalty, pad_token_id = self.tokenizer.eos_token_id, eos_token_id = self.tokenizer.eos_token_id)[0]
output = self.model.generate(
**model_inputs,
max_length=max_length,
no_repeat_ngram_size=no_repeat_ngram_size,
repetition_penalty=repetition_penalty,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)[0]
query = self.tokenizer.decode(output, skip_special_tokens=False)
start_idx = query.index('Output')
start_idx = query.index("Output")
try:
stop_idx = query.index('</s>')
stop_idx = query.index("</s>")
except Exception as e:
print(e)
stop_idx = len(query)
return query[start_idx+8:stop_idx].strip()
return query[start_idx + 8 : stop_idx].strip()


class MongoQuery:
"""Primary class to call the appropriate model"""

def __new__(cls, model_type, **kwargs):
if model_type == 'T5':
if model_type == "T5":
return MongoQueryT5(**kwargs)
elif model_type == 'Phi2':
elif model_type == "Phi2":
return MongoQueryPhi2(**kwargs)
else:
raise ValueError("Invalid model_type. Expected 'T5' or 'Phi2'")
raise ValueError("Invalid model_type. Expected 'T5' or 'Phi2'")

0 comments on commit ee0b10c

Please sign in to comment.