Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor benchmark backend code #540

Merged
merged 2 commits into from
Dec 5, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 31 additions & 38 deletions backend/src/impl/db_utils/benchmark_db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

import pandas as pd
from explainaboard.utils.cache_api import get_cache_dir, open_cached_file
from explainaboard.utils.typing_utils import unwrap
from pandas import Series

from explainaboard_web.impl.auth import get_user
from explainaboard_web.impl.constants import ALL_LANG, LING_WEIGHT, POP_WEIGHT
from explainaboard_web.impl.db_utils.dataset_db_utils import DatasetDBUtils
from explainaboard_web.impl.db_utils.db_utils import DBUtils
from explainaboard_web.impl.db_utils.system_db_utils import SystemDBUtils
from explainaboard_web.impl.internal_models.system_model import SystemModel
from explainaboard_web.impl.utils import abort_with_error_message
from explainaboard_web.models import (
BenchmarkConfig,
Expand Down Expand Up @@ -126,8 +126,7 @@ def delete_benchmark_by_id(benchmark_id: str):
raise RuntimeError(f"failed to delete benchmark {benchmark_id}")

@staticmethod
def load_sys_infos(config: BenchmarkConfig) -> list[dict]:
sys_infos: list[dict] = []
def load_sys_infos(config: BenchmarkConfig) -> list[SystemModel]:
if config.system_query is not None:
systems_return = SystemDBUtils.find_systems(
dataset_name=config.system_query.get("dataset_name"),
Expand All @@ -153,28 +152,20 @@ def load_sys_infos(config: BenchmarkConfig) -> list[dict]:
else:
raise ValueError("system_query or datasets must be set by each benchmark")

systems = systems_return.systems
for system in systems:
temp = system.get_system_info().to_dict()
# Don't include systems with no dataset
if temp["dataset_name"] is not None:
temp["creator"] = system.creator.split("@")[0]
temp["created_at"] = system.created_at
sys_infos.append(temp)
return sys_infos
return systems_return.systems

@staticmethod
def generate_dataframe_from_sys_ids(config: BenchmarkConfig, system_ids: list[str]):
return NotImplementedError

@staticmethod
def generate_dataframe_from_sys_infos(
benchmark_config: BenchmarkConfig, systems: list[dict]
benchmark_config: BenchmarkConfig, systems: list[SystemModel]
):
"""
Generate a leaderboard from a list of system_output_info:SysOutputInfo
:param config: A benchmark config
:param systems: A list of system info dictionaries
:param systems: A list of SystemModel
:return: leaderboard:Leaderboard
"""
# --- Get df entries
Expand All @@ -191,7 +182,7 @@ def generate_dataframe_from_sys_infos(
else:
dataset_tuples = list(
{
(x["dataset_name"], x["sub_dataset_name"], x["dataset_split"])
(x.dataset.dataset_name, x.dataset.sub_dataset, x.dataset.split)
for x in systems
}
)
Expand All @@ -216,17 +207,20 @@ def generate_dataframe_from_sys_infos(
dataset_metadatas.append(dataset_return.datasets[0])

# --- Rearrange so we have each system's result over each dataset
system_dataset_results: dict[str, list[dict | None]] = {}
system_dataset_results: dict[str, list[SystemModel | None]] = {}
for sys in systems:
sys_name = sys["system_name"]
sys_name = sys.system_name
if sys_name not in system_dataset_results:
system_dataset_results[sys_name] = [
{"creator": sys["creator"]} for _ in dataset_configs
]
system_dataset_results[sys_name] = [None for _ in dataset_configs]
dataset_id = dataset_to_id[
(sys["dataset_name"], sys["sub_dataset_name"], sys["dataset_split"])
(sys.dataset.dataset_name, sys.dataset.sub_dataset, sys.dataset.split)
]
system_dataset_results[sys_name][dataset_id] = sys

system_to_creator: dict[str, str] = {
sys.system_name: sys.creator for sys in systems
}

# --- Set up the columns of the dataframe
# Default dataset information columns
df_input: dict[str, list] = {
Expand Down Expand Up @@ -255,11 +249,10 @@ def generate_dataframe_from_sys_infos(
df_input["score"] = []

# --- Create the actual data
for sys_name, sys_infos in system_dataset_results.items():
for dataset_config, dataset_metadata, sys_info_tmp in zip(
dataset_configs, dataset_metadatas, sys_infos
for sys_name, systems in system_dataset_results.items():
for dataset_config, dataset_metadata, sys in zip(
dataset_configs, dataset_metadatas, systems
):
sys_info = unwrap(sys_info_tmp)
column_dict = dict(dataset_config)
column_dict["system_name"] = sys_name
dataset_metrics: list[BenchmarkMetric] = dataset_config.get(
Expand All @@ -279,25 +272,25 @@ def generate_dataframe_from_sys_infos(
column_dict["metric_weight"] = dataset_metric.get(
"weight", 1.0 / len(dataset_metrics)
)
if len(sys_info) != 1:

column_dict["creator"] = sys_info["creator"]
matching_results = [
x
for x in sys_info["results"]["overall"][0]
if x.metric_name == dataset_metric["name"]
]
if sys is not None:

column_dict["creator"] = sys.creator
matching_results = []
for level, m in sys.results.items():
for k, v in m.items():
if k == dataset_metric["name"]:
matching_results.append(v)
if len(matching_results) != 1:
performance = None
else:
performance = matching_results[0]
column_dict["score"] = (
performance.value
performance
if performance
else (dataset_metric.get("default") or 0.0)
)
else:
column_dict["creator"] = sys_info["creator"]
column_dict["creator"] = system_to_creator[sys_name]
column_dict["score"] = dataset_metric.get("default") or 0.0
for df_key, df_arr in df_input.items():
if df_key in column_dict:
Expand Down Expand Up @@ -479,13 +472,13 @@ def generate_plots(benchmark_id):
json_dict = {k.name: [] for k in config.views}
json_dict["Original"] = []
json_dict["times"] = []
unique_dates = sorted(list({x["created_at"].date() for x in sys_infos}))
unique_dates = sorted(list({x.created_at.date() for x in sys_infos}))

for date in unique_dates:
systems = [sys for sys in sys_infos if sys["created_at"].date() <= date]
systems = [sys for sys in sys_infos if sys.created_at.date() <= date]
orig_df = BenchmarkDBUtils.generate_dataframe_from_sys_infos(
config, systems
)

system_dfs = BenchmarkDBUtils.generate_view_dataframes(
config, orig_df, by_creator=False
)
Expand Down