Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward-looking (FL) and Detailed Balance (DB) losses #253

Merged
merged 42 commits into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
c2559f7
simple tetris config
AlexandraVolokhova Nov 14, 2023
21cffef
config fixes
AlexandraVolokhova Nov 14, 2023
8829988
final config for simple tetris
AlexandraVolokhova Nov 16, 2023
97430aa
state flow class
AlexandraVolokhova Nov 23, 2023
ebb2737
fl loss functionin gfn + updates in batch class
AlexandraVolokhova Nov 24, 2023
64f7c3b
black, isort
AlexandraVolokhova Nov 24, 2023
9aeec3e
Remove get_parent_is_source() because it is used only once and it is …
alexhernandezgarcia Nov 27, 2023
04bd98e
Make forwardlooking config in gflownet/ as with other losses and move…
alexhernandezgarcia Nov 27, 2023
c733172
Remove attribute self.non_terminal_rewards from both GFlowNetAgent an…
alexhernandezgarcia Nov 27, 2023
764bbec
Merge pull request #255 from alexhernandezgarcia/fl-loss-ahg
AlexandraVolokhova Nov 27, 2023
d1dcbbb
Update gflownet/gflownet.py
AlexandraVolokhova Nov 27, 2023
913bcfb
Update gflownet/gflownet.py
AlexandraVolokhova Nov 27, 2023
9f4dea2
Update gflownet/utils/batch.py, iteration over indices
AlexandraVolokhova Nov 27, 2023
1d7a838
bug fix
AlexandraVolokhova Nov 27, 2023
35f1361
unblack
alexhernandezgarcia Nov 27, 2023
1907e70
Add do_non_terminating=True to get_rewards in getting parents rewards
alexhernandezgarcia Nov 27, 2023
5a135d9
Implementation of the detailed balance loss and the necessary configu…
alexhernandezgarcia Nov 27, 2023
f27a826
Fix: logflows of terminating is log(rewards)
alexhernandezgarcia Nov 27, 2023
9771adb
tests, docstrings, doo_non_terminating in get_parents_rewards
AlexandraVolokhova Nov 27, 2023
391c304
Fix bug in FL loss (logflows needed to be squeezed); now it seems to …
alexhernandezgarcia Nov 27, 2023
0aca158
merge Alex's changes
AlexandraVolokhova Nov 27, 2023
c0f4ef4
isort, black
AlexandraVolokhova Nov 27, 2023
e24f966
Merge pull request #257 from alexhernandezgarcia/fl-loss-ahg
AlexandraVolokhova Nov 27, 2023
741f305
move squeeze to state flow call
AlexandraVolokhova Nov 27, 2023
993b9de
terminal -> terminating (for consistency)
alexhernandezgarcia Nov 27, 2023
06a6fb5
Extend docstring and wrap docstring lines
alexhernandezgarcia Nov 27, 2023
1b67f32
Merge pull request #258 from alexhernandezgarcia/state_flow_squeeze
alexhernandezgarcia Nov 27, 2023
6141b57
Edit docstring
alexhernandezgarcia Nov 28, 2023
660d5ef
Edit docstring
alexhernandezgarcia Nov 28, 2023
293010f
Compute DB loss on terminating and intermediate states
alexhernandezgarcia Nov 28, 2023
798b5ef
Remove squeeze in DB loss
alexhernandezgarcia Nov 28, 2023
21e38ac
Merge pull request #259 from alexhernandezgarcia/fl-loss-ahg
alexhernandezgarcia Nov 28, 2023
7a6bcdb
Merge pull request #256 from alexhernandezgarcia/db-loss
alexhernandezgarcia Nov 28, 2023
20285ff
Adapt names of variables for consistency
alexhernandezgarcia Nov 28, 2023
a777e99
Fixes (cherry-pick)
alexhernandezgarcia Nov 28, 2023
f7e5ec3
Merge pull request #260 from alexhernandezgarcia/fl-loss-ahg
alexhernandezgarcia Nov 28, 2023
08ee25b
if smth is False -> if not smth
AlexandraVolokhova Nov 29, 2023
0007bcc
Assert that env is not done for parents.
alexhernandezgarcia Dec 7, 2023
4390b17
Add section about losses in README
alexhernandezgarcia Dec 7, 2023
08351bb
Resolve conflicts with main: changes to policy base
alexhernandezgarcia Dec 8, 2023
d3b0126
Add DB and FL runs to sanity checks.
alexhernandezgarcia Dec 8, 2023
fe37c34
Make reward function boltzmann in mini-tetris sanity check experiments
alexhernandezgarcia Dec 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions config/experiments/simple_tetris.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# @package _global_

