-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_visualize_model.m
84 lines (64 loc) · 2.13 KB
/
run_visualize_model.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
function env = run_visualize_model(train_mode, agent_weights, scene, frame, camera)
clc;
close all
% Add paths
addpath(genpath('code/'));
% Setup global config settings
load_config(train_mode)
% Create helper
helper = Helpers();
helper.setup_caffe();
global CONFIG
CONFIG.use_recorder = 1;
% Use global env if exists to speed up loading
global env
% Launch panoptic environment
if ~isobject(env)
env = Panoptic(CONFIG.dataset_path, CONFIG.dataset_cache, nan);
end
env.reset();
if exist('scene', 'var')
for i = 1:numel(env.scenes)
env.goto_scene(i);
if strcmp(env.scene().scene_name, scene)
break;
end
end
end
if ~exist('frame', 'var')
frame = randi(env.scene().nbr_frames);
fprintf('Selecting random frame: %d\n', frame);
end
if ~exist('camera', 'var')
camera = randi(env.scene().nbr_cameras);
fprintf('Selecting random camera: %d\n', camera);
end
% Load agent
helper.set_train_proto();
solver = caffe.get_solver(CONFIG.agent_solver_proto);
% Get agent network
net = solver.net;
if ~isnan(agent_weights)
fprintf('Loaded RL network weights from %s\n', agent_weights);
net.copy_from(agent_weights);
end
% Only need to register data names which are given / produced both in
% forward and backward directions
data_names = {'data', 'pred', 'aux', 'canvas', 'rig', 'm', ...
'action', 'elev_mult', 'rewards_mises'};
agent = ACTOR(net, solver, CONFIG.training_agent_eps_per_batch, ...
data_names, CONFIG.sequence_length_train);
stats = StatCollector('Dummy Visualization');
% Create episode recorder
recorder = EpisodeRecorder(env, agent);
agent.reset();
env.goto_cam(camera);
env.goto_frame(frame);
% Execute an active sequence
out_sequence = execute_active_sequence(env, agent, stats, 1, ...
agent.last_trained_ep, ...
CONFIG.sequence_length_eval, ...
recorder, 1);
fprintf('Visualizing %s to %s\nS:%s F:%d C:%d\n', agent_weights, CONFIG.output_dir, env.scene.scene_name, frame, camera);
recorder.plot(0);
end