Skip to content

Commit

Permalink
Optimise Wengert tape for LSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
HuwCampbell committed Dec 16, 2017
1 parent 4332d62 commit a588151
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 88 deletions.
41 changes: 1 addition & 40 deletions bench/bench-lstm.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,22 @@ import Criterion.Main

import Grenade
import Grenade.Recurrent
import Grenade.Layers.Internal.Update

import qualified Numeric.LinearAlgebra.Static as H

main :: IO ()
main = do
layer60 :: LSTM 40 60 <- createRandom
layer512 :: LSTM 40 512 <- createRandom
input40 :: S ('D1 40) <- randomOfShape
rec60 :: S ('D1 60) <- randomOfShape
rec512 :: S ('D1 512) <- randomOfShape
lstm :: RecNet <- randomRecurrent

let upIn60 :: H.R 3600 = H.randomVector 1 H.Uniform * 2 - 1
let upIn512 :: H.R 262144 = H.randomVector 1 H.Uniform * 2 - 1

defaultMain [
bgroup "lstm" [ bench "forwards-60" $ nf (nfT3 . uncurry (testRun60 layer60)) (rec60, input40)
, bench "forwards-512" $ nf (nfT3 . uncurry (testRun512 layer512)) (rec512, input40)
, bench "backwards-60" $ nf (nfT3 . uncurry4 (testRun60' layer60)) (rec60, input40, rec60, rec60)
, bench "backwards-512" $ nf (nfT3 . uncurry4 (testRun512' layer512)) (rec512, input40, rec512, rec512)
]
, bgroup "update" [ bench "matrix-60x60" $ nf (uncurry3 (descendVector 1 1 1)) (upIn60, upIn60, upIn60)
, bench "matrix-512x512" $ nf (uncurry3 (descendVector 1 1 1)) (upIn512, upIn512, upIn512)
]
, bgroup "train" [ bench "one-time-step" $ whnf (nfT2 . trainRecurrent lp lstm 0) [(input40, Just input40)]
bgroup "train" [ bench "one-time-step" $ whnf (nfT2 . trainRecurrent lp lstm 0) [(input40, Just input40)]
, bench "ten-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 10 (input40, Just input40)
, bench "fifty-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 50 (input40, Just input40)
]
]

testRun60 :: LSTM 40 60 -> S ('D1 60) -> S ('D1 40) -> ((S ('D1 60), S ('D1 40)), S ('D1 60), S ('D1 60))
testRun60 = runRecurrentForwards

testRun60' :: LSTM 40 60 -> S ('D1 60) -> S ('D1 40) -> S ('D1 60) -> S ('D1 60) -> (Gradient (LSTM 40 60), S ('D1 60), S ('D1 40))
testRun60' = curry . runRecurrentBackwards

testRun512 :: LSTM 40 512 -> S ('D1 512) -> S ('D1 40) -> ((S ('D1 512), S ('D1 40)), S ('D1 512), S ('D1 512))
testRun512 = runRecurrentForwards

testRun512' :: LSTM 40 512 -> S ('D1 512) -> S ('D1 40) -> S ('D1 512) -> S ('D1 512) -> (Gradient (LSTM 40 512), S ('D1 512), S ('D1 40))
testRun512' = curry . runRecurrentBackwards

uncurry4 :: (t -> t1 -> t2 -> t3 -> t4) -> (t, t1, t2, t3) -> t4
uncurry4 f (a,b,c,d) = f a b c d

uncurry3 :: (t -> t1 -> t2 -> t3) -> (t, t1, t2) -> t3
uncurry3 f (a,b,c) = f a b c

nfT2 :: (a, b) -> (a, b)
nfT2 (!a, !b) = (a, b)

nfT3 :: (a, b, c) -> (b, c)
nfT3 (!_, !b, !c) = (b, c)


type R = Recurrent
type RecNet = RecurrentNetwork '[ R (LSTM 40 512), R (LSTM 512 40) ]
Expand Down
6 changes: 3 additions & 3 deletions src/Grenade/Layers/Elu.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ instance Serialize Elu where
get = return Elu

instance ( KnownNat i) => Layer Elu ('D1 i) ('D1 i) where
type Tape Elu ('D1 i) ('D1 i) = S ('D1 i)
type Tape Elu ('D1 i) ('D1 i) = LAS.R i

runForwards _ (S1D y) = (S1D y, S1D (elu y))
runForwards _ (S1D y) = (y, S1D (elu y))
where
elu = LAS.dvmap (\a -> if a <= 0 then exp a - 1 else a)
runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (elu' y * dEdy))
runBackwards _ y (S1D dEdy) = ((), S1D (elu' y * dEdy))
where
elu' = LAS.dvmap (\a -> if a <= 0 then exp a else 1)

Expand Down
6 changes: 3 additions & 3 deletions src/Grenade/Layers/FullyConnected.hs
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where
createRandom = randomFullyConnected

instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o) where
type Tape (FullyConnected i o) ('D1 i) ('D1 o) = S ('D1 i)
type Tape (FullyConnected i o) ('D1 i) ('D1 o) = R i
-- Do a matrix vector multiplication and return the result.
runForwards (FullyConnected (FullyConnected' wB wN) _) (S1D v) = (S1D v, S1D (wB + wN #> v))
runForwards (FullyConnected (FullyConnected' wB wN) _) (S1D v) = (v, S1D (wB + wN #> v))

-- Run a backpropogation step for a full connected layer.
runBackwards (FullyConnected (FullyConnected' _ wN) _) (S1D x) (S1D dEdy) =
runBackwards (FullyConnected (FullyConnected' _ wN) _) x (S1D dEdy) =
let wB' = dEdy
mm' = dEdy `outer` x
-- calcluate derivatives for next step
Expand Down
24 changes: 15 additions & 9 deletions src/Grenade/Layers/Logit.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,24 @@ instance UpdateLayer Logit where
createRandom = return Logit

instance (a ~ b, SingI a) => Layer Logit a b where
-- Wengert tape optimisation:
--
-- Derivative of the sigmoid function is
-- d σ(x) / dx = σ(x) • (1 - σ(x))
-- but we have already calculated σ(x) in
-- the forward pass, so just store that
-- and use it in the backwards pass.
type Tape Logit a b = S a
runForwards _ a = (a, logistic a)
runBackwards _ a g = ((), logistic' a * g)
runForwards _ a =
let l = sigmoid a
in (l, l)
runBackwards _ l g =
let sigmoid' = l * (1 - l)
in ((), sigmoid' * g)

instance Serialize Logit where
put _ = return ()
get = return Logit

logistic :: Floating a => a -> a
logistic x = 1 / (1 + exp (-x))

logistic' :: Floating a => a -> a
logistic' x = logix * (1 - logix)
where
logix = logistic x
sigmoid :: Floating a => a -> a
sigmoid x = 1 / (1 + exp (-x))
46 changes: 16 additions & 30 deletions src/Grenade/Recurrent/Layers/LSTM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -127,37 +127,12 @@ instance (KnownNat i, KnownNat o) => RecurrentUpdateLayer (LSTM i o) where

instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) where

type RecTape (LSTM i o) ('D1 i) ('D1 o) = (S ('D1 o), S ('D1 i))
-- The tape stores essentially every variable we calculate,
-- so we don't have to run any forwards component again.
type RecTape (LSTM i o) ('D1 i) ('D1 o) = (R o, R i, R o, R o, R o, R o, R o, R o, R o, R o, R o)
-- Forward propagation for the LSTM layer.
-- The size of the cell state is also the size of the output.
runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) =
let -- Forget state vector
f_t = sigmoid $ lstmBf + lstmWf #> input + lstmUf #> cell
-- Input state vector
i_t = sigmoid $ lstmBi + lstmWi #> input + lstmUi #> cell
-- Output state vector
o_t = sigmoid $ lstmBo + lstmWo #> input + lstmUo #> cell
-- Cell input state vector
c_x = tanh $ lstmBc + lstmWc #> input
-- Cell state
c_t = f_t * cell + i_t * c_x
-- Output (it's sometimes recommended to use tanh c_t)
h_t = o_t * c_t
in ((S1D cell, S1D input), S1D c_t, S1D h_t)

-- Run a backpropogation step for an LSTM layer.
-- We're doing all the derivatives by hand here, so one should
-- be extra careful when changing this.
--
-- There's a test version using the AD library without hmatrix in the test
-- suite. These should match always.
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (S1D cell, S1D input) (S1D cellGrad) (S1D h_t') =
-- We're not keeping the Wengert tape during the forward pass,
-- so we're duplicating some work here.
--
-- If I was being generous, I'd call it checkpointing.
--
-- Maybe think about better ways to store some intermediate states.
let -- Forget state vector
f_s = lstmBf + lstmWf #> input + lstmUf #> cell
f_t = sigmoid f_s
Expand All @@ -172,8 +147,18 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
c_x = tanh c_s
-- Cell state
c_t = f_t * cell + i_t * c_x
-- Output (it's sometimes recommended to use tanh c_t)
h_t = o_t * c_t
in ((cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t), S1D c_t, S1D h_t)

-- Reverse Mode AD Derivitives
-- Run a backpropogation step for an LSTM layer.
-- We're doing all the derivatives by hand here, so one should
-- be extra careful when changing this.
--
-- There's a test version using the AD library without hmatrix in the test
-- suite. These should match always.
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') =
let -- Reverse Mode AD Derivitives
c_t' = h_t' * o_t + cellGrad

f_t' = c_t' * cell
Expand Down Expand Up @@ -235,7 +220,8 @@ randomLSTM = do

-- | Maths
--
-- TODO: move to not here
-- TODO: Move to not here
-- Optimise backwards derivative
sigmoid :: Floating a => a -> a
sigmoid x = 1 / (1 + exp (-x))

Expand Down
12 changes: 9 additions & 3 deletions test/Test/Grenade/Recurrent/Layers/LSTM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ prop_lstm_reference_backwards =
input :: S.R 3 <- forAll randomVector
cell :: S.R 2 <- forAll randomVector
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
let (tape, _ :: S ('D1 2), _ :: S ('D1 2))
= runRecurrentForwards net (S1D cell) (S1D input)
actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
case actualBacks of
(actualGradients, _, _ :: S ('D1 3)) ->
let refNet = Reference.lstmToReference lstmWeights
Expand All @@ -79,7 +81,9 @@ prop_lstm_reference_backwards_input =
input :: S.R 3 <- forAll randomVector
cell :: S.R 2 <- forAll randomVector
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
let (tape, _ :: S ('D1 2), _ :: S ('D1 2))
= runRecurrentForwards net (S1D cell) (S1D input)
actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
case actualBacks of
(_, _, S1D actualGradients :: S ('D1 3)) ->
let refNet = Reference.lstmToReference lstmWeights
Expand All @@ -93,7 +97,9 @@ prop_lstm_reference_backwards_cell =
input :: S.R 3 <- forAll randomVector
cell :: S.R 2 <- forAll randomVector
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
let (tape, _ :: S ('D1 2), _ :: S ('D1 2))
= runRecurrentForwards net (S1D cell) (S1D input)
actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
case actualBacks of
(_, S1D actualGradients, _ :: S ('D1 3)) ->
let refNet = Reference.lstmToReference lstmWeights
Expand Down

0 comments on commit a588151

Please sign in to comment.