defaults:
- override /env: tetris
- override /gflownet: trajectorybalance
- override /policy: mlp
- override /proxy: tetris
- override /logger: wandb

env:
reward_func: boltzmann
reward_beta: 10.0
width: 4
height: 4
pieces: ["I", "O", "J", "L", "T"]
rotations: [0, 90, 180, 270]
buffer:
# replay_capacity: 0
test:
type: random
output_csv: simple_tetris_val.csv
output_pkl: simple_tetris_val.pkl
n: 100

gflownet:
random_action_prob: 0.3
optimizer:
n_train_steps: 10000
lr_z_mult: 100
lr: 0.0001

policy:
forward:
type: mlp
n_hid: 128
n_layers: 5

backward:
shared_weights: True
checkpoint: null
reload_ckpt: False

device: cpu
logger:
do:
online: True
project_name: simple_tetris
9 changes: 9 additions & 0 deletions config/gflownet/forwardlooking.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- gflownet
- state_flow: mlp

optimizer:
loss: forwardlooking
lr: 0.0001
lr_decay_period: 1000000
lr_decay_gamma: 0.5
2 changes: 2 additions & 0 deletions config/gflownet/gflownet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ optimizer:
# From original implementation
bootstrap_tau: 0.0
clip_grad_norm: 0.0
# State flow modelling
state_flow: null
# If True, compute rewards in batches
batch_reward: True
# Force zero probability of sampling invalid actions
Expand Down
9 changes: 9 additions & 0 deletions config/gflownet/state_flow/mlp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_target_: gflownet.policy.state_flow.StateFlow

config:
type: mlp
n_hid: 128
n_layers: 2
checkpoint: null
reload_ckpt: False
shared_weights: False
7 changes: 7 additions & 0 deletions config/policy/mlp_forwardlooking.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- mlp

backward:
shared_weights: True
checkpoint: null
reload_ckpt: False
4 changes: 2 additions & 2 deletions gflownet/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,13 +780,13 @@ def traj2readable(self, traj=None):
"""
return str(traj).replace("(", "[").replace(")", "]").replace(",", "")

def reward(self, state=None, done=None):
def reward(self, state=None, done=None, do_non_terminating=False):
"""
Computes the reward of a state
"""
state = self._get_state(state)
done = self._get_done(done)
if done is False:
if done is False and do_non_terminating is False:
AlexandraVolokhova marked this conversation as resolved.
Show resolved Hide resolved
return tfloat(0.0, float_type=self.float, device=self.device)
return self.proxy2reward(self.proxy(self.state2proxy(state))[0])

Expand Down
116 changes: 105 additions & 11 deletions gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
logger,
num_empirical_loss,
oracle,
state_flow=None,
active_learning=False,
sample_only=False,
replay_sampling="permutation",
Expand Down Expand Up @@ -79,6 +80,9 @@ def __init__(
elif optimizer.loss in ["trajectorybalance", "tb"]:
self.loss = "trajectorybalance"
self.logZ = nn.Parameter(torch.ones(optimizer.z_dim) * 150.0 / 64)
elif optimizer.loss in ["forwardlooking", "fl"]:
self.loss = "forwardlooking"
self.logZ = None
else:
print("Unkown loss. Using flowmatch as default")
self.loss = "flowmatch"
Expand Down Expand Up @@ -121,7 +125,8 @@ def __init__(
print(f"\tStd score: {self.buffer.test['energies'].std()}")
print(f"\tMin score: {self.buffer.test['energies'].min()}")
print(f"\tMax score: {self.buffer.test['energies'].max()}")
# Policy models

# Models
self.forward_policy = forward_policy
if self.forward_policy.checkpoint is not None:
self.logger.set_forward_policy_ckpt_path(self.forward_policy.checkpoint)
Expand All @@ -133,6 +138,7 @@ def __init__(
print("Reloaded GFN forward policy model Checkpoint")
else:
self.logger.set_forward_policy_ckpt_path(None)

self.backward_policy = backward_policy
self.logger.set_backward_policy_ckpt_path(None)
if self.backward_policy.checkpoint is not None:
Expand All @@ -145,6 +151,14 @@ def __init__(
print("Reloaded GFN backward policy model Checkpoint")
else:
self.logger.set_backward_policy_ckpt_path(None)

self.state_flow = state_flow
if self.state_flow is not None and self.state_flow.checkpoint is not None:
self.logger.set_state_flow_ckpt_path(self.state_flow.checkpoint)
# TODO: add the logic and conditions to reload a model
else:
self.logger.set_state_flow_ckpt_path(None)

# Optimizer
if self.forward_policy.is_model:
self.target = copy.deepcopy(self.forward_policy.model)
Expand Down Expand Up @@ -178,14 +192,16 @@ def __init__(
self.nll_tt = 0.0

def parameters(self):
if self.backward_policy.is_model is False:
return list(self.forward_policy.model.parameters())
elif self.loss == "trajectorybalance":
return list(self.forward_policy.model.parameters()) + list(
self.backward_policy.model.parameters()
)
else:
raise ValueError("Backward Policy cannot be a nn in flowmatch.")
parameters = list(self.forward_policy.model.parameters())
if self.backward_policy.is_model:
if self.loss == "flowmatch":
raise ValueError("Backward Policy cannot be a model in flowmatch.")
parameters += list(self.backward_policy.model.parameters())
if self.state_flow is not None:
if self.loss != "forwardlooking":
raise ValueError(f"State flow cannot be trained with {self.loss} loss.")
parameters += list(self.state_flow.model.parameters())
return parameters

def sample_actions(
self,
Expand Down Expand Up @@ -663,6 +679,78 @@ def trajectorybalance_loss(self, it, batch):
)
return loss, loss, loss

def forwardlooking_loss(self, it, batch):
"""
Computes the Forward-Looking GFlowNet loss of a batch
Reference : https://arxiv.org/pdf/2302.01687.pdf

