forked from tristandeleu/pytorch-meta
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
27 lines (23 loc) · 1.03 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
def get_accuracy(prototypes, embeddings, targets):
"""Compute the accuracy of the prototypical network on the test/query points.
Parameters
----------
prototypes : `torch.FloatTensor` instance
A tensor containing the prototypes for each class. This tensor has shape
`(meta_batch_size, num_classes, embedding_size)`.
embeddings : `torch.FloatTensor` instance
A tensor containing the embeddings of the query points. This tensor has
shape `(meta_batch_size, num_examples, embedding_size)`.
targets : `torch.LongTensor` instance
A tensor containing the targets of the query points. This tensor has
shape `(meta_batch_size, num_examples)`.
Returns
-------
accuracy : `torch.FloatTensor` instance
Mean accuracy on the query points.
"""
sq_distances = torch.sum((prototypes.unsqueeze(1)
- embeddings.unsqueeze(2)) ** 2, dim=-1)
_, predictions = torch.min(sq_distances, dim=-1)
return torch.mean(predictions.eq(targets).float())