You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
All tests actually still pass for me locally if I remove our manual trainmode!/testmode! usage. Do we really need to manually manipulate trainmode!/testmode!, then, or is Flux's default istraining behavior (which considers BatchNorm et al. to be "active" during forward pass only if forward pass is executed within Zygote) sufficient?
If we do need trainmode!/testmode! manually, we probably want to make sure we still handle it correctly in the event of an exception:
diff --git a/src/LighthouseFlux.jl b/src/LighthouseFlux.jl
index 55ccc2b..d5af80b 100644
--- a/src/LighthouseFlux.jl+++ b/src/LighthouseFlux.jl@@ -85,24 +85,27 @@ end
function Lighthouse.train!(classifier::FluxClassifier, batches, logger)
Flux.trainmode!(classifier.model)
- weights = Zygote.Params(classifier.params)- for batch in batches- train_loss, back = log_resource_info!(logger, "train/forward_pass";- suffix="_per_batch") do- f = () -> loss(classifier.model, batch...)- return Zygote.pullback(f, weights)- end- log_value!(logger, "train/loss_per_batch", train_loss)- gradients = log_resource_info!(logger, "train/reverse_pass";- suffix="_per_batch") do- return back(Zygote.sensitivity(train_loss))- end- log_resource_info!(logger, "train/update"; suffix="_per_batch") do- Flux.Optimise.update!(classifier.optimiser, weights, gradients)- return nothing+ try+ weights = Zygote.Params(classifier.params)+ for batch in batches+ train_loss, back = log_resource_info!(logger, "train/forward_pass";+ suffix="_per_batch") do+ f = () -> loss(classifier.model, batch...)+ return Zygote.pullback(f, weights)+ end+ log_value!(logger, "train/loss_per_batch", train_loss)+ gradients = log_resource_info!(logger, "train/reverse_pass";+ suffix="_per_batch") do+ return back(Zygote.sensitivity(train_loss))+ end+ log_resource_info!(logger, "train/update"; suffix="_per_batch") do+ Flux.Optimise.update!(classifier.optimiser, weights, gradients)+ return nothing+ end
end
+ finally+ Flux.testmode!(classifier.model)
end
- Flux.testmode!(classifier.model)
return nothing
end
The text was updated successfully, but these errors were encountered:
If its the case that explicit train/testmode! calls don't necessarily help us in the current state of model training, then using the default istraining() mechanism is sufficient.
All tests actually still pass for me locally if I remove our manual
trainmode!
/testmode!
usage. Do we really need to manually manipulatetrainmode!
/testmode!
, then, or is Flux's defaultistraining
behavior (which considers BatchNorm et al. to be "active" during forward pass only if forward pass is executed within Zygote) sufficient?If we do need
trainmode!
/testmode!
manually, we probably want to make sure we still handle it correctly in the event of an exception:The text was updated successfully, but these errors were encountered: