Skip to content

Conversation

@TianHao324
Copy link
Contributor

@TianHao324 TianHao324 commented Jan 19, 2026

Summary

Add NPU support for the embedding.

  • Implements a flattened, grid-stride Triton kernel for embedding forward/backward to improve scalability and reduce launch overhead on Ascend NPUs.
  • Uses UB-aware tiling (compute_default_tiling_strategy) and NPU vector core count to dynamically select block size and grid size for better performance stability.

Testing Done

I tested swiglu by following method and all cases passed:

  • python benchmark/scripts/benchmark_embedding.py
  • pytest -v test/transformers/test_embedding.py
  • Hardware Type: Ascend NPU 910B4
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@TianHao324
Copy link
Contributor Author

test_embedding result:
image

@TianHao324
Copy link
Contributor Author

benchmark_embedding result:

********** Benchmark Data **********
[
  {
    "kernel_name": "embedding",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      159.04737854003906,
      159.22134399414062,
      158.77392578125,
      156.7908172607422,
      158.8311004638672,
      159.46414184570312,
      159.698974609375,
      157.46987915039062
    ],
    "y_values_20": [
      159.04737854003906,
      159.22134399414062,
      158.77392578125,
      156.7908172607422,
      158.8311004638672,
      159.46414184570312,
      159.698974609375,
      157.46987915039062
    ],
    "y_values_80": [
      159.04737854003906,
      159.22134399414062,
      158.77392578125,
      156.7908172607422,
      158.8311004638672,
      159.46414184570312,
      159.698974609375,
      157.46987915039062
    ],
    "timestamp": "2026-01-19 11:21:12",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"B\": 32, \"T\": 512, \"D\": 768, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      0.08076000213623047,
      0.09152999520301819,
      0.11342000216245651,
      0.1491200029850006,
      0.18498000502586365,
      0.21288999915122986,
      0.2282399982213974,
      0.23583999276161194
    ],
    "y_values_20": [
      0.0803999975323677,
      0.09111999720335007,
      0.11283999681472778,
      0.14860399067401886,
      0.18449999392032623,
      0.21233999729156494,
      0.22757600247859955,
      0.23533600568771362
    ],
    "y_values_80": [
      0.08143600076436996,
      0.09251999855041504,
      0.1143999993801117,
      0.1496559977531433,
      0.1855199933052063,
      0.2136079967021942,
      0.22905999422073364,
      0.23659199476242065
    ],
    "timestamp": "2026-01-19 11:21:22",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"B\": 32, \"T\": 512, \"D\": 768, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "torch_compile",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      0.20855998992919922,
      0.23837999999523163,
      0.23678000271320343,
      0.2353000044822693,
      0.22234000265598297,
      0.2555199861526489,
      0.2609800100326538,
      0.26861000061035156
    ],
    "y_values_20": [
      0.20294000208377838,
      0.2349800020456314,
      0.23097999393939972,
      0.23041599988937378,
      0.21833600103855133,
      0.2518959939479828,
      0.2575879991054535,
      0.2648400068283081
    ],
    "y_values_80": [
      0.21427999436855316,
      0.24292799830436707,
      0.24410000443458557,
      0.24070800840854645,
      0.22902800142765045,
      0.2592040002346039,
      0.26579999923706055,
      0.2731640040874481
    ],
    "timestamp": "2026-01-19 11:21:37",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"B\": 32, \"T\": 512, \"D\": 768, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      171.3461151123047,
      172.15188598632812,
      173.09751892089844,
      170.68634033203125,
      172.08078002929688,
      172.8400421142578,
      172.94947814941406,
      174.10116577148438
    ],
    "y_values_20": [
      171.3461151123047,
      172.15188598632812,
      173.09751892089844,
      170.68634033203125,
      172.08078002929688,
      172.8400421142578,
      172.94947814941406,
      174.10116577148438
    ],
    "y_values_80": [
      171.3461151123047,
      172.15188598632812,
      173.09751892089844,
      170.68634033203125,
      172.08078002929688,
      172.8400421142578,
      172.94947814941406,
      174.10116577148438
    ],
    "timestamp": "2026-01-19 11:22:00",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"B\": 32, \"T\": 512, \"D\": 768, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      1.567539930343628,
      1.582919955253601,
      1.621440052986145,
      1.6944499015808105,
      1.8620400428771973,
      2.2697701454162598,
      2.9884400367736816,
      4.289669990539551
    ],
    "y_values_20": [
      1.5666999816894531,
      1.581387996673584,
      1.620255947113037,
      1.6928720474243164,
      1.8604480028152466,
      2.268468141555786,
      2.9868600368499756,
      4.2842559814453125
    ],
    "y_values_80": [
      1.5687999725341797,
      1.5840359926223755,
      1.6228679418563843,
      1.6961640119552612,
      1.8632080554962158,
      2.271224021911621,
      2.989919900894165,
      4.294147968292236
    ],
    "timestamp": "2026-01-19 11:22:11",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"B\": 32, \"T\": 512, \"D\": 768, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "torch_compile",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      1.5685700178146362,
      1.583359956741333,
      1.6218900680541992,
      1.6950700283050537,
      1.8614000082015991,
      2.2811598777770996,
      3.0018599033355713,
      4.2853899002075195
    ],
    "y_values_20": [
      1.5674799680709839,
      1.5822800397872925,
      1.6209839582443237,
      1.6933799982070923,
      1.8602440357208252,
      2.2799479961395264,
      2.9999001026153564,
      4.283323764801025
    ],
    "y_values_80": [
      1.569700002670288,
      1.5844600200653076,
      1.6224479675292969,
      1.6960320472717285,
      1.8625959157943726,
      2.2834479808807373,
      3.004300117492676,
      4.2871198654174805
    ],
    "timestamp": "2026-01-19 11:22:22",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"B\": 32, \"T\": 512, \"D\": 768, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      329.966552734375,
      331.5846862792969,
      328.7117004394531,
      330.4582214355469,
      328.7718200683594,
      330.6532897949219,
      330.2299499511719,
      334.3844909667969
    ],
    "y_values_20": [
      329.966552734375,
      331.5846862792969,
      328.7117004394531,
      330.4582214355469,
      328.7718200683594,
      330.6532897949219,
      330.2299499511719,
      334.3844909667969
    ],
    "y_values_80": [
      329.966552734375,
      331.5846862792969,
      328.7117004394531,
      330.4582214355469,
      328.7718200683594,
      330.6532897949219,
      330.2299499511719,
      334.3844909667969
    ],
    "timestamp": "2026-01-19 11:22:53",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"B\": 32, \"T\": 512, \"D\": 768, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      1.659500002861023,
      1.6823298931121826,
      1.7349200248718262,
      1.8467600345611572,
      2.023710012435913,
      2.4428000450134277,
      3.1635000705718994,
      4.471399784088135
    ],
    "y_values_20": [
      1.6574479341506958,
      1.6813240051269531,
      1.7337599992752075,
      1.8458199501037598,
      2.022576093673706,
      2.4407401084899902,
      3.160372018814087,
      4.469580173492432
    ],
    "y_values_80": [
      1.6609920263290405,
      1.683732032775879,
      1.7369400262832642,
      1.848431944847107,
      2.0262999534606934,
      2.445556163787842,
      3.165616035461426,
      4.475900173187256
    ],
    "timestamp": "2026-01-19 11:23:04",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"B\": 32, \"T\": 512, \"D\": 768, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "torch_compile",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      1.6604700088500977,
      1.6819599866867065,
      1.734910011291504,
      1.8386199474334717,
      2.0256800651550293,
      2.4474198818206787,
      3.178179979324341,
      4.476739883422852
    ],
    "y_values_20": [
      1.6595079898834229,
      1.6803920269012451,
      1.7331640720367432,
      1.8376519680023193,
      2.0248560905456543,
      2.44594407081604,
      3.1758198738098145,
      4.4761199951171875
    ],
    "y_values_80": [
      1.662611961364746,
      1.6834520101547241,
      1.7368199825286865,
      1.8401119709014893,
      2.027276039123535,
      2.4498159885406494,
      3.180596113204956,
      4.478759765625
    ],
    "timestamp": "2026-01-19 11:23:15",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"B\": 32, \"T\": 512, \"D\": 768, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      919.2055053710938,
      924.4922485351562,
      925.837890625,
      917.7440185546875,
      933.4002075195312,
      926.0487670898438,
      915.8991088867188,
      901.7867431640625
    ],
    "y_values_20": [
      919.2055053710938,
      924.4922485351562,
      925.837890625,
      917.7440185546875,
      933.4002075195312,
      926.0487670898438,
      915.8991088867188,
      901.7867431640625
    ],
    "y_values_80": [
      919.2055053710938,
      924.4922485351562,
      925.837890625,
      917.7440185546875,
      933.4002075195312,
      926.0487670898438,
      915.8991088867188,
      901.7867431640625
    ],
    "timestamp": "2026-01-19 11:25:08",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048, \"D\": 4096, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      0.48420000076293945,
      0.5853400230407715,
      0.7012799978256226,
      0.798509955406189,
      0.8588399887084961,
      0.8877999782562256,
      0.9045600295066833,
      0.9147800207138062
    ],
    "y_values_20": [
      0.48089200258255005,
      0.5818399786949158,
      0.6985039710998535,
      0.793940007686615,
      0.8546640276908875,
      0.8855999708175659,
      0.9034039974212646,
      0.9137159585952759
    ],
    "y_values_80": [
      0.4874799847602844,
      0.5883399844169617,
      0.704584002494812,
      0.8027079701423645,
      0.8639959692955017,
      0.8940200209617615,
      0.9097560048103333,
      0.9184520244598389
    ],
    "timestamp": "2026-01-19 11:26:04",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048, \"D\": 4096, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "torch_compile",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      0.4835599958896637,
      0.5865799784660339,
      0.705079972743988,
      0.8017400503158569,
      0.8565599918365479,
      0.888759970664978,
      0.9037399888038635,
      0.9135800004005432
    ],
    "y_values_20": [
      0.48020797967910767,
      0.5828199982643127,
      0.702243983745575,
      0.7970640063285828,
      0.853227972984314,
      0.8873400092124939,
      0.9027560353279114,
      0.9126359820365906
    ],
    "y_values_80": [
      0.48682400584220886,
      0.589139997959137,
      0.7075120210647583,
      0.8054919838905334,
      0.8639839887619019,
      0.8956599831581116,
      0.9073839783668518,
      0.9195600152015686
    ],
    "timestamp": "2026-01-19 11:26:57",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048, \"D\": 4096, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      945.84130859375,
      952.7191162109375,
      957.5747680664062,
      950.7302856445312,
      944.4890747070312,
      948.2030029296875,
      964.8849487304688,
      962.3030395507812
    ],
    "y_values_20": [
      945.84130859375,
      952.7191162109375,
      957.5747680664062,
      950.7302856445312,
      944.4890747070312,
      948.2030029296875,
      964.8849487304688,
      962.3030395507812
    ],
    "y_values_80": [
      945.84130859375,
      952.7191162109375,
      957.5747680664062,
      950.7302856445312,
      944.4890747070312,
      948.2030029296875,
      964.8849487304688,
      962.3030395507812
    ],
    "timestamp": "2026-01-19 11:28:58",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048, \"D\": 4096, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      7.050960063934326,
      7.176300048828125,
      7.4766998291015625,
      8.097700119018555,
      9.109689712524414,
      10.970990180969238,
      14.46051025390625,
      21.254928588867188
    ],
    "y_values_20": [
      7.049655914306641,
      7.175192356109619,
      7.476096153259277,
      8.095372200012207,
      9.10888385772705,
      10.968511581420898,
      14.458100318908691,
      21.25203514099121
    ],
    "y_values_80": [
      7.052547931671143,
      7.1777801513671875,
      7.479163646697998,
      8.098332405090332,
      9.111007690429688,
      10.974867820739746,
      14.464619636535645,
      21.258079528808594
    ],
    "timestamp": "2026-01-19 11:29:51",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048, \"D\": 4096, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "torch_compile",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      7.051300048828125,
      7.172420024871826,
      7.480500221252441,
      8.09889030456543,
      9.142240524291992,
      10.97877025604248,
      14.473550796508789,
      21.26235008239746
    ],
    "y_values_20": [
      7.050387859344482,
      7.1711320877075195,
      7.479703903198242,
      8.097579956054688,
      9.139100074768066,
      10.975571632385254,
      14.470020294189453,
      21.259340286254883
    ],
    "y_values_80": [
      7.05352783203125,
      7.174144268035889,
      7.482272148132324,
      8.100163459777832,
      9.143535614013672,
      10.9817476272583,
      14.475319862365723,
      21.265239715576172
    ],
    "timestamp": "2026-01-19 11:30:44",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048, \"D\": 4096, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      1874.319580078125,
      1857.006103515625,
      1877.2384033203125,
      1870.370361328125,
      1845.2091064453125,
      1862.2515869140625,
      1867.099365234375,
      1878.7745361328125
    ],
    "y_values_20": [
      1874.319580078125,
      1857.006103515625,
      1877.2384033203125,
      1870.370361328125,
      1845.2091064453125,
      1862.2515869140625,
      1867.099365234375,
      1878.7745361328125
    ],
    "y_values_80": [
      1874.319580078125,
      1857.006103515625,
      1877.2384033203125,
      1870.370361328125,
      1845.2091064453125,
      1862.2515869140625,
      1867.099365234375,
      1878.7745361328125
    ],
    "timestamp": "2026-01-19 11:33:36",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048, \"D\": 4096, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      7.529760360717773,
      7.7440900802612305,
      8.142419815063477,
      8.862719535827637,
      9.918439865112305,
      11.839099884033203,
      15.30836009979248,
      22.129940032958984
    ],
    "y_values_20": [
      7.52623176574707,
      7.740516185760498,
      8.14076042175293,
      8.857259750366211,
      9.917183876037598,
      11.836175918579102,
      15.305279731750488,
      22.125947952270508
    ],
    "y_values_80": [
      7.531836032867432,
      7.745599746704102,
      8.148759841918945,
      8.863639831542969,
      9.924500465393066,
      11.844079971313477,
      15.313380241394043,
      22.142860412597656
    ],
    "timestamp": "2026-01-19 11:34:30",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048, \"D\": 4096, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "torch_compile",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      7.528960227966309,
      7.742199897766113,
      8.159040451049805,
      8.857500076293945,
      9.937239646911621,
      11.829119682312012,
      15.305660247802734,
      22.119709014892578
    ],
    "y_values_20": [
      7.526879787445068,
      7.739608287811279,
      8.153400421142578,
      8.854619979858398,
      9.934915542602539,
      11.823528289794922,
      15.30228042602539,
      22.117971420288086
    ],
    "y_values_80": [
      7.531447887420654,
      7.743847846984863,
      8.162480354309082,
      8.87007999420166,
      9.942655563354492,
      11.833767890930176,
      15.309659957885742,
      22.12152862548828
    ],
    "timestamp": "2026-01-19 11:35:23",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048, \"D\": 4096, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  }
]

@TianHao324
Copy link
Contributor Author

Hi @Tcc0403, could you please help me review my code?

Copy link
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the current implementation is quite inefficient. I've left some comments about some possible issues it might have.

)


def get_optimal_block_size(total_elements, is_backward: bool):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does is_backward do?

Comment on lines +14 to +24
@triton.jit
def embedding_forward_kernel(
embeddings_ptr,
indices_ptr,
output_ptr,
total_elements,
n_elements,
embedding_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
NUM_STAGES: tl.constexpr,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the original implementation with 2 block sizes for tile shape is more readable and more efficient.

persistant grid loop is fine, but the way this kernel loading embedding seems to be uncoalesced at some point.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For instance, there will be some dim_idx not consecutive if BLOCK_SIZE is not multiple of embedding_dim. It will make the second tl.load trying to access different rows within a warp, as well as the last store.

Make these offsets created with 2d block size is more readable and efficient since we can avoid the uncoalesced access mentioned above.

Comment on lines +110 to +112
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((total_elements,),), tiling_dims=(0,)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype_size should be embedding.dtype?

block_size = tile_shapes[0][0]
return block_size
else:
return triton.next_power_of_2(total_elements)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think fallback value should be workable, triton.next_power_of_2(total_elements) is too large.

embeddings_ptr + embedding_offsets,
mask=final_mask,
other=0.0,
).to(tl.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any consideration why we need to upcast it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants