Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions src/lighteval/main_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import logging
from collections import defaultdict
from datetime import datetime
from typing import Literal

import requests
Expand Down Expand Up @@ -211,6 +212,20 @@ def eval( # noqa C901
models: Annotated[list[str], Argument(help="Models to evaluate")],
tasks: Annotated[str, Argument(help="Tasks to evaluate")],
# model arguments
model_base_url: Annotated[
str | None,
Option(
help="Base URL for communicating with the model API.",
rich_help_panel=HELP_PANEL_NAME_1,
),
] = None,
model_roles: Annotated[
str | None,
Option(
help="Model creation args (as a dictionary or as a path to a JSON or YAML config file)",
rich_help_panel=HELP_PANEL_NAME_1,
),
] = None,
max_tokens: Annotated[
int | None,
Option(
Expand Down Expand Up @@ -382,9 +397,9 @@ def eval( # noqa C901
] = None,
# Logging parameters
log_dir: Annotated[
str,
str | None,
Option(help="Log directory to use, will be created if it doesn't exist", rich_help_panel=HELP_PANEL_NAME_4),
] = "lighteval-logs",
] = None,
log_dir_allow_dirty: Annotated[
bool, Option(help="Allow dirty log directory", rich_help_panel=HELP_PANEL_NAME_4)
] = True,
Expand All @@ -396,6 +411,10 @@ def eval( # noqa C901
str | None,
Option(help="Bundle directory to use, will be created if it doesn't exist", rich_help_panel=HELP_PANEL_NAME_4),
] = None,
bundle_overwrite: Annotated[
bool,
Option(help="Overwrite bundle directory if it exists", rich_help_panel=HELP_PANEL_NAME_4),
] = True,
repo_id: Annotated[
str | None,
Option(help="Repository ID to use, will be created if it doesn't exist", rich_help_panel=HELP_PANEL_NAME_4),
Expand Down Expand Up @@ -428,6 +447,9 @@ def eval( # noqa C901
providers = _get_huggingface_providers(model)
models = [f"{model.replace(':all', '')}:{provider}" for provider in providers]

if log_dir is None:
log_dir = f"lighteval-logs-{datetime.now().strftime('%Y%m%d%H%M%S')}"

success, logs = inspect_ai_eval_set(
inspect_ai_tasks,
model=models,
Expand All @@ -440,7 +462,6 @@ def eval( # noqa C901
log_dir=log_dir,
log_dir_allow_dirty=log_dir_allow_dirty,
display=display,
bundle_dir=bundle_dir,
model_args=model_args,
max_tokens=max_tokens,
system_message=system_message,
Expand All @@ -463,10 +484,13 @@ def eval( # noqa C901
parallel_tool_calls=parallel_tool_calls,
max_tool_output=max_tool_output,
internal_tools=internal_tools,
overwrite=True,
bundle_dir=bundle_dir,
bundle_overwrite=bundle_overwrite,
)

if not success:
print("Error evaluating models")
print(f"run the same command with --log-dir {log_dir} to retry !")
return

results_per_model_per_task = {}
Expand All @@ -482,12 +506,17 @@ def eval( # noqa C901
table_md = results_to_markdown_table(results_per_model_per_task_agg)

if repo_id is not None:
push_to_hub(bundle_dir, repo_id, public=public)
if bundle_dir is not None:
push_to_hub(bundle_dir, repo_id, public=public)

print()
print(table_md)
print(f"results saved to {log_dir}")
print(f'run "inspect view --log-dir {log_dir}" to view the results')

if log_dir is not None:
print(f'run "inspect view --log-dir {log_dir}" to view the results')
else:
print("run 'inspect view' to view the results")


if __name__ == "__main__":
Expand Down
Loading