Skip to content

Commit

Permalink
Merge branch 'main' into 'bugfix/cleanup'
Browse files Browse the repository at this point in the history
  • Loading branch information
maxzuo committed Jun 23, 2024
2 parents dfa8113 + b03835f commit 2b6518e
Show file tree
Hide file tree
Showing 13 changed files with 1,234 additions and 933 deletions.
12 changes: 6 additions & 6 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import tqdm
import torch

from planetarium import pddl, graph, metric, oracle
from planetarium import builder, graph, metric, oracle
import llm_planner as llmp

from utils import apply_template
Expand Down Expand Up @@ -197,18 +197,18 @@ def result():

try:
# try to parse the LLM output
llm_problem_graph = pddl.build(llm_problem_pddl)
llm_problem_graph = builder.build(llm_problem_pddl)
parseable = True

# reduce and further validate the LLM output
oracle.reduce(llm_problem_graph.decompose()[0], validate=True)
oracle.reduce(llm_problem_graph.decompose()[1], validate=True)
valid = True

problem_graph = pddl.build(problem_pddl)
problem_graph = builder.build(problem_pddl)
init, _ = problem_graph.decompose()

if len(llm_problem_graph._constants) != len(problem_graph._constants):
if len(llm_problem_graph.constants) != len(problem_graph.constants):
resolved = True
return result()

Expand Down Expand Up @@ -253,8 +253,8 @@ def full_equivalence(
bool: True if the scene graphs are equivalent, False otherwise.
"""
return metric.equals(
oracle.fully_specify(source),
oracle.fully_specify(target),
oracle.fully_specify(source, return_reduced=True),
oracle.fully_specify(target, return_reduced=True),
is_placeholder=is_placeholder,
)

Expand Down
61 changes: 58 additions & 3 deletions finetune.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections import defaultdict
from functools import partial
import os
import sqlite3
import yaml

import dotenv
Expand All @@ -10,6 +12,7 @@
from torch import nn

import bitsandbytes as bnb
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from transformers import (
AutoTokenizer,
Expand All @@ -23,14 +26,62 @@
import tqdm as tqdm

import llm_planner as llmp
from utils import apply_template, load_dataset, strip
from utils import apply_template

from accelerate import Accelerator


HF_USER_TOKEN = os.getenv("HF_USER_TOKEN")


def load_dataset(config: dict) -> dict[str, Dataset]:
"""Load the dataset from the configuration.
Args:
config (dict): The dataset configuration.
Returns:
dict[str, Dataset]: The loaded dataset.
"""
with open(config["splits_path"], "r") as f:
split_ids_cfg = yaml.safe_load(f)

splits: set[str] = config.get("splits", {}).keys()
dataset = {split: defaultdict(list) for split in splits}

# Connect to database
conn = sqlite3.connect(config["database_path"])
c = conn.cursor()

# load domains
domains = {}
c.execute("SELECT name, domain_pddl FROM domains")
for domain_name, domain_pddl in c.fetchall():
domains[domain_name] = domain_pddl

# load problems
for split in splits:
queries = []
split_keys: list[str] = config["splits"][split]
for split_key in split_keys:
split_ids = split_ids_cfg
for key in split_key:
split_ids = split_ids[key]

c.execute(
f"SELECT domain, problem_pddl, natural_language FROM problems WHERE id in ({', '.join(['?'] * len(split_ids))})",
split_ids,
)
queries.extend(c.fetchall())

for domain, problem_pddl, natural_language in queries:
dataset[split]["domain"].append(domains[domain])
dataset[split]["problem"].append(problem_pddl)
dataset[split]["natural_language"].append(natural_language)

return {s: Dataset.from_dict(d, split=s) for s, d in dataset.items()}


def find_all_linear_names(
model: nn.Module,
bits: int | None = None,
Expand Down Expand Up @@ -62,6 +113,10 @@ def find_all_linear_names(
return list(lora_module_names)


def strip(text: str, bos_token: str, eos_token: str) -> str:
return text.removeprefix(bos_token) + eos_token


def preprocess(
tokenizer: PreTrainedTokenizer,
examples,
Expand Down Expand Up @@ -130,7 +185,7 @@ def load_model(config: dict) -> tuple[PreTrainedTokenizer, PreTrainedModel]:
)
else:
bnb_config = None

device_index = Accelerator().process_index
device_map = {"": device_index}
model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -139,7 +194,7 @@ def load_model(config: dict) -> tuple[PreTrainedTokenizer, PreTrainedModel]:
token=HF_USER_TOKEN,
torch_dtype=torch.bfloat16,
quantization_config=bnb_config,
device_map=device_map
device_map=device_map,
)

lora_config = LoraConfig(
Expand Down
File renamed without changes.
Loading

0 comments on commit 2b6518e

Please sign in to comment.