Skip to content

Commit

Permalink
Merge pull request #255 from alexhernandezgarcia/fl-loss-ahg
Browse files Browse the repository at this point in the history
Changes to main FL loss PR
  • Loading branch information
AlexandraVolokhova authored Nov 27, 2023
2 parents 64f7c3b + c733172 commit 764bbec
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 38 deletions.
9 changes: 9 additions & 0 deletions config/gflownet/forwardlooking.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- gflownet
- state_flow: mlp

optimizer:
loss: forwardlooking
lr: 0.0001
lr_decay_period: 1000000
lr_decay_gamma: 0.5
2 changes: 2 additions & 0 deletions config/gflownet/gflownet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ optimizer:
# From original implementation
bootstrap_tau: 0.0
clip_grad_norm: 0.0
# State flow modelling
state_flow: null
# If True, compute rewards in batches
batch_reward: True
# Force zero probability of sampling invalid actions
Expand Down
File renamed without changes.
1 change: 0 additions & 1 deletion config/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ defaults:
- env: grid
- gflownet: flowmatch
- policy: mlp_${gflownet}
- state_flow: null
- proxy: corners
- logger: wandb
- user: alex
Expand Down
7 changes: 7 additions & 0 deletions config/policy/mlp_forwardlooking.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- mlp

backward:
shared_weights: True
checkpoint: null
reload_ckpt: False
14 changes: 1 addition & 13 deletions gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,12 @@ def __init__(
if optimizer.loss in ["flowmatch", "flowmatching"]:
self.loss = "flowmatch"
self.logZ = None
self.non_terminal_rewards = False
elif optimizer.loss in ["trajectorybalance", "tb"]:
self.loss = "trajectorybalance"
self.logZ = nn.Parameter(torch.ones(optimizer.z_dim) * 150.0 / 64)
self.non_terminal_rewards = False
elif optimizer.loss in ["forwardlooking", "fl"]:
self.loss = "forwardlooking"
self.logZ = None
self.non_terminal_rewards = True
else:
print("Unkown loss. Using flowmatch as default")
self.loss = "flowmatch"
Expand Down Expand Up @@ -428,7 +425,6 @@ def sample_batch(
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)

# ON-POLICY FORWARD trajectories
Expand All @@ -438,7 +434,6 @@ def sample_batch(
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)
while envs:
# Sample actions
Expand All @@ -465,7 +460,6 @@ def sample_batch(
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)
if n_train > 0 and self.buffer.train_pkl is not None:
with open(self.buffer.train_pkl, "rb") as f:
Expand Down Expand Up @@ -501,7 +495,6 @@ def sample_batch(
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)
if n_replay > 0 and self.buffer.replay_pkl is not None:
with open(self.buffer.replay_pkl, "rb") as f:
Expand Down Expand Up @@ -754,8 +747,7 @@ def forwardlooking_loss(self, it, batch):
# Can be optimised by reusing states_log_flflow and batch.get_parent_indices
parents_log_flflow = self.state_flow(parents_policy)

assert batch.non_terminal_rewards
rewards_states = batch.get_rewards()
rewards_states = batch.get_rewards(do_non_terminating=True)
rewards_parents = batch.get_rewards_parents()
energies_states = -torch.log(rewards_states)
energies_parents = -torch.log(rewards_parents)
Expand Down Expand Up @@ -869,7 +861,6 @@ def estimate_logprobs_data(
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)
# Create an environment for each data point and trajectory and set the state
envs = []
Expand Down Expand Up @@ -973,7 +964,6 @@ def train(self):
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)
for j in range(self.sttr):
sub_batch, times = self.sample_batch(
Expand Down Expand Up @@ -1268,7 +1258,6 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None):
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)
self.random_action_prob = 0
t = time.time()
Expand Down Expand Up @@ -1296,7 +1285,6 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None):
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)
self.random_action_prob = 1.0
print("[test_top_k] Sampling at random...", end="\r")
Expand Down
45 changes: 25 additions & 20 deletions gflownet/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(
env: Optional[GFlowNetEnv] = None,
device: Union[str, torch.device] = "cpu",
float_type: Union[int, torch.dtype] = 32,
non_terminal_rewards: bool = False,
):
"""
env : GFlowNetEnv
Expand All @@ -57,8 +56,6 @@ def __init__(
self.device = set_device(device)
# Float precision
self.float = set_float_precision(float_type)
# Whether rewards should be computed for non-terminal states
self.non_terminal_rewards = non_terminal_rewards
# Generic environment, properties and dictionary of state and forward mask of
# source (as tensor)
if env is not None:
Expand Down Expand Up @@ -531,11 +528,6 @@ def get_parents_indices(self):
self._compute_parents()
return self.parents_indices

def get_parent_is_source(self):
if self.parents_available is False:
self._compute_parents()
return self.parents_indices == -1

def _compute_parents(self):
"""
Obtains the parent (single parent for each state) of all states in the batch.
Expand Down Expand Up @@ -842,7 +834,9 @@ def _compute_masks_backward(self):
self.masks_backward_available = True

