Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webgpu: support MultiHeadAttention operator #22144

Draft
wants to merge 2 commits into
base: fs-eire/webgpu-ep
Choose a base branch
from

Conversation

xhcao
Copy link
Contributor

@xhcao xhcao commented Sep 19, 2024

Description

Motivation and Context

@xhcao xhcao mentioned this pull request Sep 19, 2024
@xhcao
Copy link
Contributor Author

xhcao commented Sep 19, 2024

Although I think there are some unreasonable codes in JS MultiHeadAttention operator, I still kept the code in webgpu EP nearly the same as JS EP. Let's adjust and optimize the code in future.
@fs-eire @qjia7


class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram> {
public:
TransferBSDToBNSHProgram(const std::string& kernel_name, bool has_bias) : Program{kernel_name}, has_bias_(has_bias) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TransferBSDToBNSHProgram(const std::string& kernel_name, bool has_bias) : Program{kernel_name}, has_bias_(has_bias) {}
TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

shader.AddOutput("present_key", ShaderVariable::UseUniform);
}

shader.AppendImplementation("const TILE_SIZE = ", tile_size_, "u;\n")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that cache keys for the programs are not set correctly.

for example, here tile_size_ is used as a part of the shader source code, but it is not set in the cache key. Use program.CacheHint() to set the cache key.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TILE_SIZE can also be declared in overridable constants.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

shader.AppendImplementation("var<workgroup> thread_max: array<f32, ", work_group_size_, ">;\n")
.AppendImplementation("var<workgroup> thread_sum: array<f32, ", work_group_size_, ">;\n");

std::string f32_str = components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use x_value_t for the value type of x.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No matter x's type is f32 or f16, the program only uses f32 to define max and sum values.

@fs-eire
Copy link
Contributor

fs-eire commented Sep 19, 2024

I didn't see any call to set program cache key. This may be correct (if necessary information is already in uniform). need to confirm.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants