Skip to content

Commit 1ecd0d7

Browse files
committed
[ BLAS ] Implement transpose case functions for K=1 GEMM
- To cover transpose cases like, (1,M).T * (1,N) and all other transpose combinations, transpose with SIMD, and apply the original kernel **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: skykongkong8 <ss.kong@samsung.com>
1 parent 7abcbd3 commit 1ecd0d7

File tree

4 files changed

+144
-9
lines changed

4 files changed

+144
-9
lines changed

nntrainer/tensor/blas_neon.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,9 +1589,7 @@ unsigned int isamax(const unsigned int N, const __fp16 *X) {
15891589
void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
15901590
uint32_t K, float alpha, float beta, bool TransA, bool TransB) {
15911591
if (K == 1) {
1592-
unsigned int lda = (TransA) ? M : K;
1593-
unsigned int ldb = (TransB) ? K : N;
1594-
return hgemm_K1(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
1592+
return hgemm_K1(A, B, C, M, N, K, alpha, beta, TransA, TransB);
15951593
}
15961594
// dynamic creation to avoid reaching stack limit(causes segmentation fault)
15971595
float *C32 = (float *)malloc(M * N * sizeof(float));
@@ -1644,6 +1642,22 @@ void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
16441642
free(C32);
16451643
}
16461644

1645+
void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
1646+
uint32_t N, uint32_t K, float alpha, float beta, bool TransA,
1647+
bool TransB) {
1648+
unsigned int lda = (TransA) ? M : K;
1649+
unsigned int ldb = (TransB) ? K : N;
1650+
if (!TransA && TransB) {
1651+
hgemm_K1_TransB(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
1652+
} else if (TransA && !TransB) {
1653+
hgemm_K1_TransA(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
1654+
} else if (!TransA && !TransB) {
1655+
hgemm_K1_noTrans(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
1656+
} else { // TransA && TransB
1657+
hgemm_K1_TransAB(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
1658+
}
1659+
}
1660+
16471661
void ele_mul(const unsigned int N, const __fp16 *X, const __fp16 *Y, __fp16 *Z,
16481662
float alpha, float beta) {
16491663
unsigned int i = 0;

nntrainer/tensor/blas_neon.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,22 @@ unsigned int isamax(const unsigned int N, const __fp16 *X);
330330
void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
331331
uint32_t K, float alpha, float beta, bool TransA, bool TransB);
332332

333+
/**
334+
* @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
335+
* where op(X) is one of X or X**T
336+
* @param[in] A __fp16 * for Matrix A
337+
* @param[in] B __fp16 * for Matrix B
338+
* @param[in] C __fp16 * for Matrix C
339+
* @param[in] M number of op(A)'s and C's row
340+
* @param[in] N number of op(B)'s and C's columns
341+
* @param[in] K number of op(A)'s and columns and op(B)'s rows
342+
* @param[in] alpha float number
343+
* @param[in] beta float number
344+
*/
345+
void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
346+
uint32_t N, uint32_t K, float alpha, float beta, bool TransA,
347+
bool TransB);
348+
333349
/**
334350
* @brief squared root transformation with neon : X = sqrt(X)
335351
*

nntrainer/tensor/hgemm/hgemm.cpp

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,23 +76,74 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
7676
}
7777
}
7878

79-
void hgemm_K1(unsigned int M, unsigned int N, unsigned int K,
79+
void hgemm_K1_noTrans(unsigned int M, unsigned int N, unsigned int K,
8080
const __fp16 *A, unsigned int lda, const __fp16 *B,
8181
unsigned int ldb, __fp16 *C, unsigned int ldc,
8282
float alpha, float beta) {
83+
const float eps = std::numeric_limits<float>::epsilon();
8384
float16x8_t a_vec;
8485
unsigned int N8 = (N >> 3) << 3;
8586
for (unsigned int m = 0; m < M; ++m) {
86-
a_vec = vmovq_n_f16(A[m]);
87-
for (unsigned int n = 0; n < N8; n += 8) {
88-
vst1q_f16(&C[m * ldc + n], vmulq_f16(a_vec, vld1q_f16(&B[n])));
87+
a_vec = vmovq_n_f16(alpha * A[m]);
88+
if (std::fpclassify(beta) != FP_ZERO) {
89+
for (unsigned int n = 0; n < N8; n += 8) {
90+
vst1q_f16(&C[m * ldc + n],
91+
vaddq_f16(vmulq_f16(a_vec, vld1q_f16(&B[n])),
92+
vmulq_n_f16(vld1q_f16(&C[m * ldc + n]), beta)));
93+
}
94+
} else {
95+
for (unsigned int n = 0; n < N8; n += 8) {
96+
vst1q_f16(&C[m * ldc + n], vmulq_f16(a_vec, vld1q_f16(&B[n])));
97+
}
8998
}
9099
for (unsigned int n = N8; n < N; ++n) {
91-
C[m * ldc + n] = A[m] * B[n];
100+
C[m * ldc + n] = alpha * A[m] * B[n] + beta * C[m * ldc + n];
92101
}
93102
}
94103
}
95104

105+
void hgemm_K1_transA(unsigned int M, unsigned int N, unsigned int K,
106+
const __fp16 *A, unsigned int lda, const __fp16 *B,
107+
unsigned int ldb, __fp16 *C, unsigned int ldc, float alpha,
108+
float beta) {
109+
__fp16 *A_T = new __fp16[M * K];
110+
111+
transpose_neon<__fp16>(K, M, A, M, A_T, K);
112+
113+
hgemm_K1_noTrans(M, N, K, A_T, lda, B, ldb, C, ldc, alpha, beta);
114+
115+
free(A_T);
116+
}
117+
118+
void hgemm_K1_transB(unsigned int M, unsigned int N, unsigned int K,
119+
const __fp16 *A, unsigned int lda, const __fp16 *B,
120+
unsigned int ldb, __fp16 *C, unsigned int ldc, float alpha,
121+
float beta) {
122+
__fp16 *B_T = new __fp16[K * N];
123+
124+
transpose_neon<__fp16>(N, K, B, K, B_T, N);
125+
126+
hgemm_K1_noTrans(M, N, K, A, lda, B_T, ldb, C, ldc, alpha, beta);
127+
128+
free(B_T);
129+
}
130+
131+
void hgemm_K1_transAB(unsigned int M, unsigned int N, unsigned int K,
132+
const __fp16 *A, unsigned int lda, const __fp16 *B,
133+
unsigned int ldb, __fp16 *C, unsigned int ldc,
134+
float alpha, float beta) {
135+
__fp16 *A_T = new __fp16[M * K];
136+
__fp16 *B_T = new __fp16[K * N];
137+
138+
transpose_neon<__fp16>(K, M, A, M, A_T, K);
139+
transpose_neon<__fp16>(N, K, B, K, B_T, N);
140+
141+
hgemm_K1_noTrans(M, N, K, A_T, lda, B_T, ldb, C, ldc, alpha, beta);
142+
143+
free(A_T);
144+
free(B_T);
145+
}
146+
96147
void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
97148
const __fp16 *A, unsigned int lda, const __fp16 *B,
98149
unsigned int ldb, __fp16 *C, unsigned int ldc,

nntrainer/tensor/hgemm/hgemm.h

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,61 @@ void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
7575
* @param[in] alpha float number
7676
* @param[in] beta float number
7777
*/
78-
void hgemm_K1(unsigned int M, unsigned int N, unsigned int K,
78+
void hgemm_K1_noTrans(unsigned int M, unsigned int N, unsigned int K,
79+
const __fp16 *A, unsigned int lda, const __fp16 *B,
80+
unsigned int ldb, __fp16 *C, unsigned int ldc,
81+
float alpha = 1.F, float beta = 0.F);
82+
/**
83+
* @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
84+
* @param M length of the row of matrix A
85+
* @param N length of the col of matrix B
86+
* @param K length of the col of matrix A
87+
* @param A input matrix A
88+
* @param lda length of the col of matrix A
89+
* @param B input matrix B
90+
* @param ldb length of the col of matrix B
91+
* @param C output matrix C
92+
* @param ldc length of the col of matrix C
93+
* @param[in] alpha float number
94+
* @param[in] beta float number
95+
*/
96+
void hgemm_K1_transA(unsigned int M, unsigned int N, unsigned int K,
97+
const __fp16 *A, unsigned int lda, const __fp16 *B,
98+
unsigned int ldb, __fp16 *C, unsigned int ldc,
99+
float alpha = 1.F, float beta = 0.F);
100+
/**
101+
* @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
102+
* @param M length of the row of matrix A
103+
* @param N length of the col of matrix B
104+
* @param K length of the col of matrix A
105+
* @param A input matrix A
106+
* @param lda length of the col of matrix A
107+
* @param B input matrix B
108+
* @param ldb length of the col of matrix B
109+
* @param C output matrix C
110+
* @param ldc length of the col of matrix C
111+
* @param[in] alpha float number
112+
* @param[in] beta float number
113+
*/
114+
void hgemm_K1_transB(unsigned int M, unsigned int N, unsigned int K,
115+
const __fp16 *A, unsigned int lda, const __fp16 *B,
116+
unsigned int ldb, __fp16 *C, unsigned int ldc,
117+
float alpha = 1.F, float beta = 0.F);
118+
/**
119+
* @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
120+
* @param M length of the row of matrix A
121+
* @param N length of the col of matrix B
122+
* @param K length of the col of matrix A
123+
* @param A input matrix A
124+
* @param lda length of the col of matrix A
125+
* @param B input matrix B
126+
* @param ldb length of the col of matrix B
127+
* @param C output matrix C
128+
* @param ldc length of the col of matrix C
129+
* @param[in] alpha float number
130+
* @param[in] beta float number
131+
*/
132+
void hgemm_K1_transAB(unsigned int M, unsigned int N, unsigned int K,
79133
const __fp16 *A, unsigned int lda, const __fp16 *B,
80134
unsigned int ldb, __fp16 *C, unsigned int ldc,
81135
float alpha = 1.F, float beta = 0.F);

0 commit comments

Comments
 (0)