Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -2373,11 +2373,11 @@
"test_slice_negative_axes",
"test_slice_start_out_of_bounds",
"test_slice",
// "test_softmax_axis_0_expanded",
"test_softmax_axis_0_expanded",
"test_softmax_axis_0",
// "test_softmax_axis_1_expanded",
"test_softmax_axis_1_expanded",
"test_softmax_axis_1",
// "test_softmax_axis_2_expanded",
"test_softmax_axis_2_expanded",
"test_softmax_axis_2",
// "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded",
// "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob_expanded",
Expand Down Expand Up @@ -2447,13 +2447,13 @@
// "test_softmax_cross_entropy_sum_log_prob_expanded",
// "test_softmax_cross_entropy_sum_log_prob",
// "test_softmax_cross_entropy_sum",
// "opset13/test_softmax_default_axis_expanded",
"opset13/test_softmax_default_axis_expanded",
"opset13/test_softmax_default_axis",
// "test_softmax_example_expanded",
"test_softmax_example_expanded",
"test_softmax_example",
// "test_softmax_large_number_expanded",
"test_softmax_large_number_expanded",
"test_softmax_large_number",
// "test_softmax_negative_axis_expanded",
"test_softmax_negative_axis_expanded",
"test_softmax_negative_axis",
// // "test_softplus_example",
// // "test_softplus",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,48 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
const auto input_size = input_shape.size();

emscripten::val options = emscripten::val::object();

NodeAttrHelper helper(node);
const int32_t default_axis = node.SinceVersion() < 13 ? 1 : -1;
const auto since_version = node.SinceVersion();
const int32_t default_axis = since_version < 13 ? 1 : -1;
int32_t axis = helper.Get("axis", default_axis);
axis = static_cast<int32_t>(HandleNegativeAxis(axis, input_size));
axis = SafeInt<int32_t>(HandleNegativeAxis(axis, input_size));

// Prior to opset 13, Softmax operates with different semantics compared to opset 13 and later.
// Specifically, it normalizes over the flattened range of dimensions starting from the specified
// axis to the last dimension.
// In contrast, WebNN's softmax aligns with the behavior introduced in opset 13 and later.
// To handle the differences for earlier opsets, a reshape operation can be applied if necessary.
const bool do_reshape = since_version < 13 && axis != SafeInt<int32_t>(input_size - 1);
std::vector<uint32_t> input_shape_uint32;
if (do_reshape) {
input_shape_uint32 = GetNarrowedIntFromInt64<uint32_t>(input_shape);
// Need to reshape the input to 2D tensor with new shape [M, N].
// M = d0*d1*...*d(axis-1), N = d(axis)*...*d(n-1)
const auto M = Product(std::vector<uint32_t>(input_shape_uint32.begin(), input_shape_uint32.begin() + axis));
const auto N = Product(std::vector<uint32_t>(input_shape_uint32.begin() + axis, input_shape_uint32.end()));
emscripten::val new_shape = emscripten::val::array();
new_shape.set(0, M);
new_shape.set(1, N);

options.set("label", node.Name() + "_reshape_input");
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input, new_shape, options);
// Apply softmax along the last dimension (N).
axis = 1;
}

emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("softmax", input, axis, options);

if (do_reshape) {
// Softmax has the same output shape as input shape.
// Reshape the output back to the original input shape.
options.set("label", node.Name() + "_reshape_output");
output = model_builder.GetBuilder().call<emscripten::val>(
"reshape", output, emscripten::val::array(input_shape_uint32), options);
}

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
}
Expand Down