-
Notifications
You must be signed in to change notification settings - Fork 0
/
MAS_model.py
34 lines (27 loc) · 906 Bytes
/
MAS_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
import os
import shutil
import torchvision
from torchvision import datasets, models, transforms
"The classification Head remains differest which is task based"
"class specific features are limited to last layer"
class ClassHead(nn.Module):
"""
Only the last layer changes which is task specific
"""
def __init__(self,in_features,output_features):
super(ClassHead,self).__init__()
self.classhead = nn.Linear(in_features, output_features)
def forward(self, x):
return x
class SharedModel(nn.Module):
"""
As Shared model is same all across, I am taking alexnet for time being as the baseline
"""
def __init__(self, model):
super(SharedModel, self).__init__()
self.xmodel = models.alexnet(pretrained=True)
self.params = {}
def forward(self,x):
return self.xmodel(x)