-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils.py
2119 lines (1792 loc) · 86.6 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from allennlp.modules.elmo import batch_to_ids
import torch
import numpy as np
import numpy
from sacrebleu import corpus_bleu as bleu
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.tokenize import word_tokenize
import tqdm
import os
from subprocess import Popen
import subprocess
import copy
import json
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
import math
from allennlp.common.checks import ConfigurationError
import random
logger = logging.getLogger(__name__)
T = TypeVar("T")
cc = SmoothingFunction()
def bert_dual_sequence_mask(input_ids, sep_token_id, device="cuda"):
output = torch.ones(input_ids.shape).type(torch.LongTensor)
sep_token_id = torch.Tensor(sep_token_id).to(device)
for i, seq in enumerate(input_ids):
seperation_idx = torch.where(seq == sep_token_id)[0][0]
output[i, : seperation_idx + 1] = 0
return output.to(device)
def get_output_attribute(out, attribute_name, cuda_device, reduction="sum"):
"""
This function handles processing/reduction of output for both
DataParallel or non-DataParallel situations.
For the case of multiple GPUs, This function will
sum all values for a certain output attribute in various batches
together.
Parameters
---------------------
:param out: Dictionary, output of model during forward pass,
:param attribute_name: str,
:param cuda_device: list or int
:param reduction: (string, optional) reduction to apply to the output. Default: 'sum'.
"""
if isinstance(cuda_device, list):
if reduction == "sum":
return out[attribute_name].sum()
elif reduction == "mean":
return out[attribute_name].sum() / float(len(out[attribute_name]))
else:
raise ValueError("invalid reduction type argument")
else:
return out[attribute_name]
def removeDuplicates(listofElements):
# Create an empty list to store unique elements
uniqueList = []
# Iterate over the original list and for each element
# add it to uniqueList, if its not already there.
for elem in listofElements:
if elem not in uniqueList:
uniqueList.append(elem)
# Return the list of unique elements
return uniqueList
def sequence_mask(lengths, max_len=None):
"""
Creates a boolean mask from sequence lengths.
"""
batch_size = lengths.numel()
max_len = max_len or lengths.max()
return (
torch.arange(0, max_len)
.type_as(lengths)
.repeat(batch_size, 1)
.lt(lengths.unsqueeze(1))
)
def elmo_batch_to_ids(_str):
batch = [word_tokenize(_str)]
return [i.tolist() for i in batch_to_ids(batch)[0]]
def find_subtensor(sl, l, device="cuda", scalar=0.75):
sl = sl.to(device)
sll = len(sl)
tracker = []
for ind in range(
len(l) - int(sll * scalar)
): # (i for i, e in enumerate(l) if e == sl[0]):
end_ind = -max(0, sll - l[ind : ind + sll].shape[0])
if end_ind == 0:
end_ind = None
percentage = (
(l[ind : ind + sll] == sl[:end_ind]).type(torch.float64).mean().item()
)
tracker.append(percentage)
if False not in (l[ind : ind + sll] == sl[:end_ind]):
return torch.tensor([ind, ind + sll - 1]).to(device)
if max(tracker) == 0:
return None
else:
ind = np.argmax(np.array(tracker))
return torch.tensor([ind, min(ind + sll - 1, len(l) - 1)]).to(device)
def batch_bleu(ref, hyp, reduction="average"):
"""
Calculates bleu score for a batch,
"""
number = int(round(random.random(), 20) * 1e20)
assert len(ref) == len(hyp)
with open(f"/tmp/{number}.ref.txt", "w") as f:
f.write("\n".join(ref))
with open(f"/tmp/{number}.hyp.txt", "w") as f:
f.write("\n".join(hyp))
Popen(
f"./sentence-bleu /tmp/{number}.ref.txt < /tmp/{number}.hyp.txt >"
f" /tmp/{number}.bleu_score",
shell=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
).wait()
with open(f"/tmp/{number}.bleu_score", "r") as f:
scores = [float(s) for s in f.readlines()]
total_score = sum(scores)
os.remove(f"/tmp/{number}.ref.txt")
os.remove(f"/tmp/{number}.hyp.txt")
os.remove(f"/tmp/{number}.bleu_score")
if reduction == "sum":
return total_score
elif reduction == "average":
try:
out = total_score / len(ref)
except ZeroDivisionError:
out = 0.0
return out
else:
raise NotImplementedError(
f"{reduction} not in supported reductions: ['sum','average']"
)
def mkdir(path):
try:
os.mkdir(path)
except OSError as error:
pass
return path
def has_tensor(obj) -> bool:
"""
Given a possibly complex data structure,
check if it has any torch.Tensors in it.
"""
if isinstance(obj, torch.Tensor):
return True
elif isinstance(obj, dict):
return any(has_tensor(value) for value in obj.values())
elif isinstance(obj, (list, tuple)):
return any(has_tensor(item) for item in obj)
else:
return False
def move_to_device(obj, cuda_device: int):
"""
Given a structure (possibly) containing Tensors on the CPU,
move all the Tensors to the specified GPU (or do nothing, if they should be on the CPU).
"""
if cuda_device < 0 or not has_tensor(obj):
return obj
elif isinstance(obj, torch.Tensor):
return obj.cuda(cuda_device)
elif isinstance(obj, dict):
return {key: move_to_device(value, cuda_device) for key, value in obj.items()}
elif isinstance(obj, list):
return [move_to_device(item, cuda_device) for item in obj]
elif isinstance(obj, tuple) and hasattr(obj, "_fields"):
# This is the best way to detect a NamedTuple, it turns out.
return obj.__class__(*(move_to_device(item, cuda_device) for item in obj))
elif isinstance(obj, tuple):
return tuple(move_to_device(item, cuda_device) for item in obj)
else:
return obj
def clamp_tensor(tensor, minimum, maximum):
"""
Supports sparse and dense tensors.
Returns a tensor with values clamped between the provided minimum and maximum,
without modifying the original tensor.
"""
if tensor.is_sparse:
coalesced_tensor = tensor.coalesce()
coalesced_tensor._values().clamp_(minimum, maximum)
return coalesced_tensor
else:
return tensor.clamp(minimum, maximum)
def batch_tensor_dicts(
tensor_dicts: List[Dict[str, torch.Tensor]], remove_trailing_dimension: bool = False
) -> Dict[str, torch.Tensor]:
"""
Takes a list of tensor dictionaries, where each dictionary is assumed to have matching keys,
and returns a single dictionary with all tensors with the same key batched together.
# Parameters
tensor_dicts : `List[Dict[str, torch.Tensor]]`
The list of tensor dictionaries to batch.
remove_trailing_dimension : `bool`
If `True`, we will check for a trailing dimension of size 1 on the tensors that are being
batched, and remove it if we find it.
"""
key_to_tensors: Dict[str, List[torch.Tensor]] = defaultdict(list)
for tensor_dict in tensor_dicts:
for key, tensor in tensor_dict.items():
key_to_tensors[key].append(tensor)
batched_tensors = {}
for key, tensor_list in key_to_tensors.items():
batched_tensor = torch.stack(tensor_list)
if remove_trailing_dimension and all(
tensor.size(-1) == 1 for tensor in tensor_list
):
batched_tensor = batched_tensor.squeeze(-1)
batched_tensors[key] = batched_tensor
return batched_tensors
def get_lengths_from_binary_sequence_mask(mask: torch.BoolTensor) -> torch.LongTensor:
"""
Compute sequence lengths for each batch element in a tensor using a
binary mask.
# Parameters
mask : `torch.BoolTensor`, required.
A 2D binary mask of shape (batch_size, sequence_length) to
calculate the per-batch sequence lengths from.
# Returns
`torch.LongTensor`
A torch.LongTensor of shape (batch_size,) representing the lengths
of the sequences in the batch.
"""
return mask.sum(-1)
def get_mask_from_sequence_lengths(
sequence_lengths: torch.Tensor, max_length: int
) -> torch.BoolTensor:
"""
Given a variable of shape `(batch_size,)` that represents the sequence lengths of each batch
element, this function returns a `(batch_size, max_length)` mask variable. For example, if
our input was `[2, 2, 3]`, with a `max_length` of 4, we'd return
`[[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0]]`.
We require `max_length` here instead of just computing it from the input `sequence_lengths`
because it lets us avoid finding the max, then copying that value from the GPU to the CPU so
that we can use it to construct a new tensor.
"""
# (batch_size, max_length)
ones = sequence_lengths.new_ones(sequence_lengths.size(0), max_length)
range_tensor = ones.cumsum(dim=1)
return sequence_lengths.unsqueeze(1) >= range_tensor
def sort_batch_by_length(tensor: torch.Tensor, sequence_lengths: torch.Tensor):
"""
Sort a batch first tensor by some specified lengths.
# Parameters
tensor : torch.FloatTensor, required.
A batch first Pytorch tensor.
sequence_lengths : torch.LongTensor, required.
A tensor representing the lengths of some dimension of the tensor which
we want to sort by.
# Returns
sorted_tensor : `torch.FloatTensor`
The original tensor sorted along the batch dimension with respect to sequence_lengths.
sorted_sequence_lengths : `torch.LongTensor`
The original sequence_lengths sorted by decreasing size.
restoration_indices : `torch.LongTensor`
Indices into the sorted_tensor such that
`sorted_tensor.index_select(0, restoration_indices) == original_tensor`
permutation_index : `torch.LongTensor`
The indices used to sort the tensor. This is useful if you want to sort many
tensors using the same ordering.
"""
if not isinstance(tensor, torch.Tensor) or not isinstance(
sequence_lengths, torch.Tensor
):
raise ConfigurationError(
"Both the tensor and sequence lengths must be torch.Tensors."
)
sorted_sequence_lengths, permutation_index = sequence_lengths.sort(
0, descending=True
)
sorted_tensor = tensor.index_select(0, permutation_index)
index_range = torch.arange(0, len(sequence_lengths), device=sequence_lengths.device)
# This is the equivalent of zipping with index, sorting by the original
# sequence lengths and returning the now sorted indices.
_, reverse_mapping = permutation_index.sort(0, descending=False)
restoration_indices = index_range.index_select(0, reverse_mapping)
return (
sorted_tensor,
sorted_sequence_lengths,
restoration_indices,
permutation_index,
)
def get_final_encoder_states(
encoder_outputs: torch.Tensor, mask: torch.BoolTensor, bidirectional: bool = False
) -> torch.Tensor:
"""
Given the output from a `Seq2SeqEncoder`, with shape `(batch_size, sequence_length,
encoding_dim)`, this method returns the final hidden state for each element of the batch,
giving a tensor of shape `(batch_size, encoding_dim)`. This is not as simple as
`encoder_outputs[:, -1]`, because the sequences could have different lengths. We use the
mask (which has shape `(batch_size, sequence_length)`) to find the final state for each batch
instance.
Additionally, if `bidirectional` is `True`, we will split the final dimension of the
`encoder_outputs` into two and assume that the first half is for the forward direction of the
encoder and the second half is for the backward direction. We will concatenate the last state
for each encoder dimension, giving `encoder_outputs[:, -1, :encoding_dim/2]` concatenated with
`encoder_outputs[:, 0, encoding_dim/2:]`.
"""
# These are the indices of the last words in the sequences (i.e. length sans padding - 1). We
# are assuming sequences are right padded.
# Shape: (batch_size,)
last_word_indices = mask.sum(1) - 1
batch_size, _, encoder_output_dim = encoder_outputs.size()
expanded_indices = last_word_indices.view(-1, 1, 1).expand(
batch_size, 1, encoder_output_dim
)
# Shape: (batch_size, 1, encoder_output_dim)
final_encoder_output = encoder_outputs.gather(1, expanded_indices)
final_encoder_output = final_encoder_output.squeeze(
1
) # (batch_size, encoder_output_dim)
if bidirectional:
final_forward_output = final_encoder_output[:, : (encoder_output_dim // 2)]
final_backward_output = encoder_outputs[:, 0, (encoder_output_dim // 2) :]
final_encoder_output = torch.cat(
[final_forward_output, final_backward_output], dim=-1
)
return final_encoder_output
def get_dropout_mask(dropout_probability: float, tensor_for_masking: torch.Tensor):
"""
Computes and returns an element-wise dropout mask for a given tensor, where
each element in the mask is dropped out with probability dropout_probability.
Note that the mask is NOT applied to the tensor - the tensor is passed to retain
the correct CUDA tensor type for the mask.
# Parameters
dropout_probability : float, required.
Probability of dropping a dimension of the input.
tensor_for_masking : torch.Tensor, required.
# Returns
`torch.FloatTensor`
A torch.FloatTensor consisting of the binary mask scaled by 1/ (1 - dropout_probability).
This scaling ensures expected values and variances of the output of applying this mask
and the original tensor are the same.
"""
binary_mask = (torch.rand(tensor_for_masking.size()) > dropout_probability).to(
tensor_for_masking.device
)
# Scale mask by 1/keep_prob to preserve output statistics.
dropout_mask = binary_mask.float().div(1.0 - dropout_probability)
return dropout_mask
def masked_softmax(
vector: torch.Tensor,
mask: torch.BoolTensor,
dim: int = -1,
memory_efficient: bool = False,
) -> torch.Tensor:
"""
`torch.nn.functional.softmax(vector)` does not work if some elements of `vector` should be
masked. This performs a softmax on just the non-masked portions of `vector`. Passing
`None` in for the mask is also acceptable; you'll just get a regular softmax.
`vector` can have an arbitrary number of dimensions; the only requirement is that `mask` is
broadcastable to `vector's` shape. If `mask` has fewer dimensions than `vector`, we will
unsqueeze on dimension 1 until they match. If you need a different unsqueezing of your mask,
do it yourself before passing the mask into this function.
If `memory_efficient` is set to true, we will simply use a very large negative number for those
masked positions so that the probabilities of those positions would be approximately 0.
This is not accurate in math, but works for most cases and consumes less memory.
In the case that the input vector is completely masked and `memory_efficient` is false, this function
returns an array of `0.0`. This behavior may cause `NaN` if this is used as the last layer of
a model that uses categorical cross-entropy loss. Instead, if `memory_efficient` is true, this function
will treat every element as equal, and do softmax over equal numbers.
"""
if mask is None:
result = torch.nn.functional.softmax(vector, dim=dim)
else:
while mask.dim() < vector.dim():
mask = mask.unsqueeze(1)
if not memory_efficient:
# To limit numerical errors from large vector elements outside the mask, we zero these out.
result = torch.nn.functional.softmax(vector * mask, dim=dim)
result = result * mask
result = result / (
result.sum(dim=dim, keepdim=True) + tiny_value_of_dtype(result.dtype)
)
else:
masked_vector = vector.masked_fill(~mask, min_value_of_dtype(vector.dtype))
result = torch.nn.functional.softmax(masked_vector, dim=dim)
return result
def masked_log_softmax(
vector: torch.Tensor, mask: torch.BoolTensor, dim: int = -1
) -> torch.Tensor:
"""
`torch.nn.functional.log_softmax(vector)` does not work if some elements of `vector` should be
masked. This performs a log_softmax on just the non-masked portions of `vector`. Passing
`None` in for the mask is also acceptable; you'll just get a regular log_softmax.
`vector` can have an arbitrary number of dimensions; the only requirement is that `mask` is
broadcastable to `vector's` shape. If `mask` has fewer dimensions than `vector`, we will
unsqueeze on dimension 1 until they match. If you need a different unsqueezing of your mask,
do it yourself before passing the mask into this function.
In the case that the input vector is completely masked, the return value of this function is
arbitrary, but not `nan`. You should be masking the result of whatever computation comes out
of this in that case, anyway, so the specific values returned shouldn't matter. Also, the way
that we deal with this case relies on having single-precision floats; mixing half-precision
floats with fully-masked vectors will likely give you `nans`.
If your logits are all extremely negative (i.e., the max value in your logit vector is -50 or
lower), the way we handle masking here could mess you up. But if you've got logit values that
extreme, you've got bigger problems than this.
"""
if mask is not None:
while mask.dim() < vector.dim():
mask = mask.unsqueeze(1)
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it
# results in nans when the whole vector is masked. We need a very small value instead of a
# zero in the mask for these cases.
vector = vector + (mask + tiny_value_of_dtype(vector.dtype)).log()
return torch.nn.functional.log_softmax(vector, dim=dim)
def masked_max(
vector: torch.Tensor, mask: torch.BoolTensor, dim: int, keepdim: bool = False,
) -> torch.Tensor:
"""
To calculate max along certain dimensions on masked values
# Parameters
vector : `torch.Tensor`
The vector to calculate max, assume unmasked parts are already zeros
mask : `torch.BoolTensor`
The mask of the vector. It must be broadcastable with vector.
dim : `int`
The dimension to calculate max
keepdim : `bool`
Whether to keep dimension
# Returns
`torch.Tensor`
A `torch.Tensor` of including the maximum values.
"""
replaced_vector = vector.masked_fill(~mask, min_value_of_dtype(vector.dtype))
max_value, _ = replaced_vector.max(dim=dim, keepdim=keepdim)
return max_value
def masked_mean(
vector: torch.Tensor, mask: torch.BoolTensor, dim: int, keepdim: bool = False
) -> torch.Tensor:
"""
To calculate mean along certain dimensions on masked values
# Parameters
vector : `torch.Tensor`
The vector to calculate mean.
mask : `torch.BoolTensor`
The mask of the vector. It must be broadcastable with vector.
dim : `int`
The dimension to calculate mean
keepdim : `bool`
Whether to keep dimension
# Returns
`torch.Tensor`
A `torch.Tensor` of including the mean values.
"""
replaced_vector = vector.masked_fill(~mask, 0.0)
value_sum = torch.sum(replaced_vector, dim=dim, keepdim=keepdim)
value_count = torch.sum(mask, dim=dim, keepdim=keepdim)
return value_sum / value_count.float().clamp(min=tiny_value_of_dtype(torch.float))
def masked_flip(
padded_sequence: torch.Tensor, sequence_lengths: List[int]
) -> torch.Tensor:
"""
Flips a padded tensor along the time dimension without affecting masked entries.
# Parameters
padded_sequence : `torch.Tensor`
The tensor to flip along the time dimension.
Assumed to be of dimensions (batch size, num timesteps, ...)
sequence_lengths : `torch.Tensor`
A list containing the lengths of each unpadded sequence in the batch.
# Returns
`torch.Tensor`
A `torch.Tensor` of the same shape as padded_sequence.
"""
assert padded_sequence.size(0) == len(sequence_lengths), (
f"sequence_lengths length ${len(sequence_lengths)} does not match batch size"
f" ${padded_sequence.size(0)}"
)
num_timesteps = padded_sequence.size(1)
flipped_padded_sequence = torch.flip(padded_sequence, [1])
sequences = [
flipped_padded_sequence[i, num_timesteps - length :]
for i, length in enumerate(sequence_lengths)
]
return torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)
def viterbi_decode(
tag_sequence: torch.Tensor,
transition_matrix: torch.Tensor,
tag_observations: Optional[List[int]] = None,
allowed_start_transitions: torch.Tensor = None,
allowed_end_transitions: torch.Tensor = None,
top_k: int = None,
):
"""
Perform Viterbi decoding in log space over a sequence given a transition matrix
specifying pairwise (transition) potentials between tags and a matrix of shape
(sequence_length, num_tags) specifying unary potentials for possible tags per
timestep.
# Parameters
tag_sequence : torch.Tensor, required.
A tensor of shape (sequence_length, num_tags) representing scores for
a set of tags over a given sequence.
transition_matrix : torch.Tensor, required.
A tensor of shape (num_tags, num_tags) representing the binary potentials
for transitioning between a given pair of tags.
tag_observations : Optional[List[int]], optional, (default = None)
A list of length `sequence_length` containing the class ids of observed
elements in the sequence, with unobserved elements being set to -1. Note that
it is possible to provide evidence which results in degenerate labelings if
the sequences of tags you provide as evidence cannot transition between each
other, or those transitions are extremely unlikely. In this situation we log a
warning, but the responsibility for providing self-consistent evidence ultimately
lies with the user.
allowed_start_transitions : torch.Tensor, optional, (default = None)
An optional tensor of shape (num_tags,) describing which tags the START token
may transition *to*. If provided, additional transition constraints will be used for
determining the start element of the sequence.
allowed_end_transitions : torch.Tensor, optional, (default = None)
An optional tensor of shape (num_tags,) describing which tags may transition *to* the
end tag. If provided, additional transition constraints will be used for determining
the end element of the sequence.
top_k : int, optional, (default = None)
Optional integer specifying how many of the top paths to return. For top_k>=1, returns
a tuple of two lists: top_k_paths, top_k_scores, For top_k==None, returns a flattened
tuple with just the top path and its score (not in lists, for backwards compatibility).
# Returns
viterbi_path : `List[int]`
The tag indices of the maximum likelihood tag sequence.
viterbi_score : `torch.Tensor`
The score of the viterbi path.
"""
if top_k is None:
top_k = 1
flatten_output = True
elif top_k >= 1:
flatten_output = False
else:
raise ValueError(
f"top_k must be either None or an integer >=1. Instead received {top_k}"
)
sequence_length, num_tags = list(tag_sequence.size())
has_start_end_restrictions = (
allowed_end_transitions is not None or allowed_start_transitions is not None
)
if has_start_end_restrictions:
if allowed_end_transitions is None:
allowed_end_transitions = torch.zeros(num_tags)
if allowed_start_transitions is None:
allowed_start_transitions = torch.zeros(num_tags)
num_tags = num_tags + 2
new_transition_matrix = torch.zeros(num_tags, num_tags)
new_transition_matrix[:-2, :-2] = transition_matrix
# Start and end transitions are fully defined, but cannot transition between each other.
allowed_start_transitions = torch.cat(
[allowed_start_transitions, torch.tensor([-math.inf, -math.inf])]
)
allowed_end_transitions = torch.cat(
[allowed_end_transitions, torch.tensor([-math.inf, -math.inf])]
)
# First define how we may transition FROM the start and end tags.
new_transition_matrix[-2, :] = allowed_start_transitions
# We cannot transition from the end tag to any tag.
new_transition_matrix[-1, :] = -math.inf
new_transition_matrix[:, -1] = allowed_end_transitions
# We cannot transition to the start tag from any tag.
new_transition_matrix[:, -2] = -math.inf
transition_matrix = new_transition_matrix
if tag_observations:
if len(tag_observations) != sequence_length:
raise ConfigurationError(
"Observations were provided, but they were not the same length "
"as the sequence. Found sequence of length: {} and evidence: {}".format(
sequence_length, tag_observations
)
)
else:
tag_observations = [-1 for _ in range(sequence_length)]
if has_start_end_restrictions:
tag_observations = [num_tags - 2] + tag_observations + [num_tags - 1]
zero_sentinel = torch.zeros(1, num_tags)
extra_tags_sentinel = torch.ones(sequence_length, 2) * -math.inf
tag_sequence = torch.cat([tag_sequence, extra_tags_sentinel], -1)
tag_sequence = torch.cat([zero_sentinel, tag_sequence, zero_sentinel], 0)
sequence_length = tag_sequence.size(0)
path_scores = []
path_indices = []
if tag_observations[0] != -1:
one_hot = torch.zeros(num_tags)
one_hot[tag_observations[0]] = 100000.0
path_scores.append(one_hot.unsqueeze(0))
else:
path_scores.append(tag_sequence[0, :].unsqueeze(0))
# Evaluate the scores for all possible paths.
for timestep in range(1, sequence_length):
# Add pairwise potentials to current scores.
summed_potentials = path_scores[timestep - 1].unsqueeze(2) + transition_matrix
summed_potentials = summed_potentials.view(-1, num_tags)
# Best pairwise potential path score from the previous timestep.
max_k = min(summed_potentials.size()[0], top_k)
scores, paths = torch.topk(summed_potentials, k=max_k, dim=0)
# If we have an observation for this timestep, use it
# instead of the distribution over tags.
observation = tag_observations[timestep]
# Warn the user if they have passed
# invalid/extremely unlikely evidence.
if tag_observations[timestep - 1] != -1 and observation != -1:
if transition_matrix[tag_observations[timestep - 1], observation] < -10000:
logger.warning(
"The pairwise potential between tags you have passed as "
"observations is extremely unlikely. Double check your evidence "
"or transition potentials!"
)
if observation != -1:
one_hot = torch.zeros(num_tags)
one_hot[observation] = 100000.0
path_scores.append(one_hot.unsqueeze(0))
else:
path_scores.append(tag_sequence[timestep, :] + scores)
path_indices.append(paths.squeeze())
# Construct the most likely sequence backwards.
path_scores_v = path_scores[-1].view(-1)
max_k = min(path_scores_v.size()[0], top_k)
viterbi_scores, best_paths = torch.topk(path_scores_v, k=max_k, dim=0)
viterbi_paths = []
for i in range(max_k):
viterbi_path = [best_paths[i]]
for backward_timestep in reversed(path_indices):
viterbi_path.append(int(backward_timestep.view(-1)[viterbi_path[-1]]))
# Reverse the backward path.
viterbi_path.reverse()
if has_start_end_restrictions:
viterbi_path = viterbi_path[1:-1]
# Viterbi paths uses (num_tags * n_permutations) nodes; therefore, we need to modulo.
viterbi_path = [j % num_tags for j in viterbi_path]
viterbi_paths.append(viterbi_path)
if flatten_output:
return viterbi_paths[0], viterbi_scores[0]
return viterbi_paths, viterbi_scores
def get_text_field_mask(
text_field_tensors: Dict[str, Dict[str, torch.Tensor]], num_wrapping_dims: int = 0
) -> torch.BoolTensor:
"""
Takes the dictionary of tensors produced by a `TextField` and returns a mask
with 0 where the tokens are padding, and 1 otherwise. We also handle `TextFields`
wrapped by an arbitrary number of `ListFields`, where the number of wrapping `ListFields`
is given by `num_wrapping_dims`.
If `num_wrapping_dims == 0`, the returned mask has shape `(batch_size, num_tokens)`.
If `num_wrapping_dims > 0` then the returned mask has `num_wrapping_dims` extra
dimensions, so the shape will be `(batch_size, ..., num_tokens)`.
There could be several entries in the tensor dictionary with different shapes (e.g., one for
word ids, one for character ids). In order to get a token mask, we use the tensor in
the dictionary with the lowest number of dimensions. After subtracting `num_wrapping_dims`,
if this tensor has two dimensions we assume it has shape `(batch_size, ..., num_tokens)`,
and use it for the mask. If instead it has three dimensions, we assume it has shape
`(batch_size, ..., num_tokens, num_features)`, and sum over the last dimension to produce
the mask. Most frequently this will be a character id tensor, but it could also be a
featurized representation of each token, etc.
If the input `text_field_tensors` contains the "mask" key, this is returned instead of inferring the mask.
"""
masks = []
for indexer_name, indexer_tensors in text_field_tensors.items():
if "mask" in indexer_tensors:
masks.append(indexer_tensors["mask"].bool())
if len(masks) == 1:
return masks[0]
elif len(masks) > 1:
# TODO(mattg): My guess is this will basically never happen, so I'm not writing logic to
# handle it. Should be straightforward to handle, though. If you see this error in
# practice, open an issue on github.
raise ValueError("found two mask outputs; not sure which to use!")
tensor_dims = [
(tensor.dim(), tensor)
for indexer_output in text_field_tensors.values()
for tensor in indexer_output.values()
]
tensor_dims.sort(key=lambda x: x[0])
smallest_dim = tensor_dims[0][0] - num_wrapping_dims
if smallest_dim == 2:
token_tensor = tensor_dims[0][1]
return token_tensor != 0
elif smallest_dim == 3:
character_tensor = tensor_dims[0][1]
return (character_tensor > 0).any(dim=-1)
else:
raise ValueError(
"Expected a tensor with dimension 2 or 3, found {}".format(smallest_dim)
)
def get_token_ids_from_text_field_tensors(
text_field_tensors: Dict[str, Dict[str, torch.Tensor]],
) -> torch.Tensor:
"""
Our `TextFieldTensors` are complex output structures, because they try to handle a lot of
potential variation. Sometimes, you just want to grab the token ids from this data structure,
and that's not trivial without hard-coding assumptions about your data processing, which defeats
the entire purpose of that generality. This method tries to let you get the token ids out of the
data structure in your model without hard-coding any assumptions.
"""
for indexer_name, indexer_tensors in text_field_tensors.items():
for argument_name, tensor in indexer_tensors.items():
if argument_name in ["tokens", "token_ids", "input_ids"]:
return tensor
raise NotImplementedError(
"Our heuristic for guessing the right token ids failed. Please open an issue on"
" github with more detail on how you got this error, so we can implement more"
" robust logic in this method."
)
def weighted_sum(matrix: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
"""
Takes a matrix of vectors and a set of weights over the rows in the matrix (which we call an
"attention" vector), and returns a weighted sum of the rows in the matrix. This is the typical
computation performed after an attention mechanism.
Note that while we call this a "matrix" of vectors and an attention "vector", we also handle
higher-order tensors. We always sum over the second-to-last dimension of the "matrix", and we
assume that all dimensions in the "matrix" prior to the last dimension are matched in the
"vector". Non-matched dimensions in the "vector" must be `directly after the batch dimension`.
For example, say I have a "matrix" with dimensions `(batch_size, num_queries, num_words,
embedding_dim)`. The attention "vector" then must have at least those dimensions, and could
have more. Both:
- `(batch_size, num_queries, num_words)` (distribution over words for each query)
- `(batch_size, num_documents, num_queries, num_words)` (distribution over words in a
query for each document)
are valid input "vectors", producing tensors of shape:
`(batch_size, num_queries, embedding_dim)` and
`(batch_size, num_documents, num_queries, embedding_dim)` respectively.
"""
# We'll special-case a few settings here, where there are efficient (but poorly-named)
# operations in pytorch that already do the computation we need.
if attention.dim() == 2 and matrix.dim() == 3:
return attention.unsqueeze(1).bmm(matrix).squeeze(1)
if attention.dim() == 3 and matrix.dim() == 3:
return attention.bmm(matrix)
if matrix.dim() - 1 < attention.dim():
expanded_size = list(matrix.size())
for i in range(attention.dim() - matrix.dim() + 1):
matrix = matrix.unsqueeze(1)
expanded_size.insert(i + 1, attention.size(i + 1))
matrix = matrix.expand(*expanded_size)
intermediate = attention.unsqueeze(-1).expand_as(matrix) * matrix
return intermediate.sum(dim=-2)
def sequence_cross_entropy_with_logits(
logits: torch.FloatTensor,
targets: torch.LongTensor,
weights: Union[torch.FloatTensor, torch.BoolTensor],
average: str = "batch",
label_smoothing: float = None,
gamma: float = None,
alpha: Union[float, List[float], torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Computes the cross entropy loss of a sequence, weighted with respect to
some user provided weights. Note that the weighting here is not the same as
in the `torch.nn.CrossEntropyLoss()` criterion, which is weighting
classes; here we are weighting the loss contribution from particular elements
in the sequence. This allows loss computations for models which use padding.
# Parameters
logits : `torch.FloatTensor`, required.
A `torch.FloatTensor` of size (batch_size, sequence_length, num_classes)
which contains the unnormalized probability for each class.
targets : `torch.LongTensor`, required.
A `torch.LongTensor` of size (batch, sequence_length) which contains the
index of the true class for each corresponding step.
weights : `Union[torch.FloatTensor, torch.BoolTensor]`, required.
A `torch.FloatTensor` of size (batch, sequence_length)
average: str, optional (default = "batch")
If "batch", average the loss across the batches. If "token", average
the loss across each item in the input. If `None`, return a vector
of losses per batch element.
label_smoothing : `float`, optional (default = None)
Whether or not to apply label smoothing to the cross-entropy loss.
For example, with a label smoothing value of 0.2, a 4 class classification
target would look like `[0.05, 0.05, 0.85, 0.05]` if the 3rd class was
the correct label.
gamma : `float`, optional (default = None)
Focal loss[*] focusing parameter `gamma` to reduces the relative loss for
well-classified examples and put more focus on hard. The greater value
`gamma` is, the more focus on hard examples.
alpha : `float` or `List[float]`, optional (default = None)
Focal loss[*] weighting factor `alpha` to balance between classes. Can be
used independently with `gamma`. If a single `float` is provided, it
is assumed binary case using `alpha` and `1 - alpha` for positive and
negative respectively. If a list of `float` is provided, with the same
length as the number of classes, the weights will match the classes.
[*] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár, "Focal Loss for
Dense Object Detection," 2017 IEEE International Conference on Computer
Vision (ICCV), Venice, 2017, pp. 2999-3007.
# Returns
`torch.FloatTensor`
A torch.FloatTensor representing the cross entropy loss.
If `average=="batch"` or `average=="token"`, the returned loss is a scalar.
If `average is None`, the returned loss is a vector of shape (batch_size,).
"""
if average not in {None, "token", "batch"}:
raise ValueError(
"Got average f{average}, expected one of None, 'token', or 'batch'"
)
# make sure weights are float
weights = weights.to(logits.dtype)
# sum all dim except batch
non_batch_dims = tuple(range(1, len(weights.shape)))
# shape : (batch_size,)
weights_batch_sum = weights.sum(dim=non_batch_dims)
# shape : (batch * sequence_length, num_classes)
logits_flat = logits.view(-1, logits.size(-1))
# shape : (batch * sequence_length, num_classes)
log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
# shape : (batch * max_len, 1)
targets_flat = targets.view(-1, 1).long()
# focal loss coefficient
if gamma:
# shape : (batch * sequence_length, num_classes)
probs_flat = log_probs_flat.exp()
# shape : (batch * sequence_length,)
probs_flat = torch.gather(probs_flat, dim=1, index=targets_flat)
# shape : (batch * sequence_length,)
focal_factor = (1.0 - probs_flat) ** gamma
# shape : (batch, sequence_length)
focal_factor = focal_factor.view(*targets.size())
weights = weights * focal_factor
if alpha is not None:
# shape : () / (num_classes,)
if isinstance(alpha, (float, int)):
# shape : (2,)
alpha_factor = torch.tensor(
[1.0 - float(alpha), float(alpha)],
dtype=weights.dtype,
device=weights.device,
)
elif isinstance(alpha, (list, numpy.ndarray, torch.Tensor)):
# shape : (c,)
alpha_factor = torch.tensor(
alpha, dtype=weights.dtype, device=weights.device
)
if not alpha_factor.size():
# shape : (1,)
alpha_factor = alpha_factor.view(1)
# shape : (2,)
alpha_factor = torch.cat([1 - alpha_factor, alpha_factor])
else:
raise TypeError(
(
"alpha must be float, list of float, or torch.FloatTensor, {}"
" provided."
).format(type(alpha))
)
# shape : (batch, max_len)
alpha_factor = torch.gather(
alpha_factor, dim=0, index=targets_flat.view(-1)
).view(*targets.size())
weights = weights * alpha_factor
if label_smoothing is not None and label_smoothing > 0.0:
num_classes = logits.size(-1)
smoothing_value = label_smoothing / num_classes
# Fill all the correct indices with 1 - smoothing value.
one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(
-1, targets_flat, 1.0 - label_smoothing
)