Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add local search sampler #208

Merged
merged 19 commits into from
Jan 13, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'master' into hyeok9855/local-search
hyeok9855 committed Nov 2, 2024
commit a3af467e5ad1dc462820ebb4edd8033932d53a62
2 changes: 1 addition & 1 deletion src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
@@ -108,7 +108,7 @@ def __init__(
)
else:
log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
self.log_probs = log_probs
self.log_probs: torch.Tensor = log_probs

self.estimator_outputs = estimator_outputs
if self.estimator_outputs is not None:
19 changes: 8 additions & 11 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
@@ -35,7 +35,7 @@ def sample_actions(
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
**policy_kwargs: Any,
) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None,]:
) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None]:
"""Samples actions from the given states.

Args:
@@ -374,12 +374,12 @@ def local_search(
].flip(0)

# Forward part
trajectories_states_tsr[n_back : n_back + len_recon, i] = (
recon_trajectories.states.tensor[:, i]
)
trajectories_actions_tsr[n_back : n_back + len_recon - 1, i] = (
recon_trajectories.actions.tensor[: len_recon - 1, i]
)
trajectories_states_tsr[
n_back : n_back + len_recon, i
] = recon_trajectories.states.tensor[:, i]
trajectories_actions_tsr[
n_back : n_back + len_recon - 1, i
] = recon_trajectories.actions.tensor[: len_recon - 1, i]
if save_logprobs: # concatenate log_probs
raise NotImplementedError("metropolis-hastings is not implemented yet.")

@@ -390,9 +390,6 @@ def local_search(
actions=env.Actions(trajectories_actions_tsr),
when_is_done=trajectories_dones,
is_backward=False,
# FIXME: This is weird... since the trajectory contains
# both backward and forward parts.
# Maybe calculate log_pfs for the backward part -> and set is_backward=True?
log_rewards=trajectories_log_rewards,
log_probs=trajectories_logprobs, # TODO: Support log_probs (`None` for now)
)
@@ -476,7 +473,7 @@ def sample_trajectories(
last_indices = torch.arange(
n * (it + 1), n * (it + 2), device=trajectories.states.device
)
prev_log_rewards = trajectories.log_rewards[search_indices]
prev_log_rewards = trajectories.log_rewards[search_indices] # type: ignore # FIXME: pyright error
new_log_rewards = ls_trajectories.log_rewards
update_indices = prev_log_rewards <= new_log_rewards
search_indices[update_indices] = last_indices[update_indices]
You are viewing a condensed version of this merge commit. You can view the full changes here.