From ccefd863f38c474f045eb1c70ea37e5728cc6aeb Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Tue, 8 Oct 2024 09:21:32 -0600 Subject: [PATCH] ruff & mypy --- src/gflownet/algo/graph_sampling.py | 34 ++++++++++++--------- src/gflownet/algo/local_search_tb.py | 3 -- src/gflownet/algo/trajectory_balance.py | 4 --- src/gflownet/data/data_source.py | 11 +------ src/gflownet/data/replay_buffer.py | 10 +++--- src/gflownet/models/graph_transformer.py | 2 +- src/gflownet/online_trainer.py | 4 +-- src/gflownet/trainer.py | 4 +-- src/gflownet/utils/multiprocessing_proxy.py | 6 ++-- 9 files changed, 33 insertions(+), 45 deletions(-) diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 62b592f0..60b412a2 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -136,16 +136,17 @@ def sample_from_model(self, model: nn.Module, n: int, cond_info: Optional[Tensor # If we're not bootstrapping, we could query the reward # model here, but this is expensive/impractical. Instead # just report forward and backward logprobs + # TODO: stop using dicts and used typed objects data[i]["fwd_logprobs"] = torch.stack(data[i]["fwd_logprobs"]).reshape(-1) data[i]["U_bck_logprobs"] = torch.stack(data[i]["U_bck_logprobs"]).reshape(-1) - data[i]["fwd_logprob"] = data[i]["fwd_logprobs"].sum() - data[i]["U_bck_logprob"] = data[i]["U_bck_logprobs"].sum() + data[i]["fwd_logprob"] = data[i]["fwd_logprobs"].sum() # type: ignore + data[i]["U_bck_logprob"] = data[i]["U_bck_logprobs"].sum() # type: ignore data[i]["result"] = graphs[i] if self.pad_with_terminal_state: - data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Pad))) + data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Pad))) # type: ignore data[i]["U_bck_logprobs"] = torch.cat([data[i]["U_bck_logprobs"], torch.tensor([0.0], device=dev)]) - data[i]["is_sink"].append(1) - assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"]) + data[i]["is_sink"].append(1) # type: ignore + assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"]) # type: ignore return data def sample_backward_from_graphs( @@ -198,16 +199,19 @@ def sample_backward_from_graphs( for i in range(n): # See comments in sample_from_model - data[i]["traj"] = data[i]["traj"][::-1] + # TODO: stop using dicts and used typed objects + data[i]["traj"] = data[i]["traj"][::-1] # type: ignore # I think this pad is only necessary if we're padding terminal states??? - data[i]["bck_a"] = [GraphAction(GraphActionType.Pad)] + data[i]["bck_a"][::-1] - data[i]["is_sink"] = data[i]["is_sink"][::-1] - data[i]["U_bck_logprobs"] = torch.tensor([0] + data[i]["U_bck_logprobs"][::-1], device=dev).reshape(-1) + data[i]["bck_a"] = [GraphAction(GraphActionType.Pad)] + data[i]["bck_a"][::-1] # type: ignore + data[i]["is_sink"] = data[i]["is_sink"][::-1] # type: ignore + data[i]["U_bck_logprobs"] = torch.tensor( + [0] + data[i]["U_bck_logprobs"][::-1], device=dev # type: ignore + ).reshape(-1) if self.pad_with_terminal_state: - data[i]["traj"].append((starting_graphs[i], GraphAction(GraphActionType.Pad))) + data[i]["traj"].append((starting_graphs[i], GraphAction(GraphActionType.Pad))) # type: ignore data[i]["U_bck_logprobs"] = torch.cat([data[i]["U_bck_logprobs"], torch.tensor([0.0], device=dev)]) - data[i]["is_sink"].append(1) - assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"]) + data[i]["is_sink"].append(1) # type: ignore + assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"]) # type: ignore return data def local_search_sample_from_model( @@ -249,7 +253,7 @@ def local_search_sample_from_model( ] # type: ignore graphs = [i["traj"][-1][0] for i in current_trajs] done = [False] * n - fwd_a = [] + fwd_a: List[GraphAction] = [] for i in range(cfg.num_bck_steps): # This modifies `bck_trajs` & `graphs` in place, passing fwd_a computes P_F(s|s') for the previous step self._backward_step(model, bck_trajs, graphs, cond_info, done, dev, fwd_a) @@ -264,7 +268,7 @@ def local_search_sample_from_model( {"traj": [], "bck_a": [], "is_sink": [], "bck_logprobs": [], "fwd_logprobs": []} for _ in current_trajs ] # type: ignore done = [False] * n - bck_a = [] + bck_a: List[GraphAction] = [] while not all(done): self._forward_step(model, fwd_trajs, graphs, cond_info, 0, done, rng, dev, random_action_prob, bck_a) done = [d or (len(t["traj"]) + T) >= self.max_len for d, t, T in zip(done, fwd_trajs, trunc_lens)] @@ -281,7 +285,7 @@ def local_search_sample_from_model( sampled_terminals.extend(terminals) for traj, term in zip(fwd_trajs, terminals): traj["result"] = term - traj["is_accept"] = False + traj["is_accept"] = False # type: ignore # Compute rewards for the acceptance if compute_reward is not None: compute_reward(fwd_trajs, cond_info) diff --git a/src/gflownet/algo/local_search_tb.py b/src/gflownet/algo/local_search_tb.py index 4c8f38c9..6bd65fcd 100644 --- a/src/gflownet/algo/local_search_tb.py +++ b/src/gflownet/algo/local_search_tb.py @@ -1,9 +1,6 @@ -import torch - from gflownet import GFNTask from gflownet.algo.trajectory_balance import TrajectoryBalance from gflownet.data.data_source import DataSource -from gflownet.utils.misc import get_worker_device class LocalSearchTB(TrajectoryBalance): diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 8e9d3679..fd23db73 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -620,10 +620,6 @@ def compute_batch_losses( if self.cfg.mle_loss_multiplier != 0: info["mle_loss"] = mle_loss.item() - if not torch.isfinite(loss): - import pdb - - pdb.set_trace() return loss, info def analytical_maxent_backward(self, batch, first_graph_idx): diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index bad06268..8b945d05 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -1,3 +1,4 @@ +import traceback import warnings from typing import Callable, Generator, List, Optional @@ -92,19 +93,9 @@ def __iter__(self): raise e print(f"Error in DataSource: {e} [tol={self._err_tol}]") # print full traceback - import sys - import traceback traceback.print_exc() continue - except: - print("Unknown error in DataSource") - import sys - import traceback - - traceback.print_exc() - self._err_tol -= 1 - continue def validate_batch(self, batch, trajs): for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + ( diff --git a/src/gflownet/data/replay_buffer.py b/src/gflownet/data/replay_buffer.py index 62b791d2..197258b2 100644 --- a/src/gflownet/data/replay_buffer.py +++ b/src/gflownet/data/replay_buffer.py @@ -1,6 +1,6 @@ import heapq from threading import Lock -from typing import List +from typing import Any, List import numpy as np import torch @@ -15,8 +15,8 @@ def __init__(self, cfg: Config): Replay buffer for storing and sampling arbitrary data (e.g. transitions or trajectories) In self.push(), the buffer detaches any torch tensor and sends it to the CPU. """ - self.capacity = cfg.replay.capacity - self.warmup = cfg.replay.warmup + self.capacity = cfg.replay.capacity or int(1e6) + self.warmup = cfg.replay.warmup or 0 assert self.warmup <= self.capacity, "ReplayBuffer warmup must be smaller than capacity" self.buffer: List[tuple] = [] @@ -24,7 +24,7 @@ def __init__(self, cfg: Config): self.treat_as_heap = cfg.replay.keep_highest_rewards self.filter_uniques = cfg.replay.keep_only_uniques - self._uniques = set() + self._uniques: set[Any] = set() self._lock = Lock() @@ -56,7 +56,7 @@ def push(self, *args, unique_obj=None, priority=None): self._uniques.add(unique_obj) else: if len(self.buffer) < self.capacity: - self.buffer.append(None) + self.buffer.append(()) if self.filter_uniques: if self.position == 0 and len(self.buffer) == self.capacity: # We're about to wrap around, so remove the oldest element diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index e5a35996..3d5f91b7 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -310,7 +310,7 @@ def _make_cat(self, g: gd.Batch, emb: Dict[str, Tensor], action_types: list[Grap sc = self.logit_scaler( g.cond_info if g.cond_info is not None else torch.ones((g.num_graphs, 1), device=g.x.device) ) - cat.logits = [l * sc[b] for l, b in zip(cat.raw_logits, cat.batch)] # Setting .logits masks them + cat.logits = [lg * sc[b] for lg, b in zip(cat.raw_logits, cat.batch)] # Setting .logits masks them return cat def forward(self, g: gd.Batch): diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 52636d9a..c8bebba0 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -1,6 +1,6 @@ import copy -import os import pathlib +from typing import Any import git import torch @@ -137,7 +137,7 @@ def setup(self): def step(self, loss: Tensor, train_it: int): loss.backward() - info = {} + info: dict[str, Any] = {} if train_it % self.cfg.algo.grad_acc_steps != 0: return info if self.cfg.opt.clip_grad_type is not None: diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 777e1a0f..3321c32e 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -272,11 +272,11 @@ def _send_models_to_device(self): self.model.to(self.device) self.sampling_model.to(self.device) if self.world_size > 1: - self.model = DistributedDataParallel( + self.model = nn.parallel.DistributedDataParallel( self.model.to(self.rank), device_ids=[self.rank], output_device=self.rank ) if self.sampling_model is not self.model: - self.sampling_model = DistributedDataParallel( + self.sampling_model = nn.parallel.DistributedDataParallel( self.sampling_model.to(self.rank), device_ids=[self.rank], output_device=self.rank ) diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index d011b64a..13559514 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -268,8 +268,8 @@ def run(self): break timeouts = 0 attr, args, kwargs = r - if hasattr(self.obj, "lock"): - f.lock.acquire() + if hasattr(self.obj, "lock"): # TODO: this is not used anywhere? + self.obj.lock.acquire() f = getattr(self.obj, attr) args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args] kwargs = {k: i.to(self.device) if isinstance(i, self.cuda_types) else i for k, i in kwargs.items()} @@ -293,7 +293,7 @@ def run(self): msg = self.to_cpu(result) self.out_queues[qi].put(self.encode(msg)) if hasattr(self.obj, "lock"): - f.lock.release() + self.obj.lock.release() def terminate(self): self.stop.set()