Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

beam search #157

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pix2tex/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i
if seq is None or im is None:
continue
#loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)
dec = model.generate(im.to(device), temperature=args.get('temperature', .2))
dec = model.generate(im.to(device), **args)
pred = detokenize(dec, dataset.tokenizer)
truth = detokenize(seq['input_ids'], dataset.tokenizer)
bleus.append(metrics.bleu_score(pred, [alternatives(x) for x in truth]))
Expand Down
4 changes: 3 additions & 1 deletion pix2tex/model/settings/config-vit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ test_samples: 5
testbatchsize: 20
tokenizer: dataset/tokenizer.json
valbatches: 100
valdata: dataset/data/val.pkl
valdata: dataset/data/val.pkl
num_beams: 3
length_penalty: 0.7
2 changes: 2 additions & 0 deletions pix2tex/model/settings/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,5 @@ testbatchsize: 20
tokenizer: dataset/tokenizer.json
valbatches: 100
valdata: dataset/data/val.pkl
num_beams: 3
length_penalty: 0.7
4 changes: 4 additions & 0 deletions pix2tex/model/settings/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ pad: False
pad_token: 0
bos_token: 1
eos_token: 2

#beam search
num_beams: 3
length_penalty: 0.7
152 changes: 148 additions & 4 deletions pix2tex/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,156 @@
from x_transformers import TransformerWrapper, Decoder


class BeamHypotheses(object):
def __init__(self, num_beams: int, length_penalty: float):
"""
Initialize n-best list of hypotheses.
"""
self.length_penalty = length_penalty
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9

def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)

def add(self, hyp: torch.LongTensor, sum_logprobs: float):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted(
[(s, idx) for idx, (s, _) in enumerate(self.beams)])
del self.beams[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)

