Skip to content

Commit

Permalink
remove useless pyright ignore
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Mar 6, 2025
1 parent a6cd815 commit 06cecfc
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 16 deletions.
14 changes: 10 additions & 4 deletions src/gfn/gym/graph_building.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Literal, Tuple
from typing import Callable, Literal, Optional, Tuple

import torch
from torch_geometric.data import Batch as GeometricBatch
Expand Down Expand Up @@ -52,9 +52,15 @@ def __init__(
device_str=device_str,
)

def reset(self, batch_shape: Tuple | int) -> GraphStates:
def reset(
self,
batch_shape: int | Tuple[int, ...],
random: bool = False,
sink: bool = False,
seed: Optional[int] = None,
) -> GraphStates:
"""Reset the environment to a new batch of graphs."""
states = super().reset(batch_shape)
states = super().reset(batch_shape, random, sink, seed)
assert isinstance(states, GraphStates)
return states

Expand Down Expand Up @@ -153,7 +159,7 @@ def backward_step(
dtype=torch.bool,
device=graph.x.device,
)
mask[node_idx] = False # pyright: ignore
mask[node_idx] = False

# Update node features
graph.x = graph.x[mask]
Expand Down
8 changes: 4 additions & 4 deletions testing/test_graph_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def test_getitem_2d(datas):
assert torch.allclose(batch_row.log_rewards, tsr[0])

# Try again with slicing
tsr_row2 = tsr[0, [0, 1]]
batch_row2 = states[0, [0, 1]] # pyright: ignore # TODO: Fix pyright issue
tsr_row2 = tsr[0, :]
batch_row2 = states[0, :] # pyright: ignore # TODO: Fix pyright issue
assert tuple(tsr_row2.shape) == batch_row2.tensor.batch_shape == (2,)
assert torch.equal(batch_row.tensor.x, batch_row2.tensor.x)

Expand Down Expand Up @@ -164,7 +164,7 @@ def test_setitem_1d(datas):
assert states.tensor.batch_shape == (3,) # Batch shape should not change

# Set the new graph in the second and third positions
states[[1, 2]] = new_states
states[1:] = new_states # pyright: ignore # TODO: Fix pyright issue

# Check that the second and third graphs are now the new graph
second_graph = states[1].tensor
Expand All @@ -182,7 +182,7 @@ def test_setitem_1d(datas):
with pytest.raises(AssertionError):
states[0] = new_states
with pytest.raises(AssertionError):
states[[1, 2]] = new_states[0]
states[1:] = new_states[0] # pyright: ignore


def test_setitem_2d(datas):
Expand Down
10 changes: 4 additions & 6 deletions tutorials/examples/train_graph_ring.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,7 @@ def forward(self, states_tensor: GeometricBatch) -> torch.Tensor:
trajectories = gflownet.sample_trajectories(
env,
n=BATCH_SIZE,
save_logprobs=True, # pyright: ignore
save_logprobs=True,
epsilon=0.2 * (1 - iteration / N_ITERATIONS),
)
training_samples = gflownet.to_training_samples(trajectories)
Expand All @@ -940,14 +940,12 @@ def forward(self, states_tensor: GeometricBatch) -> torch.Tensor:
with torch.no_grad():
replay_buffer.add(training_samples)
if iteration > 20:
training_samples = training_samples[
: BATCH_SIZE // 2
] # pyright: ignore
training_samples = training_samples[: BATCH_SIZE // 2]
buffer_samples = replay_buffer.sample(n_trajectories=BATCH_SIZE // 2)
training_samples.extend(buffer_samples) # pyright: ignore
training_samples.extend(buffer_samples)

optimizer.zero_grad()
loss = gflownet.loss(env, training_samples) # pyright: ignore
loss = gflownet.loss(env, training_samples)
pct_rings = torch.mean(rewards > 0.1, dtype=torch.float) * 100
print(
"Iteration {} - Loss: {:.02f}, rings: {:.0f}%".format(
Expand Down
2 changes: 1 addition & 1 deletion tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def main(args): # noqa: C901
output_dim=1,
hidden_dim=args.hidden_dim,
n_hidden_layers=args.n_hidden,
trunk=pf_module.trunk if args.tied else None, # pyright: ignore
trunk=pf_module.trunk if args.tied else None,
)

logF_estimator = ScalarEstimator(
Expand Down
2 changes: 1 addition & 1 deletion tutorials/examples/train_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def to_probability_distribution(
locs, scales = torch.split(module_output, [1, 1], dim=-1)

return ScaledGaussianWithOptionalExit(
states, # pyright: ignore
states,
locs,
scales + scale_factor, # Increase this value to induce exploration.
backward=self.backward,
Expand Down

0 comments on commit 06cecfc

Please sign in to comment.