Skip to content

Commit

Permalink
switch block read functions to the production names
Browse files Browse the repository at this point in the history
Tiled kernels still need to be enabled and ported.
  • Loading branch information
bashbaug committed Feb 27, 2025
1 parent dddfdf3 commit 41159a8
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 26 deletions.
8 changes: 8 additions & 0 deletions samples/99_matrixexperimentsi8/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ static void i8_dpas_blockread_rowmajor(
cl::Kernel kernel{program, kernelName.c_str()};
if (kernel() == nullptr) {
printf("unsupported.\n");
} else if (K < 64 || N < 64) {
printf("matrix pitch for block reads must be >= 64 bytes.\n");
} else {
kernel.setArg(0, C);
kernel.setArg(1, A);
Expand Down Expand Up @@ -502,6 +504,8 @@ static void i8_dpas_blockread_rowmajor_tiled(
printf("M is too small.\n");
} else if (tN * NN > N) {
printf("N is too small.\n");
} else if (K < 64 || N < 64) {
printf("matrix pitch for block reads must be >= 64 bytes.\n");
} else {
kernel.setArg(0, C);
kernel.setArg(1, A);
Expand Down Expand Up @@ -555,6 +559,8 @@ static void i8_dpas_blockread_vnni(
cl::Kernel kernel{program, kernelName.c_str()};
if (kernel() == nullptr) {
printf("unsupported.\n");
} else if (K < 64 || N < 64/4) {
printf("matrix pitch for block reads must be >= 64 bytes.\n");
} else {
kernel.setArg(0, C);
kernel.setArg(1, A);
Expand Down Expand Up @@ -614,6 +620,8 @@ static void i8_dpas_blockread_vnni_tiled(
printf("M is too small.\n");
} else if (tN * NN > N) {
printf("N is too small.\n");
} else if (K < 64 || N < 64/4) {
printf("matrix pitch for block reads must be >= 64 bytes.\n");
} else {
kernel.setArg(0, C);
kernel.setArg(1, A);
Expand Down
68 changes: 42 additions & 26 deletions samples/99_matrixexperimentsi8/matrix_kernels_i8.cl
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ kernel void i8_dpas_vnni_m8_n16(global int* C, global char* A, global char* B, i
store_c_rowmajor_int32_m8_nx(C, sum, m, n, N);
}

#ifdef cl_intel_subgroup_extended_block_read
#ifdef cl_intel_subgroup_2d_block_io

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
kernel void i8_dpas_blockread_rowmajor_m1_n16(global int* C, global char* A, global char* B, int K)
Expand All @@ -395,13 +395,15 @@ kernel void i8_dpas_blockread_rowmajor_m1_n16(global int* C, global char* A, glo

int sum = 0;
for (int k = 0; k < K; k += tK) {
short aData = as_short(intel_subgroup_block_read_u8_m1k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_transform_u8_k32(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k)));
short aData;
intel_sub_group_2d_block_read_8b_1r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData);
int8 bData;
intel_sub_group_2d_block_read_transform_8b_32r16x1c(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k), (uint*)&bData);
sum = mat_mul_sg16(aData, bData, sum);
}

sum = activation(sum);
intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum));
intel_sub_group_2d_block_write_32b_1r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum);
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
Expand All @@ -417,13 +419,15 @@ kernel void i8_dpas_blockread_rowmajor_m2_n16(global int* C, global char* A, glo

int2 sum = 0;
for (int k = 0; k < K; k += tK) {
short2 aData = as_short2(intel_subgroup_block_read_u8_m2k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_transform_u8_k32(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k)));
short2 aData;
intel_sub_group_2d_block_read_8b_2r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData);
int8 bData;
intel_sub_group_2d_block_read_transform_8b_32r16x1c(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k), (uint*)&bData);
sum = mat_mul_sg16(aData, bData, sum);
}

