Skip to content

Commit

Permalink
fix #526
Browse files Browse the repository at this point in the history
  • Loading branch information
nfeybesse committed Mar 4, 2024
1 parent 88a3f69 commit fa29fc7
Showing 1 changed file with 7 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,12 @@ public class SoftmaxCrossEntropyWithLogits {
* <p>Usage:
*
* <pre>
* Operand&lt;TFloat32&gt; logits =
* tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
* Operand&lt;TFloat32&gt; labels =
* tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
* Operand&lt;TFloat32&gt; output =
* tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
* // output Shape = [2]
* // dataType = FLOAT (1)
* // values { 0.169846, 0.824745 }
* Operand&lt;TFloat32&gt; logits = tf.constant(new float[][] { { 4.0F, 2.0F, 1.0F }, { 0.0F, 5.0F, 1.0F } });
* Operand&lt;TFloat32&gt; labels = tf.constant(new float[][] { { 1.0F, 0.0F, 0.0F }, { 0.0F, 0.8F, 0.2F } });
* Operand&lt;TFloat32&gt; output = tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
* // output Shape = [2]
* // dataType = FLOAT (1)
* // values { 0.169846, 0.824745 }
* </pre>
*
* <p>Backpropagation will happen into both <code>logits</code> and <code>labels</code>. To
Expand Down Expand Up @@ -157,7 +154,7 @@ public static <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntr
* @return the flattened logits
*/
private static <T extends TNumber> Operand<T> flattenOuterDims(Scope scope, Operand<T> logits) {
Operand<TInt64> one = Constant.scalarOf(scope, 1L);
Operand<TInt64> one = Constant.arrayOf(scope, 1L);

Shape shape = logits.shape();
int ndims = shape.numDimensions();
Expand Down

0 comments on commit fa29fc7

Please sign in to comment.