From 0618a62bb5f8b3f3a6ead73108af3208772d6721 Mon Sep 17 00:00:00 2001 From: David Joy <10147749+dmjoy@users.noreply.github.com> Date: Wed, 24 Apr 2024 12:24:48 -0400 Subject: [PATCH] Add do_sample arg for single kdma adm; update changelog and version --- CHANGELOG.md | 13 ++++++++++++- .../algorithms/llama_2_single_kdma_adm.py | 15 +++++++++++---- pyproject.toml | 2 +- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d96cbfda..6427892b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,13 +3,24 @@ This changelog follows the specifications detailed in: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html), although we have not yet reached a `1.0.0` release. -## Unreleased +## 0.3.3 ### Changed * Modified the prompt for PulseTaggingADM. Also removed duplicated inference call within `identify_tag_color` method. Additionally, removed duplicated RED tag in-context example and replaced with missing BLACK tag example. +* Changed default maximization prompt for Kaleido + +### Fixed + +* Applied attention fixes for Kaliedo provided by UWash +* Fixed an "other choice" ordering issue in Kaleido ADM + +### Added + +* Added an additional parsing guard in Llama2SinglaKDMAADM +* Added do_sample as an init kwarg for Llama2SinglaKDMAADM (set to False for temperature 0) ## 0.3.2 diff --git a/align_system/algorithms/llama_2_single_kdma_adm.py b/align_system/algorithms/llama_2_single_kdma_adm.py index 3e1412d2..22082fb4 100644 --- a/align_system/algorithms/llama_2_single_kdma_adm.py +++ b/align_system/algorithms/llama_2_single_kdma_adm.py @@ -109,10 +109,11 @@ def to_probabilities(logits): class Llama2SingleKDMAADM(AlignedDecisionMaker): - def __init__(self, device='cuda', hf_model='meta-llama/Llama-2-7b-chat-hf', precision='full', temperature=0.7, **kwargs): + def __init__(self, device='cuda', hf_model='meta-llama/Llama-2-7b-chat-hf', precision='full', temperature=0.7, do_sample=True, **kwargs): self.device = device self.hf_model = hf_model self.temperature = temperature + self.do_sample = do_sample self.chat_template = kwargs.get('chat_template', None) assert precision in ['full', 'half'], "precision must be either 'full' or 'half'." @@ -297,7 +298,13 @@ def respond_to_dialog(self, dialog, prefix=None): if self.device != 'auto': prompt_tokens = prompt_tokens.to(self.device) - outputs = self.model.generate(prompt_tokens, return_dict_in_generate=True, output_scores=True, max_new_tokens=512, temperature=self.temperature, do_sample=True) + outputs = self.model.generate( + prompt_tokens, + return_dict_in_generate=True, + output_scores=True, + max_new_tokens=512, + temperature=self.temperature, + do_sample=self.do_sample) # Print the generated model output generated_output = self.tokenizer.decode(outputs.sequences[0][prompt_length:]) @@ -347,8 +354,8 @@ def respond_to_dialogs_batched(self, dialogs, prefixes=None): return_dict_in_generate=True, output_scores=True, max_new_tokens=512, - temperature=self.temperature - ) + temperature=self.temperature, + do_sample=self.do_sample) # Split the sequences based on prompt lengths split_outputs = torch.split(outputs.sequences, 1, dim=0) diff --git a/pyproject.toml b/pyproject.toml index 6763e198..a1f66a44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "align-system" -version = "0.3.2" +version = "0.3.3" description = "" authors = ["David Joy <10147749+dmjoy@users.noreply.github.com>"] readme = "README.md"