-
Notifications
You must be signed in to change notification settings - Fork 468
[NPU]: Add NPU support for the embedding #1028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
benchmark_embedding result: |
|
Hi @Tcc0403, could you please help me review my code? |
Tcc0403
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the current implementation is quite inefficient. I've left some comments about some possible issues it might have.
| ) | ||
|
|
||
|
|
||
| def get_optimal_block_size(total_elements, is_backward: bool): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does is_backward do?
| @triton.jit | ||
| def embedding_forward_kernel( | ||
| embeddings_ptr, | ||
| indices_ptr, | ||
| output_ptr, | ||
| total_elements, | ||
| n_elements, | ||
| embedding_dim: tl.constexpr, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| NUM_STAGES: tl.constexpr, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the original implementation with 2 block sizes for tile shape is more readable and more efficient.
persistant grid loop is fine, but the way this kernel loading embedding seems to be uncoalesced at some point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For instance, there will be some dim_idx not consecutive if BLOCK_SIZE is not multiple of embedding_dim. It will make the second tl.load trying to access different rows within a warp, as well as the last store.
Make these offsets created with 2d block size is more readable and efficient since we can avoid the uncoalesced access mentioned above.
| tile_shapes = compute_default_tiling_strategy( | ||
| safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((total_elements,),), tiling_dims=(0,) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype_size should be embedding.dtype?
| block_size = tile_shapes[0][0] | ||
| return block_size | ||
| else: | ||
| return triton.next_power_of_2(total_elements) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think fallback value should be workable, triton.next_power_of_2(total_elements) is too large.
| embeddings_ptr + embedding_offsets, | ||
| mask=final_mask, | ||
| other=0.0, | ||
| ).to(tl.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any consideration why we need to upcast it?

Summary
Add NPU support for the embedding.
Testing Done
I tested swiglu by following method and all cases passed:
python benchmark/scripts/benchmark_embedding.pypytest -v test/transformers/test_embedding.pymake testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence