Skip to content

Commit 63e3981

Browse files
committed
switched all tqdm pbars from context manager to manually closing
1 parent 00a2526 commit 63e3981

File tree

5 files changed

+118
-107
lines changed

5 files changed

+118
-107
lines changed

apax/bal/api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from functools import partial
22
from typing import List, Union
3-
from ase import Atoms
43

54
import jax
65
import numpy as np
6+
from ase import Atoms
77
from click import Path
88
from tqdm import trange
99

@@ -55,7 +55,7 @@ def compute_features(feature_fn, dataset: TFPipeline, processing_batch_size: int
5555
features.append(np.asarray(g))
5656
pbar.update(g.shape[0])
5757
pbar.close()
58-
58+
5959
features = np.concatenate(features, axis=0)
6060
return features
6161

@@ -67,8 +67,8 @@ def kernel_selection(
6767
base_fm_options: dict,
6868
selection_method: str,
6969
feature_transforms: list = [],
70-
selection_batch_size: int =10,
71-
processing_batch_size: int =64,
70+
selection_batch_size: int = 10,
71+
processing_batch_size: int = 64,
7272
):
7373
n_models = 1 if isinstance(model_dir, (Path, str)) else len(model_dir)
7474
is_ensemble = n_models > 1

apax/data/preprocessing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,14 @@ def dataset_neighborlist(
7878
r_max,
7979
)
8080

81-
nl_pbar = trange(len(positions), desc="Precomputing NL", ncols=100, mininterval=0.25, disable=disable_pbar, leave=True)
81+
nl_pbar = trange(
82+
len(positions),
83+
desc="Precomputing NL",
84+
ncols=100,
85+
mininterval=0.25,
86+
disable=disable_pbar,
87+
leave=True,
88+
)
8289
for i, position in enumerate(positions):
8390
if np.all(box[i] < 1e-6):
8491
position = jnp.asarray(position)

apax/md/nvt.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -233,34 +233,35 @@ def body_fn(i, state):
233233
start = time.time()
234234
sim_time = n_outer * ensemble.dt # * units.fs
235235
log.info("running nvt for %.1f fs", sim_time)
236-
with trange(
236+
sim_pbar = trange(
237237
0, n_steps, desc="Simulation", ncols=100, disable=disable_pbar, leave=True
238-
) as sim_pbar:
239-
while step < n_outer:
240-
new_state, neighbor, current_temperature = sim(state, neighbor)
241-
242-
if neighbor.did_buffer_overflow:
243-
log.info("step %d: neighbor list overflowed, reallocating.", step)
244-
traj_handler.reset_buffer()
245-
neighbor = neighbor_fn.allocate(
246-
state.position
247-
) # TODO check that this actually works
248-
else:
249-
state = new_state
250-
step += 1
251-
252-
if np.any(np.isnan(state.position)) or np.any(np.isnan(state.velocity)):
253-
raise ValueError(
254-
f"NaN encountered, simulation aborted after {step} steps."
255-
)
256-
257-
if step % checkpoint_interval == 0:
258-
log.info("saving checkpoint at step: %d", step)
259-
log.info("checkpoints not yet implemented")
260-
261-
if step % pbar_update_freq == 0:
262-
sim_pbar.set_postfix(T=f"{(current_temperature):.1f} K") # set string
263-
sim_pbar.update(pbar_increment)
238+
)
239+
while step < n_outer:
240+
new_state, neighbor, current_temperature = sim(state, neighbor)
241+
242+
if neighbor.did_buffer_overflow:
243+
log.info("step %d: neighbor list overflowed, reallocating.", step)
244+
traj_handler.reset_buffer()
245+
neighbor = neighbor_fn.allocate(
246+
state.position
247+
) # TODO check that this actually works
248+
else:
249+
state = new_state
250+
step += 1
251+
252+
if np.any(np.isnan(state.position)) or np.any(np.isnan(state.velocity)):
253+
raise ValueError(
254+
f"NaN encountered, simulation aborted after {step} steps."
255+
)
256+
257+
if step % checkpoint_interval == 0:
258+
log.info("saving checkpoint at step: %d", step)
259+
log.info("checkpoints not yet implemented")
260+
261+
if step % pbar_update_freq == 0:
262+
sim_pbar.set_postfix(T=f"{(current_temperature):.1f} K") # set string
263+
sim_pbar.update(pbar_increment)
264+
sim_pbar.close()
264265

265266
barrier_wait()
266267
traj_handler.write()

apax/train/eval.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,19 @@ def predict(model, params, Metrics, loss_fn, test_ds, callbacks):
8585

8686
epoch_loss.update({"test_loss": 0.0})
8787
test_metrics = Metrics.empty()
88-
with trange(
88+
89+
batch_pbar = trange(
8990
0, test_steps_per_epoch, desc="Batches", ncols=100, disable=False, leave=True
90-
) as batch_pbar:
91-
for batch_idx in range(test_steps_per_epoch):
92-
inputs, labels = next(batch_test_ds)
91+
)
92+
for batch_idx in range(test_steps_per_epoch):
93+
inputs, labels = next(batch_test_ds)
9394

94-
test_metrics, batch_loss = test_step_fn(params, inputs, labels, test_metrics)
95+
test_metrics, batch_loss = test_step_fn(params, inputs, labels, test_metrics)
9596

96-
epoch_loss["test_loss"] += batch_loss
97-
batch_pbar.set_postfix(test_loss=epoch_loss["test_loss"] / batch_idx)
98-
batch_pbar.update()
97+
epoch_loss["test_loss"] += batch_loss
98+
batch_pbar.set_postfix(test_loss=epoch_loss["test_loss"] / batch_idx)
99+
batch_pbar.update()
100+
batch_pbar.close()
99101

100102
epoch_loss["test_loss"] /= test_steps_per_epoch
101103
epoch_loss["test_loss"] = float(epoch_loss["test_loss"])

apax/train/trainer.py

Lines changed: 68 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -53,83 +53,84 @@ def fit(
5353
best_loss = np.inf
5454
early_stopping_counter = 0
5555
epoch_loss = {}
56-
with trange(
56+
epoch_pbar = trange(
5757
start_epoch, n_epochs, desc="Epochs", ncols=100, disable=disable_pbar, leave=True
58-
) as epoch_pbar:
59-
for epoch in range(start_epoch, n_epochs):
60-
epoch_start_time = time.time()
61-
callbacks.on_epoch_begin(epoch=epoch + 1)
62-
63-
epoch_loss.update({"train_loss": 0.0})
64-
train_batch_metrics = Metrics.empty()
65-
66-
for batch_idx in range(train_steps_per_epoch):
67-
callbacks.on_train_batch_begin(batch=batch_idx)
68-
69-
inputs, labels = next(batch_train_ds)
70-
train_batch_metrics, batch_loss, state = train_step(
71-
state, inputs, labels, train_batch_metrics
72-
)
73-
74-
epoch_loss["train_loss"] += batch_loss
75-
callbacks.on_train_batch_end(batch=batch_idx)
76-
77-
epoch_loss["train_loss"] /= train_steps_per_epoch
78-
epoch_loss["train_loss"] = float(epoch_loss["train_loss"])
79-
80-
epoch_metrics = {
81-
f"train_{key}": float(val)
82-
for key, val in train_batch_metrics.compute().items()
83-
}
84-
85-
if val_ds is not None:
86-
epoch_loss.update({"val_loss": 0.0})
87-
val_batch_metrics = Metrics.empty()
88-
for batch_idx in range(val_steps_per_epoch):
89-
inputs, labels = next(batch_val_ds)
90-
91-
val_batch_metrics, batch_loss = val_step(
92-
state.params, inputs, labels, val_batch_metrics
93-
)
94-
epoch_loss["val_loss"] += batch_loss
95-
96-
epoch_loss["val_loss"] /= val_steps_per_epoch
97-
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])
58+
)
59+
for epoch in range(start_epoch, n_epochs):
60+
epoch_start_time = time.time()
61+
callbacks.on_epoch_begin(epoch=epoch + 1)
9862

