Skip to content

Commit d9fed26

Browse files
micah-wilgshtras
andauthored
Dynamic Scale Factor Calculations for Key/Value Scales With FP8 KV Caching (#317)
* Changed _k_scale and _v_scale to tensors * fixed rocm paged attention with tensor kv scales * Added on the fly scale factor calculation * trying to fix attn metadata * fixed AttentionMetadata issue, updated description for calculate-kv-scales flag in arg_utils.py * Changed K and V scale constants * Removed unneeded comment * Changes to pass format.sh, also fixed lingering k_scale/v_scale : float * Fix for TP > 1 * Ran format.sh * Removed legacy kv_scale loading from the json file * Removed the outdated kv cache docs * Revert some unwanted changes --------- Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
1 parent d09f1ce commit d9fed26

36 files changed

+194
-1324
lines changed

csrc/attention/attention_kernels.cuh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ __device__ void paged_attention_kernel(
105105
const int max_num_blocks_per_seq,
106106
const float* __restrict__ alibi_slopes, // [num_heads]
107107
const int q_stride, const int kv_block_stride, const int kv_head_stride,
108-
const float k_scale, const float v_scale, const int tp_rank,
108+
const float* k_scale, const float* v_scale, const int tp_rank,
109109
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
110110
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
111111
const int seq_idx = blockIdx.y;
@@ -285,7 +285,7 @@ __device__ void paged_attention_kernel(
285285
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
286286
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
287287
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
288-
k_vec_quant, k_scale);
288+
k_vec_quant, *k_scale);
289289
}
290290
}
291291

