Skip to content

Commit 105a53a

Browse files
committed
remove UserWarning: masked_fill_
1 parent ab82b68 commit 105a53a

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def batchify_sequence_labeling_with_label(input_batch_list, gpu, if_train=True):
225225
feature_seq_tensors = []
226226
for idx in range(feature_num):
227227
feature_seq_tensors.append(torch.zeros((batch_size, max_seq_len),requires_grad = if_train).long())
228-
mask = torch.zeros((batch_size, max_seq_len), requires_grad = if_train).byte()
228+
mask = torch.zeros((batch_size, max_seq_len), requires_grad = if_train).bool()
229229
for idx, (seq, label, seqlen) in enumerate(zip(words, labels, word_seq_lengths)):
230230
seqlen = seqlen.item()
231231
word_seq_tensor[idx, :seqlen] = torch.LongTensor(seq)
@@ -304,7 +304,7 @@ def batchify_sentence_classification_with_label(input_batch_list, gpu, if_train=
304304
feature_seq_tensors = []
305305
for idx in range(feature_num):
306306
feature_seq_tensors.append(torch.zeros((batch_size, max_seq_len),requires_grad = if_train).long())
307-
mask = torch.zeros((batch_size, max_seq_len), requires_grad = if_train).byte()
307+
mask = torch.zeros((batch_size, max_seq_len), requires_grad = if_train).bool()
308308
label_seq_tensor = torch.LongTensor(labels)
309309
# exit(0)
310310
for idx, (seq, seqlen) in enumerate(zip(words, word_seq_lengths)):

main_parse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def batchify_with_label(input_batch_list, gpu, volatile_flag=False):
233233
feature_seq_tensors = []
234234
for idx in range(feature_num):
235235
feature_seq_tensors.append(autograd.Variable(torch.zeros((batch_size, max_seq_len)),volatile = volatile_flag).long())
236-
mask = autograd.Variable(torch.zeros((batch_size, max_seq_len)),volatile = volatile_flag).byte()
236+
mask = autograd.Variable(torch.zeros((batch_size, max_seq_len)),volatile = volatile_flag).bool()
237237
for idx, (seq, label, seqlen) in enumerate(zip(words, labels, word_seq_lengths)):
238238
word_seq_tensor[idx, :seqlen] = torch.LongTensor(seq)
239239
label_seq_tensor[idx, :seqlen] = torch.LongTensor(label)

model/crf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _viterbi_decode(self, feats, mask):
133133
partition_history = list()
134134
## reverse mask (bug for mask = 1- mask, use this as alternative choice)
135135
# mask = 1 + (-1)*mask
136-
mask = (1 - mask.long()).byte()
136+
mask = (1 - mask.long()).bool()
137137
_, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size
138138
# only need start from start_tag
139139
partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size) # bat_size * to_target_size
@@ -297,7 +297,7 @@ def _viterbi_decode_nbest(self, feats, mask, nbest):
297297
partition_history = list()
298298
## reverse mask (bug for mask = 1- mask, use this as alternative choice)
299299
# mask = 1 + (-1)*mask
300-
mask = (1 - mask.long()).byte()
300+
mask = (1 - mask.long()).bool()
301301
_, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size
302302
# only need start from start_tag
303303
partition = inivalues[:, START_TAG, :].clone() # bat_size * to_target_size

0 commit comments

Comments
 (0)