Skip to content

Commit

Permalink
fix: dropout should init primitive (#2788)
Browse files Browse the repository at this point in the history
  • Loading branch information
i8run authored Mar 26, 2019
1 parent 1298c80 commit 438861c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Dropout(
_gradOutputFormats = grad.map(x => HeapData(x.shape, format(x.shape)))
_gradOutputFormatsForWeight = grad.map(x => HeapData(x.shape, format(x.shape)))
_gradInputFormats = grad.map(x => HeapData(x.shape, format(x.shape)))
_gradInputFormats.map(_.getPrimitive(runtime))
gradInput = initTensor(_gradInputFormats.head)
(_gradOutputFormats, _gradInputFormats)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,18 @@ class DropoutSpec extends FlatSpec with Matchers {
val ratio = notEqZeros.toDouble / total
ratio should be (1.0)
}

"dropout in sequential" should "work correctly" in {
val shape = Array(2, 3, 4, 4)
val dropout = Dropout()
val seq = Sequential().add(Input(shape, Memory.Format.nchw))
.add(dropout)
.add(Output(Memory.Format.nchw))

seq.compile(TrainingPhase)

val input = Tensor[Float](shape).rand(-1, 1)
seq.forward(input)
seq.backward(input, seq.output)
}
}

0 comments on commit 438861c

Please sign in to comment.