Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cnn/genotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')

PRIMITIVES = [
'none',
#'none',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
Expand Down
85 changes: 53 additions & 32 deletions cnn/model_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.autograd import Variable
from genotypes import PRIMITIVES
from genotypes import Genotype
import random

class MixedOp(nn.Module):

Expand All @@ -18,7 +19,11 @@ def __init__(self, C, stride):
self._ops.append(op)

def forward(self, x, weights):
return sum(w * op(x) for w, op in zip(weights, self._ops))
for w, op in zip(weights, self._ops):
if w > 0:
return w * op(x)

return w * op(x)


class Cell(nn.Module):
Expand Down Expand Up @@ -52,15 +57,8 @@ def forward(self, s0, s1, weights, fixed_weights):

states = [s0, s1]
offset = 0
for i in range(self._steps):
if i < self._steps-1:
s = sum(self._ops[offset+j](h, fixed_weights[offset+j]) for j, h in enumerate(states))
else:
#if len(states) > 4:
#for j, h in enumerate(states):
#st = self._ops[offset+j](h, weights[offset+j])
#print (st.size())
s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))
for i in range(self._steps):
s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))
offset += len(states)
states.append(s)

Expand All @@ -69,7 +67,7 @@ def forward(self, s0, s1, weights, fixed_weights):

class Network(nn.Module):

def __init__(self, C, num_classes, layers, criterion, steps=1, multiplier=1, stem_multiplier=3):
def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3):
super(Network, self).__init__()
self._C = C
self._num_classes = num_classes
Expand Down Expand Up @@ -122,27 +120,30 @@ def forward(self, input):
#normal_weights = Variable(normal_weights, requires_grad=True)
#reduce_weights = torch.from_numpy(reduce_weights).cuda()
#reduce_weights = Variable(reduce_weights, requires_grad=True)
if self._steps > 1:
fixed_normal_edges, fixed_reduce_edges = self.fixed_edges
k = sum(1 for i in range(self._steps-1) for n in range(2+i))
fixed_normal_weights = torch.zeros((k,len(PRIMITIVES)))
fixed_reduce_weights = torch.zeros((k,len(PRIMITIVES)))
for i, edge in enumerate(fixed_reduce_edges):
fixed_reduce_weights[edge[1]][edge[0]] = 1.0
for i, edge in enumerate(fixed_normal_edges):
fixed_normal_weights[edge[1]][edge[0]] = 1.0
fixed_reduce_weights = Variable(fixed_reduce_weights, requires_grad=False).cuda()
fixed_normal_weights = Variable(fixed_normal_weights, requires_grad=False).cuda()
else:
fixed_reduce_weights = None
fixed_normal_weights = None

#if self._steps > 1:
#fixed_normal_edges, fixed_reduce_edges = self.fixed_edges
#k = sum(1 for i in range(self._steps-1) for n in range(2+i))
#fixed_normal_weights = torch.zeros((k,len(PRIMITIVES)))
#fixed_reduce_weights = torch.zeros((k,len(PRIMITIVES)))
#for i, edge in enumerate(fixed_reduce_edges):
#fixed_reduce_weights[edge[1]][edge[0]] = 1.0
#for i, edge in enumerate(fixed_normal_edges):
#fixed_normal_weights[edge[1]][edge[0]] = 1.0
#fixed_reduce_weights = Variable(fixed_reduce_weights, requires_grad=False).cuda()
#fixed_normal_weights = Variable(fixed_normal_weights, requires_grad=False).cuda()
#else:
fixed_reduce_weights = None
fixed_normal_weights = None
for i, cell in enumerate(self.cells):
if cell.reduction:
weights = F.softmax(self.alphas_reduce, dim=-1)
#weights = F.softmax(self.alphas_reduce, dim=-1)
weights = self.alphas_reduce
fixed_weights = fixed_reduce_weights
#weights = reduce_weights
else:
weights = F.softmax(self.alphas_normal, dim=-1)
#weights = F.softmax(self.alphas_normal, dim=-1)
weights = self.alphas_normal
fixed_weights = fixed_normal_weights
#weights = normal_weights
s0, s1 = s1, cell(s0, s1, weights, fixed_weights)
Expand All @@ -158,10 +159,10 @@ def _initialize_alphas(self):
k = sum(1 for i in range(self._steps) for n in range(2+i))
num_ops = len(PRIMITIVES)

self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
#self.alphas_normal = Variable(1e-3*torch.zeros(k, num_ops).cuda(), requires_grad=True)
#self.alphas_reduce = Variable(1e-3*torch.zeros(k, num_ops).cuda(), requires_grad=True)
#self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
#self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
self.alphas_normal = Variable(1e-3*torch.zeros(k, num_ops).cuda(), requires_grad=False)
self.alphas_reduce = Variable(1e-3*torch.zeros(k, num_ops).cuda(), requires_grad=False)
self._arch_parameters = [
self.alphas_normal,
self.alphas_reduce,
Expand Down Expand Up @@ -274,4 +275,24 @@ def _compute_weight_from_alphas(self):
for i, edge in enumerate(fixed_reduce_edges):
reduce_weights[edge[1]][edge[0]] = 1.0

return (normal_weights, reduce_weights)
return (normal_weights, reduce_weights)

def generate_random_alphas(self):
self._initialize_alphas()
n = 2
start = 0
for i in range(self._steps):
end = start + n
indices = list(range(start, end))
random.shuffle(indices)
for e in indices[:2]:
selected_op = random.randint(0, len(PRIMITIVES)-1)
self.alphas_normal[e][selected_op] = 1.0

random.shuffle(indices)
for e in indices[:2]:
selected_op = random.randint(0, len(PRIMITIVES)-1)
self.alphas_reduce[e][selected_op] = 1.0

start = end
n += 1
Loading