-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
PrimeshES
committed
Dec 12, 2021
1 parent
a1f7ad1
commit 7ce34fd
Showing
9 changed files
with
298 additions
and
4 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"from torch.utils.data import Dataset, DataLoader\n", | ||
"from utils_logging import setup_logger\n", | ||
"from models.base import BaseModel\n", | ||
"from dataloader.base import GooDataset\n", | ||
"from training.base import train_base_model, GazeOptimizer\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"logger = setup_logger(name='first_logger',\n", | ||
" log_dir ='./logs/',\n", | ||
" log_file='train_chong_gooreal.log',\n", | ||
" log_format = '%(asctime)s %(levelname)s %(message)s',\n", | ||
" verbose=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"batch_size=4\n", | ||
"workers=12\n", | ||
"images_dir = '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/finalrealdatasetImgsV2'\n", | ||
"pickle_path = '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/oneshotrealhumansNew2.pickle'\n", | ||
"test_images_dir = '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/finalrealdatasetImgsV2'\n", | ||
"test_pickle_path = '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/testrealhumansNew2.pickle'\n", | ||
"val_images_dir = '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/finalrealdatasetImgsV2'\n", | ||
"val_pickle_path = '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/valrealhumansNew2.pickle'\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Train\n" | ||
] | ||
}, | ||
{ | ||
"ename": "FileNotFoundError", | ||
"evalue": "[Errno 2] No such file or directory: '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/oneshotrealhumansNew2.pickle'", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", | ||
"\u001b[0;32m<ipython-input-5-a6d1dcfee7de>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mprint\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'Train'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrain_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGooDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimages_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpickle_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'train'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m train_data_loader = DataLoader(dataset=train_set,\n\u001b[1;32m 4\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | ||
"\u001b[0;32m/media/primesh/F4D0EA80D0EA49061/PROJECTS/FYP/Gaze detection/code/RetailGaze/dataloader/base.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root_dir, mat_file, training, include_path, input_size, output_size, imshow, use_gtbox)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_gtbox\u001b[0m\u001b[0;34m=\u001b[0m \u001b[0muse_gtbox\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 46\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmat_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 47\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimage_num\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | ||
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/oneshotrealhumansNew2.pickle'" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print ('Train')\n", | ||
"train_set = GooDataset(images_dir, pickle_path, 'train')\n", | ||
"train_data_loader = DataLoader(dataset=train_set,\n", | ||
" batch_size=batch_size,\n", | ||
" shuffle=True,\n", | ||
" num_workers=16)\n", | ||
"print ('Val')\n", | ||
"val_set = GooDataset(val_images_dir, val_pickle_path, 'train')\n", | ||
"val_data_loader = DataLoader(dataset=val_set,\n", | ||
" batch_size=4,\n", | ||
" shuffle=True,\n", | ||
" num_workers=16)\n", | ||
"print ('Test')\n", | ||
"test_set = GooDataset(test_images_dir, test_pickle_path, 'test')\n", | ||
"test_data_loader = DataLoader(test_set, batch_size=batch_size//2,\n", | ||
" shuffle=False, num_workers=8)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | ||
"model = BaseModel().cuda()\n", | ||
"start_epoch = 0\n", | ||
"max_epoch = 5\n", | ||
"learning_rate = 1e-4\n", | ||
"\n", | ||
"# Initializes Optimizer\n", | ||
"gaze_opt = GazeOptimizer(model, learning_rate)\n", | ||
"optimizer = gaze_opt.getOptimizer(start_epoch)\n", | ||
"# Loss criteria\n", | ||
"criterion = nn.NLLLoss().cuda()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from torch.utils.tensorboard import SummaryWriter\n", | ||
"writer = SummaryWriter('runs/base_model')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"train_base_model(model, train_data_loader, val_data_loader, criterion, optimizer, logger, writer, num_epochs=50, patience=10)\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"interpreter": { | ||
"hash": "9e6729ac249b985ced52126ba586c51d65f0f05ea9e77b2ef41209df81b7a383" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3.9.5 64-bit ('FYP': conda)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.5" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from torchvision import transforms | ||
import numpy as np | ||
import random | ||
from PIL import Image, ImageOps | ||
from scipy import signal | ||
import cv2 | ||
import torch | ||
|
||
# data transform for image | ||
data_transforms = { | ||
'train': transforms.Compose([ | ||
transforms.Resize((224, 224)), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
]), | ||
'test': transforms.Compose([ | ||
transforms.Resize((224, 224)), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
]) | ||
} | ||
|
||
|
||
# generate a gaussion on points in a map with im_shap | ||
def get_paste_kernel(im_shape, points, kernel, shape=(224 // 4, 224 // 4)): | ||
# square kernel | ||
k_size = kernel.shape[0] // 2 | ||
x, y = points | ||
image_height, image_width = im_shape[:2] | ||
x, y = int(round(image_width * x)), int(round(y * image_height)) | ||
x1, y1 = x - k_size, y - k_size | ||
x2, y2 = x + k_size, y + k_size | ||
h, w = shape | ||
if x2 >= w: | ||
w = x2 + 1 | ||
if y2 >= h: | ||
h = y2 + 1 | ||
heatmap = np.zeros((h, w)) | ||
left, top, k_left, k_top = x1, y1, 0, 0 | ||
if x1 < 0: | ||
left = 0 | ||
k_left = -x1 | ||
if y1 < 0: | ||
top = 0 | ||
k_top = -y1 | ||
|
||
heatmap[top:y2+1, left:x2+1] = kernel[k_top:, k_left:] | ||
return heatmap[0:shape[0], 0:shape[0]] | ||
|
||
|
||
|
||
def gkern(kernlen=51, std=9): | ||
"""Returns a 2D Gaussian kernel array.""" | ||
gkern1d = signal.gaussian(kernlen, std=std).reshape(kernlen, 1) | ||
gkern2d = np.outer(gkern1d, gkern1d) | ||
return gkern2d | ||
|
||
kernel_map = gkern(21, 3) | ||
|
||
def generate_data_field(eye_point): | ||
"""eye_point is (x, y) and between 0 and 1""" | ||
height, width = 224, 224 | ||
x_grid = np.array(range(width)).reshape([1, width]).repeat(height, axis=0) | ||
y_grid = np.array(range(height)).reshape([height, 1]).repeat(width, axis=1) | ||
grid = np.stack((x_grid, y_grid)).astype(np.float32) | ||
|
||
x, y = eye_point | ||
x, y = x * width, y * height | ||
|
||
grid -= np.array([x, y]).reshape([2, 1, 1]).astype(np.float32) | ||
norm = np.sqrt(np.sum(grid ** 2, axis=0)).reshape([1, height, width]) | ||
# avoid zero norm | ||
norm = np.maximum(norm, 0.1) | ||
grid /= norm | ||
return grid | ||
|
||
def preprocess_image(image_path, eye): | ||
image = cv2.imread(image_path, cv2.IMREAD_COLOR) | ||
|
||
# crop face | ||
x_c, y_c = eye | ||
x_0 = x_c - 0.15 | ||
y_0 = y_c - 0.15 | ||
x_1 = x_c + 0.15 | ||
y_1 = y_c + 0.15 | ||
if x_0 < 0: | ||
x_0 = 0 | ||
if y_0 < 0: | ||
y_0 = 0 | ||
if x_1 > 1: | ||
x_1 = 1 | ||
if y_1 > 1: | ||
y_1 = 1 | ||
|
||
h, w = image.shape[:2] | ||
face_image = image[int(y_0 * h):int(y_1 * h), int(x_0 * w):int(x_1 * w), :] | ||
# process face_image for face net | ||
face_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB) | ||
face_image = Image.fromarray(face_image) | ||
face_image = data_transforms['test'](face_image) | ||
# process image for saliency net | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
image = Image.fromarray(image) | ||
image = data_transforms['test'](image) | ||
|
||
# generate gaze field | ||
gaze_field = generate_data_field(eye_point=eye) | ||
sample = {'image' : image, | ||
'face_image': face_image, | ||
'eye_position': torch.FloatTensor(eye), | ||
'gaze_field': torch.from_numpy(gaze_field)} | ||
|
||
return sample | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import os | ||
import logging | ||
|
||
def setup_logger(name, log_dir, log_file, log_format, level=logging.INFO, verbose=False): | ||
"""To setup as many loggers as you want""" | ||
|
||
if not os.path.exists(log_dir): | ||
os.makedirs(log_dir) | ||
log_file = log_dir + log_file | ||
|
||
logging.basicConfig(level=logging.INFO, | ||
format=log_format, | ||
filename=log_file, | ||
filemode='w') | ||
|
||
console = logging.StreamHandler() | ||
console.setLevel(logging.INFO) | ||
logger = logging.getLogger(name) | ||
|
||
#prints to command line if true | ||
if verbose: | ||
logger.addHandler(console) | ||
|
||
return logger |