diff --git a/utils/eval.py b/utils/eval.py index 50513501..5b84c88c 100644 --- a/utils/eval.py +++ b/utils/eval.py @@ -13,6 +13,6 @@ def accuracy(output, target, topk=(1,)): res = [] for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) + correct_k = correct[:k].reshape(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) - return res \ No newline at end of file + return res