Skip to content

Commit

Permalink
Add option to outlines adm to filter votes to positive responses only
Browse files Browse the repository at this point in the history
  • Loading branch information
dmjoy committed Jul 19, 2024
1 parent c99ca28 commit e44188a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
* Added dedicated function to utils for calculating votes (same voting scheme as the single KDMA ADM)
* Added top level config options to force determinism and fix seeds; along with an example experiment to demonstrate
* Added sampler parameter to outlines ADMs (example usage in `align_system/configs/experiment/examples/outlines_sampler.yaml`)
* Added option (on by default) to outlines ADM to filter votes to positive options only, can disable on the command line with `+adm.inference_kwargs.filter_votes_to_positives=False`

### Deprecated
* The algorithm `align_system/algorithms/chat_kdma_predicting_adm.py` has been replaced by `align_system/algorithms/outlines_regression_adm.py`
Expand Down
21 changes: 19 additions & 2 deletions align_system/algorithms/outlines_adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

from align_system.utils import logging
from align_system.utils import get_swagger_class_enum_values
from align_system.utils.voting import calculate_votes
from align_system.utils.voting import (
calculate_votes,
filter_votes_to_responses,
)
from align_system.algorithms.abstracts import ActionBasedADM
from align_system.prompt_engineering.outlines_prompts import (
baseline_system_prompt,
Expand Down Expand Up @@ -221,8 +224,22 @@ def top_level_choose_action(self,
extra={"markup": True})
log.explain(votes, extra={"highlighter": JSON_HIGHLIGHTER})

if kwargs.get('filter_votes_to_positives', True):
filtered_votes = filter_votes_to_responses(
votes, positive_responses_choices)

if filtered_votes != votes:
log.explain("Filtering votes down to choices where we "
"have a positive reponse")
log.explain(filtered_votes,
extra={"highlighter": JSON_HIGHLIGHTER})

final_votes = filtered_votes
else:
final_votes = votes

# Take top choice by score (votes is a dictionary of choice: score)
top_choice, top_choice_score = max(votes.items(), key=lambda x: x[1])
top_choice, top_choice_score = max(final_votes.items(), key=lambda x: x[1])
# Just taking first justification from the positive responses
# where the top choice was selected. A better approach might
# be to somehow summarized all justifications with the
Expand Down
11 changes: 11 additions & 0 deletions align_system/utils/voting.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,14 @@ def calculate_votes(possible_choices,
for choice, score in tmp_normalized_votes.items()}

return normalized_votes


def filter_votes_to_responses(votes, responses):
filtered_votes = {choice: score for choice, score in votes.items()
if choice in responses}

if len(filtered_votes) == 0:
raise RuntimeError(
"No votes left after filtering, was `reponses` empty?")

return filtered_votes

0 comments on commit e44188a

Please sign in to comment.