diff --git a/skorch/llm/classifier.py b/skorch/llm/classifier.py index 23ea8ad98..560fde3e0 100644 --- a/skorch/llm/classifier.py +++ b/skorch/llm/classifier.py @@ -169,9 +169,9 @@ def _extend_inputs(inputs, extra): class _LogitsRecorder(LogitsProcessor): """Helper class to record logits and force the given label token ids""" - def __init__(self, label_ids, tokenizer): + def __init__(self, token_ids, tokenizer): self.recorded_scores = [] - self.label_ids = label_ids + self.token_ids = token_ids self.tokenizer = tokenizer def __call__(self, input_ids, scores): @@ -180,11 +180,87 @@ def __call__(self, input_ids, scores): # therefore there is no device mismatch and we save a bit of GPU memory self.recorded_scores.append(scores[0].clone().cpu()) mask = torch.ones(scores.size(), dtype=torch.bool) - mask[0, self.label_ids[idx]] = False + mask[0, self.token_ids[idx]] = False scores[mask] = -float('inf') return scores +class _CFGuidance(LogitsProcessor): + """Helper class to implement Classifier Free Guidance [1] + to guide the model sampling in a direction that takes + the prompt more into account than the generated output + without the prompt. + + Mathematically this is implemented by the following + transformation of the log-probabilities: + + .. math:: + + \text{log} \hat{\textbf{P}}_\theta(w|c) \propto + \text{log} \textbf{P}_\theta(w_i|w_{j < i}) + + \gamma ( + \text{log} \textbf{P}_\theta(w_i|w_{j