@@ -415,7 +415,7 @@ __device__ void paged_attention_kernel(
415415
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
416416
// Vector conversion from V_quant_vec to V_vec.
417417
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
418-
v_scale);
418+
*v_scale);
419419
}
420420
if (block_idx == num_seq_blocks - 1) {
421421
// NOTE(woosuk): When v_vec contains the tokens that are out of the
@@ -513,7 +513,7 @@ __global__ void paged_attention_v1_kernel(
513513
const int max_num_blocks_per_seq,
514514
const float* __restrict__ alibi_slopes, // [num_heads]
515515
const int q_stride, const int kv_block_stride, const int kv_head_stride,
516-
const float k_scale, const float v_scale, const int tp_rank,
516+
const float* k_scale, const float* v_scale, const int tp_rank,
517517
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
518518
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
519519
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
@@ -549,7 +549,7 @@ __global__ void paged_attention_v2_kernel(
549549
const int max_num_blocks_per_seq,
550550
const float* __restrict__ alibi_slopes, // [num_heads]
551551
const int q_stride, const int kv_block_stride, const int kv_head_stride,
552-
const float k_scale, const float v_scale, const int tp_rank,
552+
const float* k_scale, const float* v_scale, const int tp_rank,
553553
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
554554
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
555555
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,

csrc/attention/paged_attention_v1.cu

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
4242
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
4343
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
44-
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
44+
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
4545
blocksparse_vert_stride, blocksparse_block_size, \
4646
blocksparse_head_sliding_step);
4747

@@ -53,10 +53,10 @@ void paged_attention_v1_launcher(
5353
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
5454
torch::Tensor& value_cache, int num_kv_heads, float scale,
5555
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
56-
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
57-
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
58-
const int blocksparse_vert_stride, const int blocksparse_block_size,
59-
const int blocksparse_head_sliding_step) {
56+
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
57+
torch::Tensor& v_scale, const int tp_rank,
58+
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
59+
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
6060
int num_seqs = query.size(0);
6161
int num_heads = query.size(1);
6262
int head_size = query.size(2);
@@ -80,6 +80,8 @@ void paged_attention_v1_launcher(
8080
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
8181
int* block_tables_ptr = block_tables.data_ptr<int>();
8282
int* seq_lens_ptr = seq_lens.data_ptr<int>();
83+
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
84+
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
8385

8486
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
8587
int padded_max_seq_len =
@@ -193,8 +195,9 @@ void paged_attention_v1(
193195
torch::Tensor& seq_lens, // [num_seqs]
194196
int64_t block_size, int64_t max_seq_len,
195197
const c10::optional<torch::Tensor>& alibi_slopes,
196-
const std::string& kv_cache_dtype, double k_scale, double v_scale,
197-
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
198+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
199+
torch::Tensor& v_scale, const int64_t tp_rank,
200+
const int64_t blocksparse_local_blocks,
198201
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
199202
const int64_t blocksparse_head_sliding_step, const int64_t num_threads) {
200203
const bool is_block_sparse = (blocksparse_vert_stride > 1);

csrc/attention/paged_attention_v2.cu

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ typedef __hip_bfloat16 __nv_bfloat16;
4545
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
4646
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
4747
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
48-
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
48+
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
4949
blocksparse_local_blocks, blocksparse_vert_stride, \
5050
blocksparse_block_size, blocksparse_head_sliding_step); \
5151
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
@@ -62,10 +62,10 @@ void paged_attention_v2_launcher(
6262
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
6363
torch::Tensor& value_cache, int num_kv_heads, float scale,
6464
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
65-
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
66-
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
67-
const int blocksparse_vert_stride, const int blocksparse_block_size,
68-
const int blocksparse_head_sliding_step) {
65+
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
66+
torch::Tensor& v_scale, const int tp_rank,
67+
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
68+
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
6969
int num_seqs = query.size(0);
7070
int num_heads = query.size(1);
7171
int head_size = query.size(2);
@@ -92,6 +92,8 @@ void paged_attention_v2_launcher(
9292
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
9393
int* block_tables_ptr = block_tables.data_ptr<int>();
9494
int* seq_lens_ptr = seq_lens.data_ptr<int>();
95+
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
96+
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
9597

9698
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
9799
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
@@ -212,8 +214,9 @@ void paged_attention_v2(
212214
torch::Tensor& seq_lens, // [num_seqs]
213215
int64_t block_size, int64_t max_seq_len,
214216
const c10::optional<torch::Tensor>& alibi_slopes,
215-
const std::string& kv_cache_dtype, double k_scale, double v_scale,
216-
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
217+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
218+
torch::Tensor& v_scale, const int64_t tp_rank,
219+
const int64_t blocksparse_local_blocks,
217220
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
218221
const int64_t blocksparse_head_sliding_step, const int64_t num_threads) {
219222
const bool is_block_sparse = (blocksparse_vert_stride > 1);

csrc/cache.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
1818
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
1919
torch::Tensor& key_cache, torch::Tensor& value_cache,
2020
torch::Tensor& slot_mapping,
21-
const std::string& kv_cache_dtype, const double k_scale,
22-
const double v_scale);
21+
const std::string& kv_cache_dtype,
22+
torch::Tensor& k_scale, torch::Tensor& v_scale);
2323

2424
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
2525
torch::Tensor& key_cache,
2626
torch::Tensor& value_cache,
2727
torch::Tensor& slot_mapping,
2828
const std::string& kv_cache_dtype,
29-
const double k_scale, const double v_scale);
29+
torch::Tensor& k_scale, torch::Tensor& v_scale);
3030

3131
// Just for unittest
3232
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,

csrc/cache_kernels.cu

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
159159
// block_size]
160160
const int64_t* __restrict__ slot_mapping, // [num_tokens]
161161
const int key_stride, const int value_stride, const int num_heads,
162-
const int head_size, const int block_size, const int x, const float k_scale,
163-
const float v_scale) {
162+
const int head_size, const int block_size, const int x,
163+
const float* k_scale, const float* v_scale) {
164164
const int64_t token_idx = blockIdx.x;
165165
const int64_t slot_idx = slot_mapping[token_idx];
166166
if (slot_idx < 0) {
@@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
196196
value_cache[tgt_value_idx] = tgt_value;
197197
} else {
198198
key_cache[tgt_key_idx] =
199-
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
199+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
200200
value_cache[tgt_value_idx] =
201-
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
201+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
202202
}
203203
}
204204
}
@@ -214,7 +214,7 @@ __global__ void reshape_and_cache_flash_kernel(
214214
const int64_t* __restrict__ slot_mapping, // [num_tokens]
215215
const int block_stride, const int key_stride, const int value_stride,
216216
const int num_heads, const int head_size, const int block_size,
217-
const float k_scale, const float v_scale) {
217+
const float* k_scale, const float* v_scale) {
218218
const int64_t token_idx = blockIdx.x;
219219
const int64_t slot_idx = slot_mapping[token_idx];
220220
// NOTE: slot_idx can be -1 if the token is padded
@@ -239,9 +239,9 @@ __global__ void reshape_and_cache_flash_kernel(
239239
value_cache[tgt_key_value_idx] = tgt_value;
240240
} else {
241241
key_cache[tgt_key_value_idx] =
242-
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
242+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
243243
value_cache[tgt_key_value_idx] =
244-
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
244+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
245245
}
246246
}
247247
}
@@ -258,7 +258,9 @@ __global__ void reshape_and_cache_flash_kernel(
258258
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
259259
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
260260
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
261-
num_heads, head_size, block_size, x, k_scale, v_scale);
261+
num_heads, head_size, block_size, x, \
262+
reinterpret_cast<const float*>(k_scale.data_ptr()), \
263+
reinterpret_cast<const float*>(v_scale.data_ptr()));
262264

