Skip to content

Commit

Permalink
base model training added
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimeshES committed Dec 12, 2021
1 parent a1f7ad1 commit 7ce34fd
Show file tree
Hide file tree
Showing 9 changed files with 298 additions and 4 deletions.
Binary file added __pycache__/utils_logging.cpython-39.pyc
Binary file not shown.
Binary file added dataloader/__pycache__/base.cpython-39.pyc
Binary file not shown.
Empty file added logs/train_chong_gooreal.log
Empty file.
Binary file added models/__pycache__/base.cpython-39.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import numpy as np
import torchvision.models as models

class Base(nn.Module):
class BaseModel(nn.Module):
def __init__(self):
super(Base, self).__init__()
super(BaseModel, self).__init__()
self.img_feature_dim = 256 # the dimension of the CNN feature to represent each frame
self.base_model = models.resnet50(pretrained=True)
self.base_model.fc2 = nn.Linear(1000, self.img_feature_dim)
Expand Down
155 changes: 155 additions & 0 deletions train_base.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from utils_logging import setup_logger\n",
"from models.base import BaseModel\n",
"from dataloader.base import GooDataset\n",
"from training.base import train_base_model, GazeOptimizer\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"logger = setup_logger(name='first_logger',\n",
" log_dir ='./logs/',\n",
" log_file='train_chong_gooreal.log',\n",
" log_format = '%(asctime)s %(levelname)s %(message)s',\n",
" verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"batch_size=4\n",
"workers=12\n",
"images_dir = '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/finalrealdatasetImgsV2'\n",
"pickle_path = '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/oneshotrealhumansNew2.pickle'\n",
"test_images_dir = '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/finalrealdatasetImgsV2'\n",
"test_pickle_path = '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/testrealhumansNew2.pickle'\n",
"val_images_dir = '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/finalrealdatasetImgsV2'\n",
"val_pickle_path = '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/valrealhumansNew2.pickle'\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train\n"
]
},
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/oneshotrealhumansNew2.pickle'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-5-a6d1dcfee7de>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mprint\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'Train'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrain_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGooDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimages_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpickle_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'train'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m train_data_loader = DataLoader(dataset=train_set,\n\u001b[1;32m 4\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/media/primesh/F4D0EA80D0EA49061/PROJECTS/FYP/Gaze detection/code/RetailGaze/dataloader/base.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root_dir, mat_file, training, include_path, input_size, output_size, imshow, use_gtbox)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_gtbox\u001b[0m\u001b[0;34m=\u001b[0m \u001b[0muse_gtbox\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 46\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmat_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 47\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimage_num\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/media/primesh/F4D0EA80D0EA4906/PROJECTS/FYP/Gaze detection/Datasets/gooreal/oneshotrealhumansNew2.pickle'"
]
}
],
"source": [
"print ('Train')\n",
"train_set = GooDataset(images_dir, pickle_path, 'train')\n",
"train_data_loader = DataLoader(dataset=train_set,\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
" num_workers=16)\n",
"print ('Val')\n",
"val_set = GooDataset(val_images_dir, val_pickle_path, 'train')\n",
"val_data_loader = DataLoader(dataset=val_set,\n",
" batch_size=4,\n",
" shuffle=True,\n",
" num_workers=16)\n",
"print ('Test')\n",
"test_set = GooDataset(test_images_dir, test_pickle_path, 'test')\n",
"test_data_loader = DataLoader(test_set, batch_size=batch_size//2,\n",
" shuffle=False, num_workers=8)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"model = BaseModel().cuda()\n",
"start_epoch = 0\n",
"max_epoch = 5\n",
"learning_rate = 1e-4\n",
"\n",
"# Initializes Optimizer\n",
"gaze_opt = GazeOptimizer(model, learning_rate)\n",
"optimizer = gaze_opt.getOptimizer(start_epoch)\n",
"# Loss criteria\n",
"criterion = nn.NLLLoss().cuda()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.tensorboard import SummaryWriter\n",
"writer = SummaryWriter('runs/base_model')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_base_model(model, train_data_loader, val_data_loader, criterion, optimizer, logger, writer, num_epochs=50, patience=10)\n"
]
}
],
"metadata": {
"interpreter": {
"hash": "9e6729ac249b985ced52126ba586c51d65f0f05ea9e77b2ef41209df81b7a383"
},
"kernelspec": {
"display_name": "Python 3.9.5 64-bit ('FYP': conda)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
4 changes: 2 additions & 2 deletions training/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_bb_binary(box):
b[j][k] = 1
return b

def train_face3d(model,train_data_loader,validation_data_loader, criterion, optimizer, logger, writer ,num_epochs=5,patience=10):
def train_base_model(model,train_data_loader,validation_data_loader, criterion, optimizer, logger, writer ,num_epochs=5,patience=10):
since = time.time()
n_total_steps = len(train_data_loader)
n_total_steps_val = len(validation_data_loader)
Expand Down Expand Up @@ -95,7 +95,7 @@ def train_face3d(model,train_data_loader,validation_data_loader, criterion, opti
return model


def test_face3d(model, test_data_loader, logger, test_depth=True, save_output=False):
def test_base_model(model, test_data_loader, logger, test_depth=True, save_output=False):
model.eval()
angle_error = []
with torch.no_grad():
Expand Down
115 changes: 115 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from torchvision import transforms
import numpy as np
import random
from PIL import Image, ImageOps
from scipy import signal
import cv2
import torch

# data transform for image
data_transforms = {
'train': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}


# generate a gaussion on points in a map with im_shap
def get_paste_kernel(im_shape, points, kernel, shape=(224 // 4, 224 // 4)):
# square kernel
k_size = kernel.shape[0] // 2
x, y = points
image_height, image_width = im_shape[:2]
x, y = int(round(image_width * x)), int(round(y * image_height))
x1, y1 = x - k_size, y - k_size
x2, y2 = x + k_size, y + k_size
h, w = shape
if x2 >= w:
w = x2 + 1
if y2 >= h:
h = y2 + 1
heatmap = np.zeros((h, w))
left, top, k_left, k_top = x1, y1, 0, 0
if x1 < 0:
left = 0
k_left = -x1
if y1 < 0:
top = 0
k_top = -y1

heatmap[top:y2+1, left:x2+1] = kernel[k_top:, k_left:]
return heatmap[0:shape[0], 0:shape[0]]



def gkern(kernlen=51, std=9):
"""Returns a 2D Gaussian kernel array."""
gkern1d = signal.gaussian(kernlen, std=std).reshape(kernlen, 1)
gkern2d = np.outer(gkern1d, gkern1d)
return gkern2d

kernel_map = gkern(21, 3)

def generate_data_field(eye_point):
"""eye_point is (x, y) and between 0 and 1"""
height, width = 224, 224
x_grid = np.array(range(width)).reshape([1, width]).repeat(height, axis=0)
y_grid = np.array(range(height)).reshape([height, 1]).repeat(width, axis=1)
grid = np.stack((x_grid, y_grid)).astype(np.float32)

x, y = eye_point
x, y = x * width, y * height

grid -= np.array([x, y]).reshape([2, 1, 1]).astype(np.float32)
norm = np.sqrt(np.sum(grid ** 2, axis=0)).reshape([1, height, width])
# avoid zero norm
norm = np.maximum(norm, 0.1)
grid /= norm
return grid

def preprocess_image(image_path, eye):
image = cv2.imread(image_path, cv2.IMREAD_COLOR)

# crop face
x_c, y_c = eye
x_0 = x_c - 0.15
y_0 = y_c - 0.15
x_1 = x_c + 0.15
y_1 = y_c + 0.15
if x_0 < 0:
x_0 = 0
if y_0 < 0:
y_0 = 0
if x_1 > 1:
x_1 = 1
if y_1 > 1:
y_1 = 1

h, w = image.shape[:2]
face_image = image[int(y_0 * h):int(y_1 * h), int(x_0 * w):int(x_1 * w), :]
# process face_image for face net
face_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
face_image = Image.fromarray(face_image)
face_image = data_transforms['test'](face_image)
# process image for saliency net
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)
image = data_transforms['test'](image)

# generate gaze field
gaze_field = generate_data_field(eye_point=eye)
sample = {'image' : image,
'face_image': face_image,
'eye_position': torch.FloatTensor(eye),
'gaze_field': torch.from_numpy(gaze_field)}

return sample


24 changes: 24 additions & 0 deletions utils_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os
import logging

def setup_logger(name, log_dir, log_file, log_format, level=logging.INFO, verbose=False):
"""To setup as many loggers as you want"""

if not os.path.exists(log_dir):
os.makedirs(log_dir)
log_file = log_dir + log_file

logging.basicConfig(level=logging.INFO,
format=log_format,
filename=log_file,
filemode='w')

console = logging.StreamHandler()
console.setLevel(logging.INFO)
logger = logging.getLogger(name)

#prints to command line if true
if verbose:
logger.addHandler(console)

return logger

0 comments on commit 7ce34fd

Please sign in to comment.