forked from facebookresearch/chameleon
-
Notifications
You must be signed in to change notification settings - Fork 0
/
token_selector.py
47 lines (37 loc) · 1.49 KB
/
token_selector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Chameleon License found in the
# LICENSE file in the root directory of this source tree.
import torch
class TokenSelector:
def __call__(
self, input_ids: torch.LongTensor, probs: torch.FloatTensor
) -> torch.FloatTensor:
# input_ids.shape=[batch, seq_len]
# probs.shape=[batch, vocab]
...
class ArgmaxTokenSelector(TokenSelector):
def __call__(
self, _: torch.LongTensor, probs: torch.FloatTensor
) -> torch.LongTensor:
# probs.shape=[batch, vocab]
return probs.argmax(dim=1)
class MultinomialTokenSelector(TokenSelector):
def __call__(
self, _: torch.LongTensor, probs: torch.FloatTensor
) -> torch.LongTensor:
# probs.shape=[batch, vocab]
return probs.multinomial(num_samples=1).squeeze(1)
class ReplicatedInputTokenSelector(TokenSelector):
def __init__(self, token_selector: TokenSelector, n: int):
self.token_selector = token_selector
self.n = n
def __call__(
self, input_ids: torch.LongTensor, probs: torch.FloatTensor
) -> torch.LongTensor:
# input_ids.shape=[n*batch, seq_len]
# probs.shape=[n*batch, vocab]
primary_input_ids = torch.chunk(input_ids, chunks=self.n, dim=0)[0]
primary_probs = torch.chunk(probs, chunks=self.n, dim=0)[0]
tokens = self.token_selector(primary_input_ids, primary_probs)
return tokens.repeat(self.n)