263265
void reshape_and_cache(
264266
torch::Tensor& key, // [num_tokens, num_heads, head_size]
@@ -268,8 +270,8 @@ void reshape_and_cache(
268270
torch::Tensor&
269271
value_cache, // [num_blocks, num_heads, head_size, block_size]
270272
torch::Tensor& slot_mapping, // [num_tokens]
271-
const std::string& kv_cache_dtype, const double k_scale,
272-
const double v_scale) {
273+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
274+
torch::Tensor& v_scale) {
273275
int num_tokens = key.size(0);
274276
int num_heads = key.size(1);
275277
int head_size = key.size(2);
@@ -299,7 +301,9 @@ void reshape_and_cache(
299301
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
300302
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
301303
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
302-
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
304+
value_stride, num_heads, head_size, block_size, \
305+
reinterpret_cast<const float*>(k_scale.data_ptr()), \
306+
reinterpret_cast<const float*>(v_scale.data_ptr()));
303307

304308
void reshape_and_cache_flash(
305309
torch::Tensor& key, // [num_tokens, num_heads, head_size]
@@ -308,8 +312,8 @@ void reshape_and_cache_flash(
308312
torch::Tensor&
309313
value_cache, // [num_blocks, block_size, num_heads, head_size]
310314
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
311-
const std::string& kv_cache_dtype, const double k_scale,
312-
const double v_scale) {
315+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
316+
torch::Tensor& v_scale) {
313317
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
314318
// slot_mapping.size(0) because of padding for CUDA graphs.
315319
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because

csrc/ops.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ void paged_attention_v1(
3434
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
3535
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
3636
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
37-
const std::string& kv_cache_dtype, double k_scale, double v_scale,
38-
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
37+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
38+
torch::Tensor& v_scale, const int64_t tp_rank,
39+
const int64_t blocksparse_local_blocks,
3940
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
4041
const int64_t blocksparse_head_sliding_step, const int64_t num_threads);
4142

@@ -45,8 +46,9 @@ void paged_attention_v2(
4546
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
4647
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
4748
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
48-
const std::string& kv_cache_dtype, double k_scale, double v_scale,
49-
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
49+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
50+
torch::Tensor& v_scale, const int64_t tp_rank,
51+
const int64_t blocksparse_local_blocks,
5052
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
5153
const int64_t blocksparse_head_sliding_step, const int64_t num_threads);
5254

csrc/rocm/attention.cu

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
236236
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
237237
// head_size]
238238
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
239-
int max_ctx_blocks, float k_scale, float v_scale,
239+
int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr,
240240
const float* __restrict__ fp8_out_scale_ptr) {
241241
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
242242
const int warpid = threadIdx.x / WARP_SIZE;
@@ -438,7 +438,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
438438
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
439439
const _B8x8 Vlocalb8 = v_ptrh8be[d];
440440
Vlocal[h][b * BLOCK_SIZE / 8 + d] =
441-
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, v_scale);
441+
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, *v_scale_ptr);
442442
}
443443
}
444444
}
@@ -448,7 +448,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
448448
#pragma unroll
449449
for (int d = 0; d < KHELOOP; d++) {
450450
Klocal[d] =
451-
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale);
451+
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], *k_scale_ptr);
452452
}
453453
}
454454

@@ -995,7 +995,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
995995
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
996996
// head_size]
997997
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
998-
int max_ctx_blocks, float k_scale, float v_scale,
998+
int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr,
999999
const float* __restrict__ fp8_out_scale_ptr) {
10001000
UNREACHABLE_CODE
10011001
}
@@ -1026,7 +1026,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
10261026
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
10271027
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
10281028
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
1029-
k_scale, v_scale, fp8_out_scale_ptr);
1029+
k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr);
10301030

10311031
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
10321032
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
@@ -1043,7 +1043,7 @@ void paged_attention_custom_launcher(
10431043
torch::Tensor& value_cache, const int num_kv_heads, float scale,
10441044
torch::Tensor& block_tables, torch::Tensor& context_lens,
10451045
int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
1046-
float k_scale, float v_scale,
1046+
torch::Tensor& k_scale, torch::Tensor& v_scale,
10471047
const c10::optional<torch::Tensor>& fp8_out_scale) {
10481048
int num_seqs = query.size(0);
10491049
int num_heads = query.size(1);
@@ -1068,6 +1068,9 @@ void paged_attention_custom_launcher(
10681068
int* block_tables_ptr = block_tables.data_ptr<int>();
10691069
int* context_lens_ptr = context_lens.data_ptr<int>();
10701070

1071+
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
1072+
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
1073+
10711074
// NOTE: fp8_out_scale is optional.
10721075
const float* fp8_out_scale_ptr =
10731076
fp8_out_scale
@@ -1263,8 +1266,9 @@ void paged_attention(
12631266
torch::Tensor& context_lens, // [num_seqs]
12641267
int64_t block_size, int64_t max_context_len,
12651268
const c10::optional<torch::Tensor>& alibi_slopes,
1266-
const std::string& kv_cache_dtype, double k_scale, double v_scale,
1267-
const c10::optional<torch::Tensor>& fp8_out_scale, int64_t partition_size) {
1269+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
1270+
torch::Tensor& v_scale, const c10::optional<torch::Tensor>& fp8_out_scale,
1271+
int64_t partition_size) {
12681272
const int head_size = query.size(2);
12691273
if (kv_cache_dtype == "auto") {
12701274
if (query.dtype() == at::ScalarType::Half) {

0 commit comments

Comments
 (0)