Skip to content

Commit

Permalink
Change data paths
Browse files Browse the repository at this point in the history
  • Loading branch information
adursun committed Sep 21, 2019
1 parent c7f0a69 commit a47302e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@

data/*
!data/.gitkeep

states/*
!states/.gitkeep
9 changes: 5 additions & 4 deletions wsddn-pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ def __init__(self, split):
self.split = split

#loaded_mat = loadmat(f"/kaggle/input/selective-search-windows/selective_search_data/voc_2007_{self.split}.mat")
loaded_mat = loadmat(f"~/.datasets/selective_search_data/voc_2007_{self.split}.mat")
loaded_mat = loadmat(f"data/selective_search_data/voc_2007_{self.split}.mat")
self.ssw_boxes = loaded_mat["boxes"][0]
self.ssw_scores = loaded_mat["boxScores"][0]

#voc_dir = f"/kaggle/input/pascal-voc/voc{self.split}_06-nov-2007/VOCdevkit/VOC2007"
voc_dir = f"~/.datasets/VOC{self.split}_06-Nov-2007/VOCdevkit/VOC2007"
voc_dir = f"data/VOC{self.split}_06-Nov-2007/VOCdevkit/VOC2007"
self.ids = [id_.strip() for id_ in open(f"{voc_dir}/ImageSets/Main/{self.split}.txt")]
self.img_paths = [f"{voc_dir}/JPEGImages/{id_}.jpg" for id_ in self.ids]
self.annotation_paths = [f"{voc_dir}/Annotations/{id_}.xml" for id_ in self.ids]
Expand Down Expand Up @@ -230,10 +230,11 @@ def __len__(self):
### Create the network

class WSDDN(nn.Module):
base = alexnet(pretrained=True)
base = alexnet(pretrained=False)

def __init__(self):
super().__init__()
self.base.load_state_dict(torch.load("states/alexnet-owt-4df8aa71.pth"))
self.features = self.base.features[:-1]
self.fcs = self.base.classifier[1:-1]
self.fc_c = nn.Linear(4096, 20)
Expand Down Expand Up @@ -299,7 +300,7 @@ def loss_func(combined_scores, target):

optimizer.step()

torch.save(net, f"epoch_{epoch}.pt")
torch.save(net, f"states/epoch_{epoch}.pt")

print("Avg loss is", epoch_loss / len(train_ds))

Expand Down

0 comments on commit a47302e

Please sign in to comment.