diff --git a/ldp/alg/__init__.py b/ldp/alg/__init__.py index c26a0f1..93241f4 100644 --- a/ldp/alg/__init__.py +++ b/ldp/alg/__init__.py @@ -1,4 +1,4 @@ -from .algorithms import to_network +from .algorithms import evaluate_consensus, to_network from .beam_search import Beam, BeamSearchRollout from .callbacks import ( Callback, @@ -45,5 +45,6 @@ "TrajectoryMetricsCallback", "TreeSearchRollout", "WandBLoggingCallback", + "evaluate_consensus", "to_network", ] diff --git a/ldp/alg/algorithms.py b/ldp/alg/algorithms.py index 9b43282..15ce489 100644 --- a/ldp/alg/algorithms.py +++ b/ldp/alg/algorithms.py @@ -1,9 +1,12 @@ +import asyncio +import collections import itertools -from collections.abc import Sequence -from typing import Any +import random +from collections.abc import Awaitable, Callable, Hashable, Sequence +from typing import Any, Literal, TypeVar, cast import networkx as nx -from aviary.core import Message, Tool, ToolRequestMessage, join +from aviary.core import Message, Tool, ToolRequestMessage, is_coroutine_callable, join from ldp.graph import OpResult from ldp.graph.ops import GradOutType @@ -120,3 +123,70 @@ def gvizify(x: Any) -> str: G.add_edge(op_call_str, arg_str, label=gvizify(grad), style="dotted") return G + + +TData = TypeVar("TData") +TGroupKey = TypeVar("TGroupKey", bound=Hashable) +TAnswer = TypeVar("TAnswer") +NO_IDEAL_ANSWER_FN: Literal["NO_IDEAL_ANSWER_FN"] = "NO_IDEAL_ANSWER_FN" # Sentinel + + +async def evaluate_consensus( + data: Sequence[TData], + grouping_fn: Callable[[TData], TGroupKey], + extract_answer_fn: Callable[[TData], TAnswer | Awaitable[TAnswer]], + ideal_answer_fn: ( + Callable[[TData], TAnswer] | Literal["NO_IDEAL_ANSWER_FN"] + ) = NO_IDEAL_ANSWER_FN, + num_samples: int = 1, + seed: int | None = None, +) -> tuple[dict[TGroupKey, list[tuple[TAnswer, int]]], float]: + """ + Create consensus groups and evaluate the consensus accuracy for each one. + + Args: + data: Data to evaluate consensus upon, length is called N. + grouping_fn: Function to extract the group key from a datum. + extract_answer_fn: Function to extract the actual answer from a datum. It can + be async so this can be done using a LLM call. + ideal_answer_fn: Optional function to extract the ideal answer from a datum to + compute accuracy with, or pass NO_IDEAL_ANSWER to skip this calculation. + num_samples: Number of samples to choose from the N total. + seed: Optional seed for sampling. + + Returns: + Two-tuple of consensus list generated by collections.Counter.most_common and + the proportion of groups for which the consensus matches the ideal. + """ + groups = collections.defaultdict(list) + for x in data: + groups[grouping_fn(x)].append(x) + + ideal_count = 0 + grouped_consensus: dict[TGroupKey, list[tuple[TAnswer, int]]] = {} + rand = random.Random(seed) if seed is not None else random + for group_key, group in groups.items(): + if len(group) < num_samples: # Too few items, sample with replacement + sampled = [rand.choice(group) for _ in range(num_samples)] + else: # Sample without replacement + sampled = rand.sample(group, num_samples) + + # Get answers for the sampled data + if is_coroutine_callable(extract_answer_fn): + extract_answer_fn = cast( + Callable[[TData], Awaitable[TAnswer]], extract_answer_fn + ) + answers = await asyncio.gather(*(extract_answer_fn(x) for x in sampled)) + else: + extract_answer_fn = cast(Callable[[TData], TAnswer], extract_answer_fn) + answers = [extract_answer_fn(x) for x in sampled] + + # Compute consensus: mode of the sampled answers + grouped_consensus[group_key] = collections.Counter(answers).most_common() + # NOTE: If there are multiple modes, just use the first one + consensus: TAnswer = grouped_consensus[group_key][0][0] + if ideal_answer_fn != NO_IDEAL_ANSWER_FN: + # Assume all items in the group have the same ideal answer + ideal_count += consensus == ideal_answer_fn(group[0]) + + return grouped_consensus, ideal_count / len(groups) if groups else 0.0 diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index e1ec2c4..e326265 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -1,7 +1,11 @@ +import operator + import pytest from aviary.core import DummyEnv +from aviary.utils import MultipleChoiceQuestion from ldp.agent import SimpleAgent +from ldp.alg import evaluate_consensus from ldp.utils import discounted_returns @@ -32,3 +36,74 @@ async def test_rollout_and_discounting(dummy_env: DummyEnv) -> None: print(terms) d_returns = discounted_returns(rewards, terms, 0.5) print(d_returns) + + +@pytest.mark.asyncio +async def test_evaluate_consensus() -> None: + # We have two questions, so let's group based on question + question_1 = MultipleChoiceQuestion( + question="What is the meaning of life?", + options=["-84", "11", "cheesecake"], + ideal_answer="42", + ) + question_2 = MultipleChoiceQuestion( + question="What is a healthy fruit?", + options=["brownie", "chocolate bar", "french fry"], + ideal_answer="apple", + ) + question_3 = MultipleChoiceQuestion( + question="What is the highest number?", + options=["1", "2", "4"], + ideal_answer="8", + ) + data_with_several_groups: list[tuple[MultipleChoiceQuestion, str]] = [ + # Correct consensus + (question_1, "-84"), + (question_1, "11"), + (question_1, "11"), + (question_1, "cheesecake"), + (question_1, "42"), + (question_1, "42"), + (question_1, "42"), + (question_1, "42"), + (question_1, "42"), + (question_1, "42"), + # Correct consensus + (question_2, "brownie"), + (question_2, "brownie"), + (question_2, "apple"), + (question_2, "apple"), + (question_2, "apple"), + (question_2, "apple"), + (question_2, "apple"), + (question_2, "apple"), + # Incorrect consensus + (question_3, "1"), + (question_3, "2"), + (question_3, "1"), + (question_3, "2"), + ] + # NOTE: this consensus is sensitive to seed + expected_consensus = { + question_1.question: [("42", 3), ("11", 1), ("-84", 1)], + question_2.question: [("apple", 4), ("brownie", 1)], + question_3.question: [("1", 3), ("2", 2)], + } + + # Check accuracy is 0% without an ideal answer + assert await evaluate_consensus( + data_with_several_groups, + grouping_fn=lambda x: x[0].question, + extract_answer_fn=operator.itemgetter(1), + num_samples=5, + seed=42, + ) == (expected_consensus, 0.0) + # Check accuracy is present when we can get an ideal answer + assert await evaluate_consensus( + data_with_several_groups, + grouping_fn=lambda x: x[0].question, + extract_answer_fn=operator.itemgetter(1), + ideal_answer_fn=lambda x: x[0].ideal_answer, + num_samples=5, + seed=42, + ) == (expected_consensus, 2 / 3)