Args
----
it : int
Iteration

batch : Batch
A batch of data, containing all the states in the trajectories.


Returns
-------
loss : float

term_loss : float
Loss of the terminal nodes only

nonterm_loss : float
Loss of the intermediate nodes only
"""

assert batch.is_valid()
# Get necessary tensors from batch
states_policy = batch.get_states(policy=True)
states = batch.get_states(policy=False)
actions = batch.get_actions()
parents_policy = batch.get_parents(policy=True)
parents = batch.get_parents(policy=False)
traj_indices = batch.get_trajectory_indices(consecutive=True)
done = batch.get_done()

masks_b = batch.get_masks_backward()
policy_output_b = self.backward_policy(states_policy)
logprobs_bkw = self.env.get_logprobs(
policy_output_b, actions, masks_b, states, is_backward=True
)
masks_f = batch.get_masks_forward(of_parents=True)
policy_output_f = self.forward_policy(parents_policy)
logprobs_fwd = self.env.get_logprobs(
policy_output_f, actions, masks_f, parents, is_backward=False
)

states_log_flflow = self.state_flow(states_policy)
# forward-looking flow is 1 in the terminal states
states_log_flflow[done.eq(1)] = 0.0
carriepl marked this conversation as resolved.
Show resolved Hide resolved
# Can be optimised by reusing states_log_flflow and batch.get_parent_indices
parents_log_flflow = self.state_flow(parents_policy)

rewards_states = batch.get_rewards(do_non_terminating=True)
rewards_parents = batch.get_rewards_parents()
energies_states = -torch.log(rewards_states)
energies_parents = -torch.log(rewards_parents)
AlexandraVolokhova marked this conversation as resolved.
Show resolved Hide resolved

per_node_loss = (
parents_log_flflow
- states_log_flflow
+ logprobs_fwd
- logprobs_bkw
+ energies_states
- energies_parents
).pow(2)

term_loss = per_node_loss[done].mean()
nonterm_loss = per_node_loss[~done].mean()
loss = per_node_loss.mean()

return loss, term_loss, nonterm_loss

@torch.no_grad()
def estimate_logprobs_data(
self,
Expand Down Expand Up @@ -869,6 +957,8 @@ def train(self):
losses = self.trajectorybalance_loss(
it * self.ttsr + j, batch
) # returns (opt loss, *metrics)
elif self.loss == "forwardlooking":
losses = self.forwardlooking_loss(it * self.ttsr + j, batch)
else:
print("Unknown loss!")
# TODO: deal with this in a better way
Expand Down Expand Up @@ -932,7 +1022,9 @@ def train(self):
times.update({"log": t1_log - t0_log})
# Save intermediate models
t0_model = time.time()
self.logger.save_models(self.forward_policy, self.backward_policy, step=it)
self.logger.save_models(
self.forward_policy, self.backward_policy, self.state_flow, step=it
)
t1_model = time.time()
times.update({"save_interim_model": t1_model - t0_model})

Expand Down Expand Up @@ -961,7 +1053,9 @@ def train(self):
self.logger.log_time(times, use_context=self.use_context)

# Save final model
self.logger.save_models(self.forward_policy, self.backward_policy, final=True)
self.logger.save_models(
self.forward_policy, self.backward_policy, self.state_flow, final=True
)
# Close logger
if self.use_context is False:
self.logger.end()
Expand Down
Loading