@@ -36,11 +36,19 @@ static constexpr void manually_unroll_loop(F &&f) {
36
36
37
37
template <size_t TM, size_t TN, size_t TK> class MatMul ;
38
38
39
- template <size_t rowsA, size_t colsA, size_t rowsB, size_t colsB,
39
+ template <
40
+ #if !defined(ARG_DIM) && !defined(RUNTIME_DIM)
41
+ size_t rowsA, size_t colsA, size_t rowsB, size_t colsB,
42
+ #endif // ARG_DIM, RUNTIME_DIM
40
43
size_t vnniFactor, typename TOperand, typename TResult, size_t TM,
41
44
size_t TN, size_t TK, size_t MCache1, size_t NCache1, size_t KCache1,
42
45
size_t MCache2, size_t NCache2, size_t KCache2>
43
- double joint_matmul (TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
46
+ double joint_matmul (TOperand *A, TOperand *B, TResult *C, queue &q, int i
47
+ #if defined(ARG_DIM) || defined(RUNTIME_DIM)
48
+ , size_t rowsA, size_t colsA, size_t rowsB, size_t colsB
49
+ #endif // ARG_DIM, RUNTIME_DIM
50
+ ) {
51
+
44
52
size_t sgSize = get_sg_size<MatMul<TM, TN, TK>>(q);
45
53
range<2 > global{rowsA / MCache1, (colsB / NCache1) * sgSize};
46
54
range<2 > cachelocal{MCache2 / MCache1, NCache2 / NCache1 * sgSize};
@@ -287,8 +295,8 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
287
295
#ifdef PREFETCH
288
296
auto prefetch_offsetA = (m2 * MCache2 + sgId * prefRow) * colsA +
289
297
(k2 + prefDistance) * prefCol;
290
- if ((prefetch_offsetA + (prefRow * MATRIX_SIZE ) + prefCol) <
291
- (MATRIX_SIZE * MATRIX_SIZE ))
298
+ if ((prefetch_offsetA + (prefRow * colsA ) + prefCol) <
299
+ (rowsA * colsA ))
292
300
joint_matrix_prefetch<prefRow, prefCol>(
293
301
sg, A + prefetch_offsetA, colsA, layout::row_major,
294
302
syclex::properties{syclex::prefetch_hint_L1});
@@ -298,8 +306,8 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
298
306
pm1B * prefRow) *
299
307
(colsB)*vnniFactor +
300
308
(n2 * NCache2 * vnniFactor + pn1B * prefCol);
301
- if ((prefetch_offsetB + (prefRow * MATRIX_SIZE * vnniFactor) +
302
- prefCol) < (MATRIX_SIZE * MATRIX_SIZE ))
309
+ if ((prefetch_offsetB + (prefRow * colsB * vnniFactor) +
310
+ prefCol) < (rowsB * colsB ))
303
311
joint_matrix_prefetch<prefRow, prefCol>(
304
312
sg, B + prefetch_offsetB, colsB * vnniFactor,
305
313
layout::row_major,
@@ -349,31 +357,37 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
349
357
template <typename T, typename TResult, size_t vnniFactor, size_t TM, size_t TN,
350
358
size_t TK, size_t MCache1, size_t NCache1, size_t KCache1,
351
359
size_t MCache2, size_t NCache2, size_t KCache2>
352
- void test () {
353
- assert (MATRIX_SIZE >= TM && MATRIX_SIZE >= TK && MATRIX_SIZE >= TN &&
360
+ void test (size_t matrix_size_input) {
361
+ #ifdef RUNTIME_DIM
362
+ size_t matrix_size = matrix_size_input;
363
+ #else
364
+ constexpr size_t matrix_size = MATRIX_SIZE;
365
+ #endif // RUNTIME_DIM
366
+
367
+ assert (matrix_size >= TM && matrix_size >= TK && matrix_size >= TN &&
354
368
" invalid matrix size" );
355
- assert ((MATRIX_SIZE % TM) == 0 && (MATRIX_SIZE % TN) == 0 &&
356
- (MATRIX_SIZE % TK) == 0 &&
369
+ assert ((matrix_size % TM) == 0 && (matrix_size % TN) == 0 &&
370
+ (matrix_size % TK) == 0 &&
357
371
" invalid matrix size detected: not a multiple of <TM,TN,TK>" );
358
372
359
373
std::cout << " Testing: " << TM << " x " << TN << " x " << TK
360
374
<< " [TM x TN x TK]" << std::endl;
361
375
362
376
queue q;
363
- T *A = malloc_shared<T>(MATRIX_SIZE * MATRIX_SIZE , q);
364
- T *B = malloc_shared<T>(MATRIX_SIZE * MATRIX_SIZE , q);
365
- TResult *C = malloc_shared<TResult>(MATRIX_SIZE * MATRIX_SIZE , q);
366
- TResult *refC = malloc_shared<TResult>(MATRIX_SIZE * MATRIX_SIZE , q);
377
+ T *A = malloc_shared<T>(matrix_size * matrix_size , q);
378
+ T *B = malloc_shared<T>(matrix_size * matrix_size , q);
379
+ TResult *C = malloc_shared<TResult>(matrix_size * matrix_size , q);
380
+ TResult *refC = malloc_shared<TResult>(matrix_size * matrix_size , q);
367
381
368
- matrix_rand<T>(MATRIX_SIZE, MATRIX_SIZE , A, T (1 ));
369
- matrix_rand<T>(MATRIX_SIZE, MATRIX_SIZE , B, T (1 ));
382
+ matrix_rand<T>(matrix_size, matrix_size , A, T (1 ));
383
+ matrix_rand<T>(matrix_size, matrix_size , B, T (1 ));
370
384
371
- matrix_multiply_ref<T, T, TResult, 1 >(A, B, refC, MATRIX_SIZE, MATRIX_SIZE ,
372
- MATRIX_SIZE );
385
+ matrix_multiply_ref<T, T, TResult, 1 >(A, B, refC, matrix_size, matrix_size ,
386
+ matrix_size );
373
387
374
388
#ifdef VNNI
375
- T *vnniB = malloc_shared<T>(MATRIX_SIZE * MATRIX_SIZE , q);
376
- matrix_vnni<T>(MATRIX_SIZE, MATRIX_SIZE , B, vnniB, vnniFactor);
389
+ T *vnniB = malloc_shared<T>(matrix_size * matrix_size , q);
390
+ matrix_vnni<T>(matrix_size, matrix_size , B, vnniB, vnniFactor);
377
391
free (B, q);
378
392
B = vnniB;
379
393
#endif
@@ -382,22 +396,31 @@ void test() {
382
396
double totalDuration = 0 ;
383
397
for (unsigned int i = 0 ; i < testIterations; i++) {
384
398
double duration =
385
- joint_matmul<MATRIX_SIZE, MATRIX_SIZE, MATRIX_SIZE, MATRIX_SIZE,
386
- vnniFactor, T, TResult, TM, TN, TK, MCache1, NCache1,
387
- KCache1, MCache2, NCache2, KCache2>(A, B, C, q, i);
399
+ joint_matmul<
400
+ #if !defined(ARG_DIM) && !defined(RUNTIME_DIM)
401
+ matrix_size, matrix_size, matrix_size, matrix_size,
402
+ #endif // ARG_DIM, RUNTIME_DIM
403
+ vnniFactor, T, TResult, TM, TN, TK, MCache1, NCache1,
404
+ KCache1, MCache2, NCache2, KCache2>
405
+ (A, B, C, q, i
406
+ #if defined(ARG_DIM) || defined(RUNTIME_DIM)
407
+ , matrix_size, matrix_size, matrix_size, matrix_size
408
+ #endif // ARG_DIM, RUNTIME_DIM
409
+ );
410
+
388
411
if (i >= recordThresh) {
389
412
totalDuration += duration;
390
413
}
391
414
}
392
415
393
- assert (matrix_compare (MATRIX_SIZE, MATRIX_SIZE , C, refC));
416
+ assert (matrix_compare (matrix_size, matrix_size , C, refC));
394
417
395
418
double msecPerMatrixMul =
396
419
totalDuration / static_cast <double >(testIterations - recordThresh);
397
- double gflops = (2 .f * MATRIX_SIZE * MATRIX_SIZE * MATRIX_SIZE * 1 .0e-9f ) /
420
+ double gflops = (2 .f * matrix_size * matrix_size * matrix_size * 1 .0e-9f ) /
398
421
(msecPerMatrixMul / 1000 .f );
399
422
400
- std::cout << " DONE for size " << MATRIX_SIZE << std::endl;
423
+ std::cout << " DONE for size " << matrix_size << std::endl;
401
424
std::cout << " GOPS is " << gflops << " Gop/s" << std::endl;
402
425
403
426
free (A, q);
@@ -406,7 +429,22 @@ void test() {
406
429
free (refC, q);
407
430
}
408
431
409
- int main () {
432
+ int main (
433
+ #ifdef RUNTIME_DIM
434
+ int argc, char *argv[]
435
+ #endif // RUNTIME_DIM
436
+ ) {
437
+
438
+ size_t matrix_size = -1 ;
439
+ #ifdef RUNTIME_DIM
440
+ if (argc == 2 ) {
441
+ matrix_size = std::stoul (argv[1 ]);
442
+ } else {
443
+ std::cerr << " Usage: ./program matrix_size\n " ;
444
+ return 1 ; // Error if no argument
445
+ }
446
+ #endif // RUNTIME_DIM
447
+
410
448
queue q;
411
449
std::vector<combination> combinations =
412
450
q.get_device ()
@@ -429,22 +467,22 @@ int main() {
429
467
constexpr size_t NCache1 = 32 ;
430
468
constexpr size_t KCache1 = 32 ;
431
469
test<bfloat16, float , VnniFactor, /* TM*/ 16 , /* TN*/ 16 , /* TK*/ 32 ,
432
- MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>();
470
+ MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size );
433
471
break ;
434
472
}
435
473
436
474
if (combinations[i].nsize == 16 ) { // architecture::intel_gpu_pvc
437
475
constexpr size_t NCache1 = 4 * /* TN*/ 16 ;
438
476
constexpr size_t KCache1 = 16 ;
439
477
test<bfloat16, float , VnniFactor, /* TM*/ 8 , /* TN*/ 16 , /* TK*/ 16 , MCache1,
440
- NCache1, KCache1, MCache2, NCache2, KCache2>();
478
+ NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size );
441
479
#if (!defined(SG_SZ) || SG_SZ != 32)
442
480
// These combination are not currently supported for subgroup size = 32 in
443
481
// IGC
444
482
test<bfloat16, float , VnniFactor, /* TM*/ 16 , /* TN*/ 16 , /* TK*/ 16 ,
445
- MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>();
483
+ MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size );
446
484
test<bfloat16, float , VnniFactor, /* TM*/ 32 , /* TN*/ 64 , /* TK*/ 16 ,
447
- MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>();
485
+ MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size );
448
486
#endif
449
487
break ;
450
488
}
@@ -454,10 +492,9 @@ int main() {
454
492
constexpr size_t KCache1 = 16 ;
455
493
456
494
test<bfloat16, float , VnniFactor, /* TM*/ 8 , /* TN*/ 8 , /* TK*/ 16 , MCache1,
457
- NCache1, KCache1, MCache2, NCache2, KCache2>();
458
- // test<bfloat16, float, VnniFactor, /*TM*/ 32, /*TN*/ 32, /*TK*/ 16,
459
- // MCache1,
460
- // NCache1, KCache1, MCache2, NCache2, KCache2>();
495
+ NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size);
496
+ // test<bfloat16, float, VnniFactor, /*TM*/ 32, /*TN*/ 32, /*TK*/ 16, MCache1,
497
+ // NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size);
461
498
break ;
462
499
}
463
500
}
0 commit comments