sum = activation(sum);
intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum));
intel_sub_group_2d_block_write_32b_2r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum);
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
Expand All @@ -439,13 +443,15 @@ kernel void i8_dpas_blockread_rowmajor_m4_n16(global int* C, global char* A, glo

int4 sum = 0;
for (int k = 0; k < K; k += tK) {
short4 aData = as_short4(intel_subgroup_block_read_u8_m4k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_transform_u8_k32(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k)));
short4 aData;
intel_sub_group_2d_block_read_8b_4r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData);
int8 bData;
intel_sub_group_2d_block_read_transform_8b_32r16x1c(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k), (uint*)&bData);
sum = mat_mul_sg16(aData, bData, sum);
}

sum = activation(sum);
intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum));
intel_sub_group_2d_block_write_32b_4r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum);
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
Expand All @@ -461,13 +467,15 @@ kernel void i8_dpas_blockread_rowmajor_m8_n16(global int* C, global char* A, glo

int8 sum = 0;
for (int k = 0; k < K; k += tK) {
short8 aData = as_short8(intel_subgroup_block_read_u8_m8k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_transform_u8_k32(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k)));
short8 aData;
intel_sub_group_2d_block_read_8b_8r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData);
int8 bData;
intel_sub_group_2d_block_read_transform_8b_32r16x1c(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k), (uint*)&bData);
sum = mat_mul_sg16(aData, bData, sum);
}

sum = activation(sum);
intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum));
intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum);
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
Expand All @@ -483,13 +491,15 @@ kernel void i8_dpas_blockread_vnni_m1_n16(global int* C, global char* A, global

int sum = 0;
for (int k = 0; k < K; k += tK) {
short aData = as_short(intel_subgroup_block_read_u8_m1k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4)));
short aData;
intel_sub_group_2d_block_read_8b_1r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData);
int8 bData;
intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4), (uint*)&bData);
sum = mat_mul_sg16(aData, bData, sum);
}

sum = activation(sum);
intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum));
intel_sub_group_2d_block_write_32b_1r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum);
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
Expand All @@ -505,13 +515,15 @@ kernel void i8_dpas_blockread_vnni_m2_n16(global int* C, global char* A, global

int2 sum = 0;
for (int k = 0; k < K; k += tK) {
short2 aData = as_short2(intel_subgroup_block_read_u8_m2k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4)));
short2 aData;
intel_sub_group_2d_block_read_8b_2r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData);
int8 bData;
intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4), (uint*)&bData);
sum = mat_mul_sg16(aData, bData, sum);
}

sum = activation(sum);
intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum));
intel_sub_group_2d_block_write_32b_2r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum);
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
Expand All @@ -527,13 +539,15 @@ kernel void i8_dpas_blockread_vnni_m4_n16(global int* C, global char* A, global

int4 sum = 0;
for (int k = 0; k < K; k += tK) {
short4 aData = as_short4(intel_subgroup_block_read_u8_m4k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4)));
short4 aData;
intel_sub_group_2d_block_read_8b_4r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData);
int8 bData;
intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4), (uint*)&bData);
sum = mat_mul_sg16(aData, bData, sum);
}

sum = activation(sum);
intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum));
intel_sub_group_2d_block_write_32b_4r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum);
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
Expand All @@ -549,16 +563,18 @@ kernel void i8_dpas_blockread_vnni_m8_n16(global int* C, global char* A, global

int8 sum = 0;
for (int k = 0; k < K; k += tK) {
short8 aData = as_short8(intel_subgroup_block_read_u8_m8k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4)));
short8 aData;
intel_sub_group_2d_block_read_8b_8r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData);
int8 bData;
intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4), (uint*)&bData);
sum = mat_mul_sg16(aData, bData, sum);
}

sum = activation(sum);
intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum));
intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum);
}

#endif // cl_intel_subgroup_extended_block_read
#endif // cl_intel_subgroup_2d_block_io

#if 0 // disable the tiled cases for now

Expand Down

0 comments on commit 41159a8

Please sign in to comment.