Skip to content

Commit 78b729a

Browse files
committed
fix merging issues
1 parent c3df427 commit 78b729a

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

src/gfn/containers/trajectories.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(
104104
assert (
105105
log_probs.shape == (self.max_length, self.n_trajectories)
106106
and log_probs.dtype == torch.float
107-
)
107+
), f"log_probs.shape={log_probs.shape}, self.max_length={self.max_length}, self.n_trajectories={self.n_trajectories}"
108108
else:
109109
log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
110110
self.log_probs: torch.Tensor = log_probs

src/gfn/samplers.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def sample_trajectories(
207207
all_estimator_outputs.append(estimator_outputs_padded)
208208

209209
actions[~dones] = valid_actions
210-
trajectories_actions.append(actions)
211210
if save_logprobs:
212211
# When off_policy, actions_log_probs are None.
213212
log_probs[~dones] = actions_log_probs
@@ -247,7 +246,9 @@ def sample_trajectories(
247246
trajectories_states.append(deepcopy(states))
248247

249248
trajectories_states = env.States.stack(trajectories_states)
250-
trajectories_actions = env.Actions.stack(trajectories_actions)[1:] # Drop dummy action
249+
trajectories_actions = env.Actions.stack(trajectories_actions)[
250+
1:
251+
] # Drop dummy action
251252
trajectories_logprobs = (
252253
torch.stack(trajectories_logprobs, dim=0)[1:] # Drop dummy logprob
253254
if save_logprobs
@@ -257,7 +258,6 @@ def sample_trajectories(
257258
# TODO: use torch.nested.nested_tensor(dtype, device, requires_grad).
258259
if save_estimator_outputs:
259260
all_estimator_outputs = torch.stack(all_estimator_outputs, dim=0)
260-
261261
trajectories = Trajectories(
262262
env=env,
263263
states=trajectories_states,

testing/test_samplers_and_trajectories.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
from typing import Literal, Tuple
22

33
import pytest
4+
import torch
5+
from tensordict import TensorDict
6+
from torch import nn
7+
from torch_geometric.nn import GCNConv
48

9+
from gfn.actions import GraphActionType
510
from gfn.containers import Trajectories
611
from gfn.containers.replay_buffer import ReplayBuffer
712
from gfn.gym import Box, DiscreteEBM, HyperGrid
13+
from gfn.gym.graph_building import GraphBuilding
814
from gfn.gym.helpers.box_utils import BoxPBEstimator, BoxPBMLP, BoxPFEstimator, BoxPFMLP
915
from gfn.modules import DiscretePolicyEstimator, GFNModule, GraphActionPolicyEstimator
1016
from gfn.samplers import LocalSearchSampler, Sampler

0 commit comments

Comments
 (0)