Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MultilabelSoftMarginLoss for small C #3451

Open
wants to merge 25 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8442b94
Temporary
littlecutebird Jun 11, 2024
32719a4
fix bug op
littlecutebird Jun 18, 2024
997bf00
remove backward + add fw condition
littlecutebird Jun 18, 2024
6201804
add gtest
littlecutebird Jun 18, 2024
dddecee
edit gtest input value to avoid fp precision issue
littlecutebird Jun 18, 2024
da114ec
remove redundant code
littlecutebird Jun 19, 2024
7d2ecff
resolve comments
littlecutebird Jun 20, 2024
f4ae345
receiving reduction mode in MultilabelSoftMarginLossTestConfigs
littlecutebird Jun 21, 2024
e9f959f
use FLOAT for kernel
littlecutebird Jun 24, 2024
9f6b6ad
change method to generate uncont tensor
littlecutebird Jun 24, 2024
011e43e
remove stride flag in driver
littlecutebird Jun 26, 2024
b196799
Merge branch 'develop-moreh' into nl-impl_MultilabelSoftMarginLoss
littlecutebird Dec 30, 2024
289fa7d
update after merge
littlecutebird Dec 30, 2024
9b1442d
update api, solver and gtest
littlecutebird Dec 30, 2024
7926eb3
update driver
littlecutebird Dec 30, 2024
20f2c8c
Merge remote-tracking branch 'upstream/develop' into nl-impl_Multilab…
littlecutebird Dec 30, 2024
930e8e2
remove some files
littlecutebird Dec 30, 2024
c983ab7
readd gitignore
littlecutebird Dec 30, 2024
8668e6a
update IsImprovementOverROCm
littlecutebird Dec 31, 2024
5384881
Merge remote-tracking branch 'upstream/develop' into nl-impl_Multilab…
littlecutebird Dec 31, 2024
ce2359b
update gtest name
littlecutebird Dec 31, 2024
847772d
Merge remote-tracking branch 'upstream/develop' into nl-impl_Multilab…
littlecutebird Jan 2, 2025
fbe1247
Merge branch 'develop' into nl-impl_MultilabelSoftMarginLoss
littlecutebird Jan 8, 2025
4da10d9
Merge branch 'develop' into nl-impl_MultilabelSoftMarginLoss
long10024070 Jan 13, 2025
5e0ec92
Merge branch 'develop' into nl-impl_MultilabelSoftMarginLoss
littlecutebird Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ add_executable(MIOpenDriver
dm_glu.cpp
dm_groupnorm.cpp
dm_kthvalue.cpp
dm_multilabelsoftmarginloss.cpp
dm_layernorm.cpp
dm_lrn.cpp
dm_multimarginloss.cpp
Expand Down
40 changes: 40 additions & 0 deletions driver/dm_multilabelsoftmarginloss.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2024 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "multilabelsoftmarginloss_driver.hpp"
#include "registry_driver_maker.hpp"

static Driver* makeDriver(const std::string& base_arg)
{
if(base_arg == "multilabelsoftmarginloss")
return new MultilabelSoftMarginLossDriver<float, double>();
if(base_arg == "multilabelsoftmarginlossfp16")
return new MultilabelSoftMarginLossDriver<float16, double>();
if(base_arg == "multilabelsoftmarginlossbfp16")
return new MultilabelSoftMarginLossDriver<bfloat16, double>();
return nullptr;
}

REGISTER_DRIVER_MAKER(makeDriver);
6 changes: 4 additions & 2 deletions driver/driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,8 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
"adamw[fp16], ampadamw, transformersadamw[fp16], transformersampadamw, "
"getitem[bfp16|fp16], reducecalculation[bfp16|fp16], rope[bfp16|fp16], "
"prelu[bfp16|fp16], kthvalue[bfp16|fp16], glu[bfp16|fp16], softmarginloss[bfp16|fp16], "
"multimarginloss[bfp16|fp16]\n");
"multimarginloss[bfp16|fp16], "
"multilabelsoftmarginloss[bfp16|fp16]\n");
exit(0); // NOLINT (concurrency-mt-unsafe)
}

Expand Down Expand Up @@ -352,7 +353,8 @@ inline std::string ParseBaseArg(int argc, char* argv[])
arg != "kthvaluebfp16" && arg != "glu" && arg != "glufp16" && arg != "glubfp16" &&
arg != "softmarginloss" && arg != "softmarginlossfp16" && arg != "softmarginlossbfp16" &&
arg != "multimarginloss" && arg != "multimarginlossfp16" && arg != "multimarginlossbfp16" &&
arg != "--version")
arg != "multilabelsoftmarginloss" && arg != "multilabelsoftmarginlossfp16" &&
arg != "multilabelsoftmarginlossbfp16" && arg != "--version")
{
printf("FAILED: Invalid Base Input Argument\n");
Usage();
Expand Down
Loading
Loading