Skip to content

Commit

Permalink
Merge pull request #47 from ITM-Kitware/dev/metrics-eval-quickfixes
Browse files Browse the repository at this point in the history
Dev/metrics eval quickfixes
  • Loading branch information
dmjoy authored Apr 24, 2024
2 parents 42a40e6 + 0618a62 commit 1ca9d4f
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 12 deletions.
13 changes: 12 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,24 @@
This changelog follows the specifications detailed in: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html), although we have not yet reached a `1.0.0` release.

## Unreleased
## 0.3.3

### Changed

* Modified the prompt for PulseTaggingADM. Also removed duplicated inference call within `identify_tag_color`
method. Additionally, removed duplicated RED tag in-context example and replaced with missing BLACK tag
example.
* Changed default maximization prompt for Kaleido

### Fixed

* Applied attention fixes for Kaliedo provided by UWash
* Fixed an "other choice" ordering issue in Kaleido ADM

### Added

* Added an additional parsing guard in Llama2SinglaKDMAADM
* Added do_sample as an init kwarg for Llama2SinglaKDMAADM (set to False for temperature 0)

## 0.3.2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ adm:
answer_attempts: 5

alignment_target_override:
id: maximization_low
id: maximization_high
kdma_values:
- kdma: maximization
value: 0.1
value: 0.9
2 changes: 1 addition & 1 deletion align_system/algorithms/kaleido_adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def estimate_kdma_values(self,

rows = []
for choice in choices:
other_choices_str = ', '.join(['"{}"'.format(c) for c in (set(choices) - {choice})])
other_choices_str = ', '.join(['"{}"'.format(c) for c in choices if c != choice])
choice_prompt = format_template(
prompt_template,
allow_extraneous=True,
Expand Down
4 changes: 2 additions & 2 deletions align_system/algorithms/lib/kaleido.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,10 @@ def get_probs(self, inputs, batch_size=None):
encoded_batch = self.tokenizer.batch_encode_plus(
inputs[inds].tolist(),
return_tensors='pt', padding=True, truncation=False, max_length=128,
).to(self.device).input_ids
).to(self.device)
# batch_inputs = encoded_batch[i*batch_size:(i+1)*batch_size]
# Run through model, get last logits
logits_batch = self.model(input_ids=encoded_batch, labels=self.get_dummy(encoded_batch)).logits[:, -1, :].detach().cpu()
logits_batch = self.model(input_ids=encoded_batch.input_ids, attention_mask=encoded_batch.attention_mask, labels=self.get_dummy(encoded_batch.input_ids)).logits[:, -1, :].detach().cpu()
logits.append(logits_batch)

# concatenate logits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ MoralDesert:

maximization:
name: Maximization
description: Prioritize gathering the most information possible before acting on that information, but not at the detriment of casualties
description: Prioritize gathering information to take the most informed, life-saving actions
vrd: Value
15 changes: 11 additions & 4 deletions align_system/algorithms/llama_2_single_kdma_adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ def to_probabilities(logits):

class Llama2SingleKDMAADM(AlignedDecisionMaker):

def __init__(self, device='cuda', hf_model='meta-llama/Llama-2-7b-chat-hf', precision='full', temperature=0.7, **kwargs):
def __init__(self, device='cuda', hf_model='meta-llama/Llama-2-7b-chat-hf', precision='full', temperature=0.7, do_sample=True, **kwargs):
self.device = device
self.hf_model = hf_model
self.temperature = temperature
self.do_sample = do_sample
self.chat_template = kwargs.get('chat_template', None)

assert precision in ['full', 'half'], "precision must be either 'full' or 'half'."
Expand Down Expand Up @@ -297,7 +298,13 @@ def respond_to_dialog(self, dialog, prefix=None):
if self.device != 'auto':
prompt_tokens = prompt_tokens.to(self.device)

outputs = self.model.generate(prompt_tokens, return_dict_in_generate=True, output_scores=True, max_new_tokens=512, temperature=self.temperature, do_sample=True)
outputs = self.model.generate(
prompt_tokens,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=512,
temperature=self.temperature,
do_sample=self.do_sample)

# Print the generated model output
generated_output = self.tokenizer.decode(outputs.sequences[0][prompt_length:])
Expand Down Expand Up @@ -347,8 +354,8 @@ def respond_to_dialogs_batched(self, dialogs, prefixes=None):
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=512,
temperature=self.temperature
)
temperature=self.temperature,
do_sample=self.do_sample)

# Split the sequences based on prompt lengths
split_outputs = torch.split(outputs.sequences, 1, dim=0)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "align-system"
version = "0.3.2"
version = "0.3.3"
description = ""
authors = ["David Joy <10147749+dmjoy@users.noreply.github.com>"]
readme = "README.md"
Expand Down

0 comments on commit 1ca9d4f

Please sign in to comment.