Skip to content

Commit

Permalink
Add do_sample arg for single kdma adm; update changelog and version
Browse files Browse the repository at this point in the history
  • Loading branch information
dmjoy committed Apr 24, 2024
1 parent 3b3ea29 commit 0618a62
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
13 changes: 12 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,24 @@
This changelog follows the specifications detailed in: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html), although we have not yet reached a `1.0.0` release.

## Unreleased
## 0.3.3

### Changed

* Modified the prompt for PulseTaggingADM. Also removed duplicated inference call within `identify_tag_color`
method. Additionally, removed duplicated RED tag in-context example and replaced with missing BLACK tag
example.
* Changed default maximization prompt for Kaleido

### Fixed

* Applied attention fixes for Kaliedo provided by UWash
* Fixed an "other choice" ordering issue in Kaleido ADM

### Added

* Added an additional parsing guard in Llama2SinglaKDMAADM
* Added do_sample as an init kwarg for Llama2SinglaKDMAADM (set to False for temperature 0)

## 0.3.2

Expand Down
15 changes: 11 additions & 4 deletions align_system/algorithms/llama_2_single_kdma_adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ def to_probabilities(logits):

class Llama2SingleKDMAADM(AlignedDecisionMaker):

def __init__(self, device='cuda', hf_model='meta-llama/Llama-2-7b-chat-hf', precision='full', temperature=0.7, **kwargs):
def __init__(self, device='cuda', hf_model='meta-llama/Llama-2-7b-chat-hf', precision='full', temperature=0.7, do_sample=True, **kwargs):
self.device = device
self.hf_model = hf_model
self.temperature = temperature
self.do_sample = do_sample
self.chat_template = kwargs.get('chat_template', None)

assert precision in ['full', 'half'], "precision must be either 'full' or 'half'."
Expand Down Expand Up @@ -297,7 +298,13 @@ def respond_to_dialog(self, dialog, prefix=None):
if self.device != 'auto':
prompt_tokens = prompt_tokens.to(self.device)

outputs = self.model.generate(prompt_tokens, return_dict_in_generate=True, output_scores=True, max_new_tokens=512, temperature=self.temperature, do_sample=True)
outputs = self.model.generate(
prompt_tokens,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=512,
temperature=self.temperature,
do_sample=self.do_sample)

# Print the generated model output
generated_output = self.tokenizer.decode(outputs.sequences[0][prompt_length:])
Expand Down Expand Up @@ -347,8 +354,8 @@ def respond_to_dialogs_batched(self, dialogs, prefixes=None):
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=512,
temperature=self.temperature
)
temperature=self.temperature,
do_sample=self.do_sample)

# Split the sequences based on prompt lengths
split_outputs = torch.split(outputs.sequences, 1, dim=0)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "align-system"
version = "0.3.2"
version = "0.3.3"
description = ""
authors = ["David Joy <10147749+dmjoy@users.noreply.github.com>"]
readme = "README.md"
Expand Down

0 comments on commit 0618a62

Please sign in to comment.