From e44188a0883a9783d3a025537689b771e3063399 Mon Sep 17 00:00:00 2001 From: David Joy <10147749+dmjoy@users.noreply.github.com> Date: Fri, 19 Jul 2024 13:06:59 -0400 Subject: [PATCH] Add option to outlines adm to filter votes to positive responses only --- CHANGELOG.md | 1 + align_system/algorithms/outlines_adm.py | 21 +++++++++++++++++++-- align_system/utils/voting.py | 11 +++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b2e0afd..54c465e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm * Added dedicated function to utils for calculating votes (same voting scheme as the single KDMA ADM) * Added top level config options to force determinism and fix seeds; along with an example experiment to demonstrate * Added sampler parameter to outlines ADMs (example usage in `align_system/configs/experiment/examples/outlines_sampler.yaml`) +* Added option (on by default) to outlines ADM to filter votes to positive options only, can disable on the command line with `+adm.inference_kwargs.filter_votes_to_positives=False` ### Deprecated * The algorithm `align_system/algorithms/chat_kdma_predicting_adm.py` has been replaced by `align_system/algorithms/outlines_regression_adm.py` diff --git a/align_system/algorithms/outlines_adm.py b/align_system/algorithms/outlines_adm.py index 9568756f..8c86ba29 100644 --- a/align_system/algorithms/outlines_adm.py +++ b/align_system/algorithms/outlines_adm.py @@ -15,7 +15,10 @@ from align_system.utils import logging from align_system.utils import get_swagger_class_enum_values -from align_system.utils.voting import calculate_votes +from align_system.utils.voting import ( + calculate_votes, + filter_votes_to_responses, +) from align_system.algorithms.abstracts import ActionBasedADM from align_system.prompt_engineering.outlines_prompts import ( baseline_system_prompt, @@ -221,8 +224,22 @@ def top_level_choose_action(self, extra={"markup": True}) log.explain(votes, extra={"highlighter": JSON_HIGHLIGHTER}) + if kwargs.get('filter_votes_to_positives', True): + filtered_votes = filter_votes_to_responses( + votes, positive_responses_choices) + + if filtered_votes != votes: + log.explain("Filtering votes down to choices where we " + "have a positive reponse") + log.explain(filtered_votes, + extra={"highlighter": JSON_HIGHLIGHTER}) + + final_votes = filtered_votes + else: + final_votes = votes + # Take top choice by score (votes is a dictionary of choice: score) - top_choice, top_choice_score = max(votes.items(), key=lambda x: x[1]) + top_choice, top_choice_score = max(final_votes.items(), key=lambda x: x[1]) # Just taking first justification from the positive responses # where the top choice was selected. A better approach might # be to somehow summarized all justifications with the diff --git a/align_system/utils/voting.py b/align_system/utils/voting.py index b459346a..fbd6fa23 100644 --- a/align_system/utils/voting.py +++ b/align_system/utils/voting.py @@ -33,3 +33,14 @@ def calculate_votes(possible_choices, for choice, score in tmp_normalized_votes.items()} return normalized_votes + + +def filter_votes_to_responses(votes, responses): + filtered_votes = {choice: score for choice, score in votes.items() + if choice in responses} + + if len(filtered_votes) == 0: + raise RuntimeError( + "No votes left after filtering, was `reponses` empty?") + + return filtered_votes