def get_rewards(
self, force_recompute: Optional[bool] = False
self,
force_recompute: Optional[bool] = False,
do_non_terminating: Optional[bool] = False,
) -> TensorType["n_states"]:
"""
Returns the rewards of all states in the batch (including not done).
Expand All @@ -851,26 +845,37 @@ def get_rewards(
----
force_recompute : bool
If True, the rewards are recomputed even if they are available.
do_non_terminating : bool
If True, compute the rewards of the non-terminating states instead of
assigning reward 0.
"""
if self.rewards_available is False or force_recompute is True:
self._compute_rewards()
self._compute_rewards(do_non_terminating)
return self.rewards

def _compute_rewards(self):
def _compute_rewards(self, do_non_terminating: Optional[bool] = False):
"""
Computes rewards for all self.states by first converting the states into proxy
format. The result is stored in self.rewards as a torch.tensor
Args
----
do_non_terminating : bool
If True, compute the rewards of the non-terminating states instead of
assigning reward 0.
"""

self.rewards = torch.zeros(len(self), dtype=self.float, device=self.device)
done = self.get_done()
if self.non_terminal_rewards:
if do_non_terminating:
self.rewards = self.env.proxy2reward(self.env.proxy(self.states2proxy()))
elif len(done) > 0:
states_proxy_done = self.get_terminating_states(proxy=True)
self.rewards[done] = self.env.proxy2reward(
self.env.proxy(states_proxy_done)
)
else:
self.rewards = torch.zeros(len(self), dtype=self.float, device=self.device)
done = self.get_done()
if len(done) > 0:
states_proxy_done = self.get_terminating_states(proxy=True)
self.rewards[done] = self.env.proxy2reward(
self.env.proxy(states_proxy_done)
)
self.rewards_available = True

def get_rewards_parents(self) -> TensorType["n_states"]:
Expand All @@ -888,8 +893,8 @@ def _compute_rewards_parents(self):
"""
state_rewards = self.get_rewards()
self.rewards_parents = torch.zeros_like(state_rewards)
parent_is_source = self.get_parent_is_source()
parent_indices = self.get_parents_indices()
parent_is_source = parent_indices == -1
self.rewards_parents[~parent_is_source] = self.rewards[
parent_indices[~parent_is_source]
]
Expand Down
10 changes: 6 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,18 @@ def main(config):
float_precision=config.float_precision,
base=forward_policy,
)

if config.gflownet.optimizer.loss in ["forwardlooking", "fl"]:
# State flow
if config.gflownet.state_flow is not None:
state_flow = hydra.utils.instantiate(
config.state_flow,
config.gflownet.state_flow,
env=env,
device=config.device,
float_precision=config.float_precision,
base=forward_policy,
)
else:
state_flow = None

# GFlowNet Agent
gflownet = hydra.utils.instantiate(
config.gflownet,
device=config.device,
Expand All @@ -83,6 +83,8 @@ def main(config):
buffer=config.env.buffer,
logger=logger,
)

# Train GFlowNet
gflownet.train()

# Sample from trained GFlowNet
Expand Down

0 comments on commit 764bbec

Please sign in to comment.