-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
64 lines (46 loc) · 1.45 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import json
import torch
import torch.nn as nn
def masked_softmax(similarity_scores, mask):
"""
Args:
similarity_scores: FloatTensor with shape of [bz, seq_len_q, seq_len_k]
mask: BoolTensor with shape of [bz, seq_len_k]
Returns:
masked_attention_weights
"""
while mask.dim() < similarity_scores.dim():
mask = mask.unsqueeze(1)
masked_score = similarity_scores.masked_fill(~mask, -1e8)
masked_attention_weights = torch.nn.functional.softmax(masked_score, dim=-1)
return masked_attention_weights
def get_mask(tensor, padding_idx=0):
"""
Get a mask to `tensor`.
Args:
tensor: LongTensor with shape of [bz, seq_len]
Returns:
mask: BoolTensor with shape of [bz, seq_len]
"""
mask = torch.ones(size=list(tensor.size()), dtype=torch.bool)
mask[tensor[:,:] == padding_idx] = False
return mask
def correct_instance_count(pred_logits, labels):
"""
Args:
pred_logits: FloatTensor with shape of [bz, number_of_classes_in_labels]
labels: LongTensor with shape of [bz]
Returns:
correct_count: int
"""
pred_labels = pred_logits.argmax(dim=1)
return (pred_labels == labels).sum().item()
class Args():
pass
def parse_args(config):
args = Args()
with open(config, 'r') as f:
config = json.load(f)
for name, val in config.items():
setattr(args, name, val)
return args