Skip to content

Commit

Permalink
black pyright and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Mar 4, 2025
1 parent e1fad74 commit 5105837
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
8 changes: 6 additions & 2 deletions src/gfn/gym/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def backward_step(

# Remove the node
mask = torch.ones(
size=graph.num_nodes, # pyright: ignore
graph.num_nodes, # pyright: ignore
dtype=torch.bool,
device=graph.x.device,
)
Expand Down Expand Up @@ -219,7 +219,11 @@ def is_action_valid(
src, dst = actions.edge_index[i]

# Check if src and dst are valid node indices
if src >= graph.num_nodes or dst >= graph.num_nodes or src == dst: # pyright: ignore
if (
src >= graph.num_nodes # pyright: ignore
or dst >= graph.num_nodes # pyright: ignore
or src == dst
):
return False

# Check if the edge already exists
Expand Down
4 changes: 3 additions & 1 deletion src/gfn/utils/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def __init__(
self.trunk.hidden_dim = hidden_dim # pyright: ignore
else:
self.trunk = trunk
self.last_layer = nn.Linear(self.trunk.hidden_dim, output_dim) # pyright: ignore
self.last_layer = nn.Linear(
self.trunk.hidden_dim, output_dim
) # pyright: ignore

def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor:
"""Forward method for the neural network.
Expand Down
4 changes: 2 additions & 2 deletions src/gfn/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def validate(

logZ = None
if isinstance(gflownet, TBGFlowNet):
logZ = gflownet.logZ.item() # pyright: ignore
logZ = gflownet.logZ.item() # pyright: ignore
if visited_terminating_states is None:
terminating_states = gflownet.sample_terminating_states(
n_validation_samples
Expand Down Expand Up @@ -188,7 +188,7 @@ def warm_up(
else:
loss = gflownet.loss(env, training_trajs)

loss.backward() # pyright: ignore
loss.backward() # pyright: ignore
optimizer.step()
t.set_description(f"{epoch=}, {loss=}")

Expand Down

0 comments on commit 5105837

Please sign in to comment.