Skip to content

Commit 355e2c0

Browse files
Merge pull request #3491 from sheldon-roberts/master
Fix TextRegressor label_name bug
2 parents afedf3c + 02f259f commit 355e2c0

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

flair/models/text_regression_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]:
6464

6565
def _labels_to_tensor(self, sentences: List[Sentence]):
6666
indices = [
67-
torch.tensor([float(label.value) for label in sentence.labels], dtype=torch.float) for sentence in sentences
67+
torch.tensor([float(label.value) for label in sentence.get_labels(self.label_name)], dtype=torch.float)
68+
for sentence in sentences
6869
]
6970

7071
vec = torch.cat(indices, 0).to(flair.device)

0 commit comments

Comments
 (0)