Skip to content

Commit

Permalink
Fix select() so that it handles the dictionaries coming from the repl…
Browse files Browse the repository at this point in the history
…ay buffer.
  • Loading branch information
alexhernandezgarcia committed Jan 22, 2024
1 parent 0f43f2a commit a2a2155
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions gflownet/utils/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,11 @@ def compute_stats(data):
def select(data_dict: dict, n: int, mode: str = "permutation", rng=None):
if n == 0:
return []
# TODO: need list()?
samples = data_dict["x"]
# If the data_dict comes from the replay buffer, then samples is a dict and we
# need to keep its values only
if isinstance(samples, dict):
samples = list(samples.values())
if mode == "permutation":
assert rng is not None
samples = [samples[idx] for idx in rng.permutation(n)]
Expand All @@ -285,8 +288,11 @@ def select(data_dict: dict, n: int, mode: str = "permutation", rng=None):
score = "energy"
else:
raise ValueError(f"Data set does not contain reward or energy key.")
# TODO: need fromiter()?
scores = np.fromiter(data_dict[score], dtype=float)
scores = data_dict[score]
# If the data_dict comes from the replay buffer, then scores is a dict and we
# need to keep its values only
if isinstance(scores, dict):
scores = np.fromiter(scores.values(), dtype=float)
indices = np.random.choice(
len(samples),
size=n,
Expand Down

0 comments on commit a2a2155

Please sign in to comment.