-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
112 lines (94 loc) · 3.32 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import math
import torch
import torch.nn as nn
import random
import numpy as np
import torch.nn.functional as F
import argparse
import os
import shutil
import time
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
def reject_CrossEntropyLoss(outputs, m, labels, m2, n_classes):
'''
The L_{CE} loss implementation for CIFAR
----
outputs: network outputs
m: cost of deferring to expert cost of classifier predicting (I_{m =y})
labels: target
m2: cost of classifier predicting (alpha* I_{m\neq y} + I_{m =y})
n_classes: number of classes
'''
batch_size = outputs.size()[0] # batch_size
rc = [n_classes] * batch_size
outputs = -m * torch.log2(outputs[range(batch_size), rc]) - m2 * torch.log2(
outputs[range(batch_size), labels])
return torch.mean(outputs)
def my_CrossEntropyLoss(outputs, labels):
# Regular Cross entropy loss
batch_size = outputs.size()[0] # batch_size
outputs = - torch.log2(outputs[range(batch_size), labels]) # regular CE
return torch.mean(outputs)
def my_CrossEntropyLossWithSoftmax(outputs, labels):
# Regular Cross entropy loss with outputs being logits without softmax
outputs = F.softmax(outputs)
batch_size = outputs.size()[0] # batch_size
outputs = - torch.log2(outputs[range(batch_size), labels]) # regular CE
return torch.mean(outputs)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class synth_expert:
'''
simple class to describe our synthetic expert on CIFAR-10
----
k: number of classes expert can predict
n_classes: number of classes (10 for CIFAR-10)
'''
def __init__(self, k, n_classes):
self.k = k
self.n_classes = n_classes
def predict(self, input, labels):
batch_size = labels.size()[0] # batch_size
outs = [0] * batch_size
for i in range(0, batch_size):
if labels[i].item() <= self.k -1: # CHANGE FROM OLD PAPER
outs[i] = labels[i].item()
else:
# change to determinsticly false
prediction_rand = random.randint(0, self.n_classes - 1)
outs[i] = prediction_rand
return outs