-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Kliff master v1 lightning #182
Conversation
…r stress in training loss
…ckpt resume capabilities + docstrings
optimizer_name: Name of the optimizer to use. Default is "Adam" | ||
lr: Learning rate for the optimizer. Default is 0.001 | ||
energy_weight: Weight for the energy loss. Default is 1.0 | ||
forces_weight: Weight for the forces loss. Default is 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about stress_weight
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: once the stress issue is fixed.
+ forces_weight * torch.mean(per_atom_force_loss) / 3 | ||
) # divide by 3 to get correct MSE | ||
|
||
self.log( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides a sum of the loss, people would typically be more interested in separate losses on energy and forces.
Also, people can also be interested in metrics like MAE other than MSE.
Most importantly, it is not interesting to know losses/metrics of a validation step for a batch of data. The losses/metrics at each epoch for all data matter. This requires aggregating the data between steps. I'd suggest using torchmetrics, which automatically does it.
So, we need to report separate losses, and allow other metrics, e.g. using the trochmetrics package.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Torch metrics looks great. I will add support for it.
Can you explain this a bit:
Most importantly, it is not interesting to know losses/metrics of a validation step for a batch of data. The losses/metrics at each epoch for all data matter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, the validation loss/metrics is evaluated and reported for each mini batch of data, which is only part of all the validation data. But people generally would be interested in the loss/metrics for all validation/test data, not each mini batch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think so. In logging when we give on_epoch=True
and on_step=False
, then I believe the logger will call after_batch_end
function to simply accumulate the loss and only log it in after_epoch_end
function. In the log files too I could see that validation losses are logged per epoch, and not per batch. This is similar to what I am doing in the loss traj callback, where on_validation_batch_end
just gathers the result and on_validation_epoch_end
writes them to a file. I can recheck the API but I am quite certain that is the case.
@ipcamit Nothing major, but a few clarifying questions and minor tweaks. |
Addressed most comments. Need some more time and meeting with Josh for Loss trajectory concretization. Will address the descriptor dataloader issues with torch trainer which uses the descriptor module, so it will be easier to see the design choices. |
Summary
Added base trainer and lightning trainer. Along with new tests to test the trainer module
Additional dependencies introduced (if any)
TODO (if any)
Add KIM Trainer next
Checklist
Before a pull request can be merged, the following items must be checked:
type check your code.
Note that the CI system will run all the above checks. But it will be much more
efficient if you already fix most errors prior to submitting the PR.