99-
epoch_metrics.update(
100-
{
101-
f"val_{key}": float(val)
102-
for key, val in val_batch_metrics.compute().items()
103-
}
104-
)
63+
epoch_loss.update({"train_loss": 0.0})
64+
train_batch_metrics = Metrics.empty()
10565

106-
epoch_metrics.update({**epoch_loss})
66+
for batch_idx in range(train_steps_per_epoch):
67+
callbacks.on_train_batch_begin(batch=batch_idx)
10768

108-
epoch_end_time = time.time()
109-
epoch_metrics.update({"epoch_time": epoch_end_time - epoch_start_time})
69+
inputs, labels = next(batch_train_ds)
70+
train_batch_metrics, batch_loss, state = train_step(
71+
state, inputs, labels, train_batch_metrics
72+
)
11073

111-
ckpt = {"model": state, "epoch": epoch}
112-
if epoch % ckpt_interval == 0:
113-
ckpt_manager.save_checkpoint(ckpt, epoch, latest_dir)
74+
epoch_loss["train_loss"] += batch_loss
75+
callbacks.on_train_batch_end(batch=batch_idx)
11476

115-
if epoch_metrics["val_loss"] < best_loss:
116-
best_loss = epoch_metrics["val_loss"]
117-
ckpt_manager.save_checkpoint(ckpt, epoch, best_dir)
118-
early_stopping_counter = 0
119-
else:
120-
early_stopping_counter += 1
77+
epoch_loss["train_loss"] /= train_steps_per_epoch
78+
epoch_loss["train_loss"] = float(epoch_loss["train_loss"])
12179

122-
callbacks.on_epoch_end(epoch=epoch, logs=epoch_metrics)
80+
epoch_metrics = {
81+
f"train_{key}": float(val)
82+
for key, val in train_batch_metrics.compute().items()
83+
}
12384

124-
epoch_pbar.set_postfix(val_loss=epoch_metrics["val_loss"])
125-
epoch_pbar.update()
85+
if val_ds is not None:
86+
epoch_loss.update({"val_loss": 0.0})
87+
val_batch_metrics = Metrics.empty()
88+
for batch_idx in range(val_steps_per_epoch):
89+
inputs, labels = next(batch_val_ds)
12690

127-
if patience is not None and early_stopping_counter >= patience:
128-
log.info(
129-
"Early stopping patience exceeded. Stopping training after"
130-
f" {epoch} epochs."
91+
val_batch_metrics, batch_loss = val_step(
92+
state.params, inputs, labels, val_batch_metrics
13193
)
132-
break
94+
epoch_loss["val_loss"] += batch_loss
95+
96+
epoch_loss["val_loss"] /= val_steps_per_epoch
97+
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])
98+
99+
epoch_metrics.update(
100+
{
101+
f"val_{key}": float(val)
102+
for key, val in val_batch_metrics.compute().items()
103+
}
104+
)
105+
106+
epoch_metrics.update({**epoch_loss})
107+
108+
epoch_end_time = time.time()
109+
epoch_metrics.update({"epoch_time": epoch_end_time - epoch_start_time})
110+
111+
ckpt = {"model": state, "epoch": epoch}
112+
if epoch % ckpt_interval == 0:
113+
ckpt_manager.save_checkpoint(ckpt, epoch, latest_dir)
114+
115+
if epoch_metrics["val_loss"] < best_loss:
116+
best_loss = epoch_metrics["val_loss"]
117+
ckpt_manager.save_checkpoint(ckpt, epoch, best_dir)
118+
early_stopping_counter = 0
119+
else:
120+
early_stopping_counter += 1
121+
122+
callbacks.on_epoch_end(epoch=epoch, logs=epoch_metrics)
123+
124+
epoch_pbar.set_postfix(val_loss=epoch_metrics["val_loss"])
125+
epoch_pbar.update()
126+
127+
if patience is not None and early_stopping_counter >= patience:
128+
log.info(
129+
"Early stopping patience exceeded. Stopping training after"
130+
f" {epoch} epochs."
131+
)
132+
break
133+
epoch_pbar.close()
133134
callbacks.on_train_end()
134135

135136

0 commit comments

Comments
 (0)