def is_done(self, best_sum_logprobs: float, cur_len: int):
"""
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
else:
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret


class CustomARWrapper(AutoregressiveWrapper):
def __init__(self, *args, **kwargs):
super(CustomARWrapper, self).__init__(*args, **kwargs)

@torch.no_grad()
def generate(self, start_tokens, seq_len=256, eos_token=None, temperature=1., filter_logits_fn=top_k, filter_thres=0.9, **kwargs):
def beam_generate(self, start_tokens, context, seq_len=256, **kwargs):
eos_token = kwargs.get('eos_token', None)
num_beams = kwargs.get('num_beams', 3)
length_penalty = kwargs.get('length_penalty', 0.7)
temperature = kwargs.get('temperature', 1.)
num_tokens = kwargs.get('num_tokens', None)
pad_token = kwargs.get('pad_token', 0)
batch_size, t = start_tokens.shape
was_training = self.net.training
self.net.eval()
beam_scores = torch.zeros(
(batch_size, num_beams)).to(start_tokens.device)
beam_scores[:, 1:] = -1e9 # prevent the first time beam repeating
beam_scores = beam_scores.view(-1)
done = [False for _ in range(batch_size)]
generated_hyps = [BeamHypotheses(
num_beams, length_penalty=length_penalty) for _ in range(batch_size)]
input_ids = start_tokens.repeat(num_beams, 1)
hidden = context[:, None].repeat(1, num_beams, 1, 1).view(
batch_size*num_beams, context.shape[1], context.shape[2])
cur_len = t
while cur_len < seq_len:
outputs = self.net(x=input_ids, context=hidden)
next_token_logits = outputs[:, -1, :]
scores = F.log_softmax(next_token_logits/temperature, dim=-1)
# cumulative log(prob)
next_scores = scores + beam_scores[:, None].expand_as(scores)
next_scores = next_scores.view(batch_size, num_beams * num_tokens)
next_scores, next_tokens = torch.topk(
next_scores, 2*num_beams, dim=1, largest=True, sorted=True)
next_batch_beam = []
for batch_idx in range(batch_size):
if done[batch_idx]:
next_batch_beam.extend([(0, pad_token, 0)] * num_beams)# (0,0,0)->(score, token id, beam id)
continue
next_sent_beam = []
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(zip(next_tokens[batch_idx], next_scores[batch_idx])):
beam_id = beam_token_id // num_tokens
token_id = beam_token_id % num_tokens
effective_beam_id = batch_idx * num_beams + beam_id
if (eos_token is not None) and (token_id.item() == eos_token):
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams:
continue
generated_hyps[batch_idx].add(
input_ids[effective_beam_id].clone(), beam_token_score.item(),)
else:
next_sent_beam.append(
(beam_token_score, token_id, effective_beam_id))
if len(next_sent_beam) == num_beams:
break
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
next_scores[batch_idx].max().item(), cur_len)
next_batch_beam.extend(next_sent_beam)
if all(done):
break
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
input_ids = input_ids[beam_idx, ...]
hidden = hidden[beam_idx, ...]
input_ids = torch.cat(
[input_ids, beam_tokens.unsqueeze(1)], dim=-1)
cur_len = cur_len + 1
for batch_idx in range(batch_size):
if done[batch_idx]:
continue
for beam_id in range(num_beams):
effective_beam_id = batch_idx * num_beams + beam_id
final_score = beam_scores[effective_beam_id].item()
final_tokens = input_ids[effective_beam_id]
generated_hyps[batch_idx].add(final_tokens, final_score)
output_num_return_sequences_per_batch = 1
output_batch_size = output_num_return_sequences_per_batch * batch_size
sent_lengths = input_ids.new(output_batch_size)
best = []
for i, hypotheses in enumerate(generated_hyps):
sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
for j in range(output_num_return_sequences_per_batch):
effective_batch_idx = output_num_return_sequences_per_batch * i + j
best_hyp = sorted_hyps.pop()[1]
sent_lengths[effective_batch_idx] = len(best_hyp)
best.append(best_hyp)
if sent_lengths.min().item() != sent_lengths.max().item():
sent_max_len = min(sent_lengths.max().item() + 1, seq_len)
decoded = input_ids.new(
output_batch_size, sent_max_len).fill_(pad_token)
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < seq_len:
decoded[i, sent_lengths[i]] = eos_token
else:
decoded = torch.stack(best).type(torch.long)
self.net.train(was_training)
decoded = decoded[:, t:]
return decoded

@torch.no_grad()
def generate(self, start_tokens, context, seq_len=256, filter_logits_fn=top_k, filter_thres=0.9, **kwargs):
eos_token = kwargs.get('eos_token', None)
temperature = kwargs.get('temperature', 1.)
device = start_tokens.device
was_training = self.net.training
num_dims = len(start_tokens.shape)
Expand All @@ -23,13 +167,13 @@ def generate(self, start_tokens, seq_len=256, eos_token=None, temperature=1., fi
out = start_tokens
mask = kwargs.pop('mask', None)
if mask is None:
mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)
mask = torch.full_like(out, True, dtype=torch.bool, device=device)

for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
mask = mask[:, -self.max_seq_len:]
# print('arw:',out.shape)
logits = self.net(x, mask=mask, **kwargs)[:, -1, :]
# print('arw:', out.shape)
logits = self.net(x, mask=mask, context=context)[:, -1, :]

if filter_logits_fn in {top_k, top_p}:
filtered_logits = filter_logits_fn(logits, thres=filter_thres)
Expand Down
21 changes: 14 additions & 7 deletions pix2tex/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from . import hybrid
from . import vit
from . import transformer

from munch import Munch

class Model(nn.Module):
def __init__(self, encoder, decoder, args):
Expand All @@ -19,8 +19,10 @@ def data_parallel(self, x: torch.Tensor, device_ids, output_device=None, **kwarg
if output_device is None:
output_device = device_ids[0]
replicas = nn.parallel.replicate(self, device_ids)
inputs = nn.parallel.scatter(x, device_ids) # Slices tensors into approximately equal chunks and distributes them across given GPUs.
kwargs = nn.parallel.scatter(kwargs, device_ids) # Duplicates references to objects that are not tensors.
# Slices tensors into approximately equal chunks and distributes them across given GPUs.
inputs = nn.parallel.scatter(x, device_ids)
# Duplicates references to objects that are not tensors.
kwargs = nn.parallel.scatter(kwargs, device_ids)
replicas = replicas[:len(inputs)]
kwargs = kwargs[:len(inputs)]
outputs = nn.parallel.parallel_apply(replicas, inputs, kwargs)
Expand All @@ -32,9 +34,13 @@ def forward(self, x: torch.Tensor, tgt_seq: torch.Tensor, **kwargs):
return out

@torch.no_grad()
def generate(self, x: torch.Tensor, temperature: float = 0.25):
return self.decoder.generate((torch.LongTensor([self.args.bos_token]*len(x))[:, None]).to(x.device), self.args.max_seq_len,
eos_token=self.args.eos_token, context=self.encoder(x), temperature=temperature)
def generate(self, x: torch.Tensor, temperature: float = 0.25, **kwargs):
args = Munch(self.args)
args.update(kwargs)
args.temperature = temperature
# return self.decoder.beam_generate((torch.LongTensor([self.args.bos_token]*len(x))[:, None]).to(x.device), context=self.encoder(x), seq_len=self.args.max_seq_len, **args)
return self.decoder.generate((torch.LongTensor([self.args.bos_token]*len(x))[:, None]).to(x.device), context=self.encoder(x), seq_len=self.args.max_seq_len, **args)



def get_model(args):
Expand All @@ -43,7 +49,8 @@ def get_model(args):
elif args.encoder_structure.lower() == 'hybrid':
encoder = hybrid.get_encoder(args)
else:
raise NotImplementedError('Encoder structure "%s" not supported.' % args.encoder_structure)
raise NotImplementedError(
'Encoder structure "%s" not supported.' % args.encoder_structure)
decoder = transformer.get_decoder(args)
encoder.to(args.device)
decoder.to(args.device)
Expand Down