Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

usage of trainmode!/testmode! #6

Open
jrevels opened this issue Apr 1, 2020 · 1 comment
Open

usage of trainmode!/testmode! #6

jrevels opened this issue Apr 1, 2020 · 1 comment

Comments

@jrevels
Copy link
Member

jrevels commented Apr 1, 2020

All tests actually still pass for me locally if I remove our manual trainmode!/testmode! usage. Do we really need to manually manipulate trainmode!/testmode!, then, or is Flux's default istraining behavior (which considers BatchNorm et al. to be "active" during forward pass only if forward pass is executed within Zygote) sufficient?

If we do need trainmode!/testmode! manually, we probably want to make sure we still handle it correctly in the event of an exception:

diff --git a/src/LighthouseFlux.jl b/src/LighthouseFlux.jl
index 55ccc2b..d5af80b 100644
--- a/src/LighthouseFlux.jl
+++ b/src/LighthouseFlux.jl
@@ -85,24 +85,27 @@ end

 function Lighthouse.train!(classifier::FluxClassifier, batches, logger)
     Flux.trainmode!(classifier.model)
-    weights = Zygote.Params(classifier.params)
-    for batch in batches
-        train_loss, back = log_resource_info!(logger, "train/forward_pass";
-                                              suffix="_per_batch") do
-            f = () -> loss(classifier.model, batch...)
-            return Zygote.pullback(f, weights)
-        end
-        log_value!(logger, "train/loss_per_batch", train_loss)
-        gradients = log_resource_info!(logger, "train/reverse_pass";
-                                       suffix="_per_batch") do
-            return back(Zygote.sensitivity(train_loss))
-        end
-        log_resource_info!(logger, "train/update"; suffix="_per_batch") do
-            Flux.Optimise.update!(classifier.optimiser, weights, gradients)
-            return nothing
+    try
+        weights = Zygote.Params(classifier.params)
+        for batch in batches
+            train_loss, back = log_resource_info!(logger, "train/forward_pass";
+                                                  suffix="_per_batch") do
+                f = () -> loss(classifier.model, batch...)
+                return Zygote.pullback(f, weights)
+            end
+            log_value!(logger, "train/loss_per_batch", train_loss)
+            gradients = log_resource_info!(logger, "train/reverse_pass";
+                                           suffix="_per_batch") do
+                return back(Zygote.sensitivity(train_loss))
+            end
+            log_resource_info!(logger, "train/update"; suffix="_per_batch") do
+                Flux.Optimise.update!(classifier.optimiser, weights, gradients)
+                return nothing
+            end
         end
+    finally
+        Flux.testmode!(classifier.model)
     end
-    Flux.testmode!(classifier.model)
     return nothing
 end
@anoojpatel
Copy link
Contributor

If its the case that explicit train/testmode! calls don't necessarily help us in the current state of model training, then using the default istraining() mechanism is sufficient.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants