Skip to content

Commit dc8ca49

Browse files
committed
[software] Add flag to fold q16 MMSE
1 parent 218f1ef commit dc8ca49

File tree

3 files changed

+78
-36
lines changed

3 files changed

+78
-36
lines changed

software/apps/baremetal/mimo_mmse_q16/main.c

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,29 @@ Parameters and defines
2222
2323
PARALLEL: When defined benchmark parallel MIMO-MMSE.
2424
SINGLE: When defined benchmark single-core MIMO-MMSE.
25+
FOLD: When defined 1 fold matrices in memory.
2526
*/
2627

27-
int16_t l1_H[2 * N_TX * N_RX * N_ITR]
28-
__attribute__((aligned(BANKING_FACTOR * NUM_CORES * sizeof(int32_t)),
29-
section(".l1_prio")));
28+
#define FOLD (1)
29+
#define PARALLEL
30+
31+
#if FOLD
32+
#define NUM_ROW (1 + ((N_ITR * N_TX - 1) / NUM_BANKS))
33+
#define NUM_COL (NUM_BANKS / N_TX)
34+
35+
int16_t l1_G[2 * N_TX * NUM_BANKS * NUM_ROW]
36+
__attribute__((aligned(sizeof(int32_t)), section(".l1_prio")));
37+
int16_t l1_L[2 * N_TX * NUM_BANKS * NUM_ROW]
38+
__attribute__((aligned(sizeof(int32_t)), section(".l1_prio")));
39+
#else
3040
int16_t l1_G[2 * N_TX * N_TX * N_ITR]
31-
__attribute__((aligned(BANKING_FACTOR * NUM_CORES * sizeof(int32_t)),
32-
section(".l1_prio")));
41+
__attribute__((aligned(sizeof(int32_t)), section(".l1_prio")));
3342
int16_t l1_L[2 * N_TX * N_TX * N_ITR]
34-
__attribute__((aligned(BANKING_FACTOR * NUM_CORES * sizeof(int32_t)),
35-
section(".l1_prio")));
43+
__attribute__((aligned(sizeof(int32_t)), section(".l1_prio")));
44+
#endif
3645

46+
int16_t l1_H[2 * N_TX * N_RX * N_ITR]
47+
__attribute__((aligned(sizeof(int32_t)), section(".l1_prio")));
3748
int16_t l1_S[2 * N_TX * N_ITR]
3849
__attribute__((aligned(sizeof(int32_t)), section(".l1_prio")));
3950
int16_t l1_y[2 * N_RX * N_ITR]
@@ -51,12 +62,14 @@ int main() {
5162
uint32_t core_id = mempool_get_core_id();
5263
uint32_t num_cores = mempool_get_core_count();
5364
mempool_barrier_init(core_id); // Initialize barrier and synchronize
65+
uint32_t time_init, time_end;
5466

5567
/* Initialize matrices */
5668
if (core_id == 0) {
5769
dma_memcpy_blocking(l1_H, l2_H, N_TX * N_RX * N_ITR * sizeof(int32_t));
5870
dma_memcpy_blocking(l1_y, l2_y, N_RX * N_ITR * sizeof(int32_t));
5971
dma_memcpy_blocking(l1_S, l2_S, N_TX * N_ITR * sizeof(int32_t));
72+
printf("Data transferred\n");
6073
}
6174
mempool_barrier(num_cores);
6275

@@ -65,13 +78,18 @@ int main() {
6578

6679
if (core_id == 0) {
6780
mempool_start_benchmark();
68-
mempool_hermitian_q16vecs((v2s *)l1_H, (v2s *)l1_G, (v2s *)l1_Sigma, N_RX,
69-
N_TX);
70-
mempool_MVP_conjtransp_q16vecs((v2s *)l1_H, (v2s *)l1_y, (v2s *)y2, N_RX,
71-
N_TX, 0);
72-
mempool_cholesky_q16vecs(l1_G, l1_L, N_TX);
73-
mempool_Ltrisol_q16vecs(l1_L, y2, y3, N_TX, 0);
74-
mempool_Ltrisol_q16vecs(l1_L, y3, l1_x, N_TX, 1);
81+
time_init = mempool_get_timer();
82+
v2s *PtrH = (v2s *)l1_H;
83+
v2s *PtrG = (v2s *)l1_G;
84+
v2s *PtrS = (v2s *)l1_Sigma;
85+
v2s *Ptry = (v2s *)l1_y;
86+
v2s *Ptry2 = (v2s *)y2;
87+
mempool_hermitian_q16vecs(PtrH, PtrG, PtrS, N_RX, N_TX);
88+
mempool_MVP_conjtransp_q16vecs(PtrH, Ptry, Ptry2, N_RX, N_TX, FOLD);
89+
mempool_cholesky_q16vecs(l1_G, l1_L, N_TX, FOLD);
90+
mempool_Ltrisol_q16vecs(l1_L, y2, y3, N_TX, 0, FOLD);
91+
mempool_Ltrisol_q16vecs(l1_L, y3, l1_x, N_TX, 1, FOLD);
92+
time_end = mempool_get_timer();
7593
mempool_stop_benchmark();
7694
}
7795
mempool_barrier(num_cores);
@@ -81,30 +99,49 @@ int main() {
8199
#ifdef PARALLEL
82100

83101
mempool_start_benchmark();
102+
time_init = mempool_get_timer();
84103
for (uint32_t itr = core_id; itr < N_ITR; itr += num_cores) {
85104

86105
int16_t *PtrH = l1_H + itr * (2 * N_TX * N_RX);
87106
int16_t *Ptry = l1_y + itr * (2 * N_RX);
88-
int16_t *PtrSigma = l1_S + itr * (2 * N_TX);
89-
107+
int16_t *PtrS = l1_S + itr * (2 * N_TX);
108+
109+
#if FOLD
110+
int16_t *PtrG = l1_G + (itr / NUM_COL) * (2 * N_TX * NUM_BANKS) +
111+
(itr % NUM_COL) * (2 * N_TX);
112+
int16_t *PtrL = l1_L + (itr / NUM_COL) * (2 * N_TX * NUM_BANKS) +
113+
(itr % NUM_COL) * (2 * N_TX);
114+
int16_t *Ptry2 =
115+
y2 + (itr / NUM_COL) * (2 * NUM_BANKS) + (itr % NUM_COL) * (2 * N_TX);
116+
int16_t *Ptry3 =
117+
y3 + (itr / NUM_COL) * (2 * NUM_BANKS) + (itr % NUM_COL) * (2 * N_TX);
118+
int16_t *Ptrx = l1_x + itr * (2 * N_TX);
119+
#else
90120
int16_t *PtrG = l1_G + itr * (2 * N_TX * N_TX);
91121
int16_t *PtrL = l1_L + itr * (2 * N_TX * N_TX);
92122
int16_t *Ptry2 = y2 + itr * (2 * N_TX);
93123
int16_t *Ptry3 = y3 + itr * (2 * N_TX);
94124
int16_t *Ptrx = l1_x + itr * (2 * N_TX);
125+
#endif
95126

96-
mempool_hermitian_q16vecs((v2s *)PtrH, (v2s *)PtrG, (v2s *)PtrSigma, N_RX,
127+
mempool_hermitian_q16vecs((v2s *)PtrH, (v2s *)PtrG, (v2s *)PtrS, N_RX,
97128
N_TX);
98129
mempool_MVP_conjtransp_q16vecs((v2s *)PtrH, (v2s *)Ptry, (v2s *)Ptry2, N_RX,
99-
N_TX, 0);
100-
mempool_cholesky_q16vecs(PtrG, PtrL, N_TX);
101-
mempool_Ltrisol_q16vecs(PtrL, Ptry2, Ptry3, N_TX, 0);
102-
mempool_Ltrisol_q16vecs(PtrL, Ptry3, Ptrx, N_TX, 1);
130+
N_TX, FOLD);
131+
mempool_cholesky_q16vecs(PtrG, PtrL, N_TX, FOLD);
132+
mempool_Ltrisol_q16vecs(PtrL, Ptry2, Ptry3, N_TX, 0, FOLD);
133+
mempool_Ltrisol_q16vecs(PtrL, Ptry3, Ptrx, N_TX, 1, FOLD);
103134
}
104-
mempool_log_barrier(2, core_id);
135+
mempool_barrier(num_cores);
136+
time_end = mempool_get_timer();
105137
mempool_stop_benchmark();
106138

107139
#endif
108140

141+
if (core_id == 0) {
142+
printf("Runtime: %d\n", time_end - time_init);
143+
}
144+
mempool_barrier(num_cores);
145+
109146
return 0;
110147
}

software/kernels/baremetal/mempool_cholesky_q16s.h

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
@param[in] n dimension of the input data
1717
@return none
1818
*/
19-
void mempool_cholesky_q16vecs(int16_t *pSrc, int16_t *pL, const uint32_t n) {
19+
void mempool_cholesky_q16vecs(int16_t *pSrc, int16_t *pL, const uint32_t n,
20+
const uint32_t folded) {
2021

21-
uint32_t i, j, k;
2222
int32_t sum; // Sum for elements on diagonal (real)
2323
int32_t diag; // Diagonal element (real)
2424
int32_t as, bs; // Sum for elements on rows (complex)
2525
int32_t ap, bp; // Pivot elements (complex)
26+
uint32_t i, j, k;
27+
const uint32_t offset = folded ? NUM_BANKS : n;
2628

2729
v2s ab = (v2s){0, 0};
2830
v2s cd = (v2s){0, 0};
@@ -33,30 +35,30 @@ void mempool_cholesky_q16vecs(int16_t *pSrc, int16_t *pL, const uint32_t n) {
3335

3436
// Elements on diagonal (input matrix is positive-definite)
3537
sum = 0;
36-
diag = (int32_t)pSrc[2 * (j * n + j)];
38+
diag = (int32_t)pSrc[2 * (j * offset + j)];
3739
for (k = 0; k < j; k++) {
38-
ab = *(v2s *)&pL[2 * (j * n + k)];
40+
ab = *(v2s *)&pL[2 * (j * offset + k)];
3941
asm volatile("pv.dotsp.h %[sum], %[ab], %[ab];"
4042
"srai %[sum], %[sum], 0x8;"
4143
"p.clip %[sum], %[sum], 0x16;"
4244
: [sum] "+&r"(sum)
4345
: [ab] "r"(ab)
4446
:);
4547
}
46-
pL[2U * (j * n + j)] = (int16_t)mempool_sqrt_q32s(diag - sum, 16);
48+
pL[2U * (j * offset + j)] = (int16_t)mempool_sqrt_q32s(diag - sum, 16);
4749

4850
// Elements on rows
4951
for (i = j + 1; i < n; i++) {
50-
ap = (int32_t)pSrc[2 * (i * n + j)]; // Pivot
51-
bp = (int32_t)pSrc[2 * (i * n + j) + 1]; // Pivot
52-
diag = (int32_t)pL[2 * (j * n + j)]; // Diag
52+
ap = (int32_t)pSrc[2 * (i * offset + j)]; // Pivot
53+
bp = (int32_t)pSrc[2 * (i * offset + j) + 1]; // Pivot
54+
diag = (int32_t)pL[2 * (j * offset + j)]; // Diag
5355

5456
as = 0;
5557
bs = 0;
5658
// Sum -> s = s + (ac + bd) + j*(bc - ad)
5759
for (k = 0; k < j; k++) {
58-
ab = *(v2s *)&pL[2U * (i * n + k)];
59-
cd = *(v2s *)&pL[2U * (j * n + k)];
60+
ab = *(v2s *)&pL[2U * (i * offset + k)];
61+
cd = *(v2s *)&pL[2U * (j * offset + k)];
6062
const uint32_t shuffle_mask = 0x00020003;
6163
asm volatile(
6264
// s = s + (ac + bd) + j(bc - ad)
@@ -81,7 +83,7 @@ void mempool_cholesky_q16vecs(int16_t *pSrc, int16_t *pL, const uint32_t n) {
8183
: [ap] "+&r"(ap), [bp] "+&r"(bp), [res] "+&r"(res)
8284
: [as] "r"(as), [bs] "r"(bs), [diag] "r"(diag)
8385
:);
84-
(*(v2s *)&pL[2 * (i * n + j)]) = res;
86+
(*(v2s *)&pL[2 * (i * offset + j)]) = res;
8587
}
8688
}
8789
return;

software/kernels/baremetal/mempool_linearsolver_q16s.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,29 @@
1717
*/
1818

1919
void mempool_Ltrisol_q16vecs(int16_t *pL, int16_t *y, int16_t *x,
20-
const uint32_t n, const uint32_t transposed) {
20+
const uint32_t n, const uint32_t transposed,
21+
const uint32_t folded) {
2122

2223
uint32_t i, j;
2324
int32_t as, bs, diag;
2425
v2s ab, cd;
2526
v2s res = (v2s){0, 0};
2627
v2s ndc = (v2s){0, 0};
28+
const uint32_t offset = folded ? NUM_BANKS : n;
2729

2830
// Solve for each variable x[i] in loop
2931
for (i = 0; i < n; i++) {
3032
uint32_t ridx = transposed ? (n - i - 1) : i;
31-
diag = pL[2U * (ridx + ridx)];
33+
diag = pL[2U * (ridx * offset + ridx)];
3234
// Initialize the sums
3335
as = 0;
3436
bs = 0;
3537
// Use the previously solved variables to compute the sum
3638
for (j = 0; j < i; j++) {
39+
3740
uint32_t cidx = transposed ? (n - j - 1) : j;
3841
if (!transposed) {
39-
ab = *(v2s *)&pL[2U * (ridx * n + cidx)];
42+
ab = *(v2s *)&pL[2U * (ridx * offset + cidx)];
4043
} else {
4144
ab = *(v2s *)&pL[2U * (cidx * n + ridx)];
4245
}

0 commit comments

Comments
 (0)