From a47302e29ed0d1f86fb06135ec2d2f8b84516e7f Mon Sep 17 00:00:00 2001 From: Abdullah DURSUN Date: Sat, 21 Sep 2019 13:17:06 +0300 Subject: [PATCH] Change data paths --- .gitignore | 3 +++ wsddn-pytorch.py | 9 +++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 5e812fb..faf5e09 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,6 @@ data/* !data/.gitkeep + +states/* +!states/.gitkeep diff --git a/wsddn-pytorch.py b/wsddn-pytorch.py index 87470db..8d1d820 100644 --- a/wsddn-pytorch.py +++ b/wsddn-pytorch.py @@ -123,12 +123,12 @@ def __init__(self, split): self.split = split #loaded_mat = loadmat(f"/kaggle/input/selective-search-windows/selective_search_data/voc_2007_{self.split}.mat") - loaded_mat = loadmat(f"~/.datasets/selective_search_data/voc_2007_{self.split}.mat") + loaded_mat = loadmat(f"data/selective_search_data/voc_2007_{self.split}.mat") self.ssw_boxes = loaded_mat["boxes"][0] self.ssw_scores = loaded_mat["boxScores"][0] #voc_dir = f"/kaggle/input/pascal-voc/voc{self.split}_06-nov-2007/VOCdevkit/VOC2007" - voc_dir = f"~/.datasets/VOC{self.split}_06-Nov-2007/VOCdevkit/VOC2007" + voc_dir = f"data/VOC{self.split}_06-Nov-2007/VOCdevkit/VOC2007" self.ids = [id_.strip() for id_ in open(f"{voc_dir}/ImageSets/Main/{self.split}.txt")] self.img_paths = [f"{voc_dir}/JPEGImages/{id_}.jpg" for id_ in self.ids] self.annotation_paths = [f"{voc_dir}/Annotations/{id_}.xml" for id_ in self.ids] @@ -230,10 +230,11 @@ def __len__(self): ### Create the network class WSDDN(nn.Module): - base = alexnet(pretrained=True) + base = alexnet(pretrained=False) def __init__(self): super().__init__() + self.base.load_state_dict(torch.load("states/alexnet-owt-4df8aa71.pth")) self.features = self.base.features[:-1] self.fcs = self.base.classifier[1:-1] self.fc_c = nn.Linear(4096, 20) @@ -299,7 +300,7 @@ def loss_func(combined_scores, target): optimizer.step() - torch.save(net, f"epoch_{epoch}.pt") + torch.save(net, f"states/epoch_{epoch}.pt") print("Avg loss is", epoch_loss / len(train_ds))