Skip to content

Commit

Permalink
Updated eval + llm_planner for o1 models
Browse files Browse the repository at this point in the history
  • Loading branch information
maxzuo committed Oct 16, 2024
1 parent b974c18 commit 5ef6e3f
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 20 deletions.
26 changes: 22 additions & 4 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

HF_USER_TOKEN = os.getenv("HF_USER_TOKEN")
VALIDATE = os.getenv("VALIDATE", "Validate")

DOWNWARD = os.getenv("DOWNWARD", "downward")

def signal_handler(signum, frame):
raise TimeoutError("Timed out")
Expand Down Expand Up @@ -140,6 +140,7 @@ def load_planner(config: Mapping[str, dict[str, str]]) -> llmp.Planner:
elif config["model"]["type"] == "hf":
llm = llmp.VLLMPlanner(
config["model"]["model_name"],
lora=config["model"].get("lora"),
tokenizer=config["model"]["tokenizer_name"],
trust_remote_code=True,
dtype=torch.bfloat16,
Expand Down Expand Up @@ -274,6 +275,8 @@ def clean(pddl_str: str) -> str:
def validate(
pddl_str: str,
domain_str: str,
fast_downward: str = DOWNWARD,
**downward_args,
) -> bool:
"""Validate a PDDL problem as "solvable".
Expand All @@ -292,6 +295,17 @@ def validate(
valid = downward.validate(domain_str, pddl_str, plan, VALIDATE)
except (LarkError, AttributeError, ValueError):
pass
except (oracle.DomainNotSupportedError, NotImplementedError):
try:
plan_str, _ = downward.plan(
domain_str,
pddl_str,
downward=fast_downward,
**downward_args,
)
valid = downward.validate(domain_str, pddl_str, plan_str, VALIDATE)
except:
pass

return valid

Expand Down Expand Up @@ -325,7 +339,11 @@ def equivalence(

return (
parseable,
validate(llm_problem_pddl, domains[graphs["llm_problem_graph"].domain]),
validate(
llm_problem_pddl,
domains[graphs["llm_problem_graph"].domain],
alias="lama-first",
),
full_equivalence(
graphs["problem_graph"],
graphs["llm_problem_graph"],
Expand Down Expand Up @@ -465,7 +483,7 @@ def generate_openai(
problem_id,
config_str,
model_name,
llm_problem_pddl,
llm_problem_pddl[0],
),
)
pbar.update()
Expand Down Expand Up @@ -574,7 +592,7 @@ def _evaluate(args):
return problem_id, config_str, model_name, (None, None, None)
except Exception as e:
equivalent = None
raise e
print("ERROR", e, problem_id, llm_problem_pddl)
cursor.close()
return problem_id, config_str, model_name, (parseable, valid, equivalent)

Expand Down
52 changes: 36 additions & 16 deletions llm_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)

from vllm import LLM, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest


class PlanningProblem:
Expand Down Expand Up @@ -193,14 +194,15 @@ def plan_chat(
class VLLMPlanner(Planner):
"""A class for planning using VLLM models."""

def __init__(self, model_name: str, **kwargs):
def __init__(self, model_name: str, lora: str | None = None, **kwargs):
"""Initializes a new VLLMPlanner.
Args:
model_name (str): The name of the model to be used.
kwargs: Additional keyword arguments to be passed to the model.
"""
self.model = LLM(model_name, **kwargs)
self.lora = LoRARequest(lora, 1, lora) if lora else None
self.model = LLM(model_name, enable_lora=bool(lora), **kwargs)
self.tokenizer = self.model.get_tokenizer()

def plan_chat(
Expand Down Expand Up @@ -236,6 +238,7 @@ def plan_chat(
encoded,
params,
use_tqdm=False,
lora_request=self.lora,
)

return [output.outputs[0].text for output in outputs]
Expand All @@ -254,6 +257,7 @@ def __init__(self, model_name: str, **kwargs):
"""
self.client = OpenAI(**kwargs)
self.model_name = model_name
self.is_o1 = model_name.startswith("o1")

def _plan_chat(
self,
Expand All @@ -273,20 +277,36 @@ def _plan_chat(
str: The message completion.
"""

return (
self.client.chat.completions.create(
model=self.model_name,
messages=messages,
frequency_penalty=kwargs.get("frequency_penalty", None),
max_tokens=max_new_tokens,
n=1,
presence_penalty=kwargs.get("presence_penalty", None),
temperature=kwargs.get("temperature", 0.0),
top_p=kwargs.get("top_p", None),
if self.is_o1:
return (
self.client.chat.completions.create(
model=self.model_name,
messages=messages,
frequency_penalty=kwargs.get("frequency_penalty", None),
max_completion_tokens=max_new_tokens,
n=1,
presence_penalty=kwargs.get("presence_penalty", None),
temperature=kwargs.get("temperature", 0.0),
top_p=kwargs.get("top_p", None),
)
.choices[0]
.message.content
)
else:
return (
self.client.chat.completions.create(
model=self.model_name,
messages=messages,
frequency_penalty=kwargs.get("frequency_penalty", None),
max_tokens=max_new_tokens,
n=1,
presence_penalty=kwargs.get("presence_penalty", None),
temperature=kwargs.get("temperature", 0.0),
top_p=kwargs.get("top_p", None),
)
.choices[0]
.message.content
)
.choices[0]
.message.content
)

def plan_chat(
self,
Expand All @@ -313,4 +333,4 @@ def plan_chat(
**kwargs,
)
for message in messages
]
]

0 comments on commit 5ef6e3f

Please sign in to comment.