Skip to content

Commit

Permalink
Refactor indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
Aba committed Oct 27, 2023
1 parent 5274542 commit aa18cb5
Showing 1 changed file with 28 additions and 38 deletions.
66 changes: 28 additions & 38 deletions c/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,20 @@ static inline int quant_lrelu(int x, signed char nzero, signed char shift, signe
return x;
}

static inline void write_x(signed char val, int ib, int ixp, int ixn, int ixl, int ixw, int ixcm, int ixr, Bundle_t *p_bo, int X_CMP){

int idx_n2r = ixn * (p_bo->l * p_bo->w * X_CMP * (PE_ROWS+X_PAD))
+ ixl * ( p_bo->w * X_CMP * (PE_ROWS+X_PAD))
+ ixw * ( X_CMP * (PE_ROWS+X_PAD))
+ ixcm * ( (PE_ROWS+X_PAD))
+ ixr;
int idx = (ixp == 0) ? idx_n2r : ( p_bo->n * p_bo->l * p_bo->w * p_bo->cm_p0 * (PE_ROWS+X_PAD)
+ (ixp-1) * p_bo->n * p_bo->l * p_bo->w * p_bo->cm * (PE_ROWS+X_PAD)
+ idx_n2r);
mem.nx[idx] = val;

if (!(ixr < PE_ROWS+X_PAD)) assert(0*printf("%d >= %d --------- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d X_CMP:%d \n", ixr, PE_ROWS+X_PAD, ib,ixp,ixn,ixl,ixw,ixcm,ixr,X_CMP));
if (!(ixcm < X_CMP )) assert(0*printf("%d >= %d --------- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d X_CMP:%d \n", ixcm, X_CMP, ib,ixp,ixn,ixl,ixw,ixcm,ixr,X_CMP));
if (!(ixw < p_bo->w )) assert(0*printf("%d >= %d --------- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d X_CMP:%d \n", ixw , p_bo->w, ib,ixp,ixn,ixl,ixw,ixcm,ixr,X_CMP));
if (!(ixl < p_bo->l )) assert(0*printf("%d >= %d --------- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d X_CMP:%d \n", ixl , p_bo->l, ib,ixp,ixn,ixl,ixw,ixcm,ixr,X_CMP));
if (!(ixn < p_bo->n )) assert(0*printf("%d >= %d --------- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d X_CMP:%d \n", ixn , p_bo->n, ib,ixp,ixn,ixl,ixw,ixcm,ixr,X_CMP));
if (!(ixp < p_bo->p )) assert(0*printf("%d >= %d --------- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d X_CMP:%d \n", ixp , p_bo->p, ib,ixp,ixn,ixl,ixw,ixcm,ixr,X_CMP));
static inline void write_x(signed char val, int ib, int ixp, int ixn, int ixl, int ixw, int ixcm, int ixr, Bundle_t *p_bo, int xcm ){

int p_offset = (ixp == 0) ? 0 : (p_bo->cm_p0 + (ixp-1)*p_bo->cm) *p_bo->n*p_bo->l*p_bo->w*(PE_ROWS+X_PAD);

int flat_index_n2r = (((ixn*p_bo->l + ixl)*p_bo->w + ixw)*xcm + ixcm)*(PE_ROWS+X_PAD) + ixr; // multidim_index -> flat_index [n,l,w,cm,r]

if (!( ixr < PE_ROWS+X_PAD)) assert(0*printf("ixr : %d >= %d --------- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n", ixr, PE_ROWS+X_PAD, ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm ));
if (!( ixcm < xcm )) assert(0*printf("ixcm: %d >= %d --------- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n", ixcm, xcm , ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm ));
if (!( ixw < p_bo->w )) assert(0*printf("ixw : %d >= %d --------- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n", ixw , p_bo->w, ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm ));
if (!( ixl < p_bo->l )) assert(0*printf("ixl : %d >= %d --------- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n", ixl , p_bo->l, ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm ));
if (!( ixn < p_bo->n )) assert(0*printf("ixn : %d >= %d --------- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n", ixn , p_bo->n, ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm ));
if (!( ixp < p_bo->p )) assert(0*printf("ixp : %d >= %d --------- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n", ixp , p_bo->p, ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm ));

mem.nx[p_offset + flat_index_n2r] = val;
}


Expand Down Expand Up @@ -132,7 +128,7 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c

if (ib == N_BUNDLES-1){ // Last bundle: save as nhwc in out buffer

int idx = (p_bundle->h*p_bundle->w*p_bundle->co)* i_yn + (p_bundle->w*p_bundle->co)* i_yh + (p_bundle->co)* i_yw + i_yc;
int idx = ((i_yn*p_bundle->h + i_yh)*p_bundle->w + i_yw)*p_bundle->co + i_yc;
mem.y[idx] = out_val;

} else {
Expand All @@ -141,10 +137,10 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c
int i_xn, i_xh, i_xw, i_xc;

if (p_bundle->conv2dense){
i_xn = 0 ; // N=1
i_xh = i_yn; // N -> H
i_xw = 0 ; // W=1
i_xc = (p_bundle->w*p_bundle->co)* i_yh + (p_bundle->co)* i_yw + i_yc; // (H*W*C) -> C
i_xn = 0 ; // N=1
i_xh = i_yn; // N -> H
i_xw = 0 ; // W=1
i_xc = (i_yh*p_bundle->w + i_yw)*p_bundle->co + i_yc; // (H*W*C) -> C
} else {
i_xn = i_yn;
i_xh = i_yh;
Expand All @@ -154,40 +150,34 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c

// Calc x coordinates: [p, n, l, w,cmp, r+pad]
Bundle_t *p_bo = ib == N_BUNDLES-1 ? &bundles[ib] : &bundles[ib+1];
char xp_first = i_xc < p_bo->cm_p0;

int i_xr = i_xh % PE_ROWS;
int i_xl = i_xh / PE_ROWS;
int i_xp, i_xcm, X_CMP;

if (i_xc < p_bo->cm_p0) { // first xp
i_xp = 0;
i_xcm = i_xc;
X_CMP = p_bo->cm_p0;
} else { // following xps
i_xp = (i_xc - p_bo->cm_p0) / p_bo->cm + 1;
i_xcm = (i_xc - p_bo->cm_p0) % p_bo->cm;
X_CMP = p_bo->cm;
}

int i_xp = xp_first ? 0 : (i_xc - p_bo->cm_p0) / p_bo->cm + 1;
int i_xcm = xp_first ? i_xc : (i_xc - p_bo->cm_p0) % p_bo->cm ;
int xcm = xp_first ? p_bo->cm_p0 : p_bo->cm ;

// ------ STORE ------

write_x(out_val, ib, i_xp, i_xn, i_xl, i_xw, i_xcm, i_xr, p_bo, X_CMP);
write_x(out_val, ib, i_xp, i_xn, i_xl, i_xw, i_xcm, i_xr, p_bo, xcm);

// --- PADDING: the [bottom X_PAD rows of previous block (l-1)] with [first X_PAD rows of this block (l)]
if (i_xr < X_PAD) {
int pad_val = (i_xl == 0) ? 0 : out_val;
int dest_xl = (i_xl == 0) ? p_bo->l-1 : i_xl-1;
write_x(pad_val, ib, i_xp, i_xn, dest_xl, i_xw, i_xcm, i_xr+PE_ROWS, p_bo, X_CMP);
write_x(pad_val, ib, i_xp, i_xn, dest_xl, i_xw, i_xcm, i_xr+PE_ROWS, p_bo, xcm);
}

// --- PADDING: L*PE_ROWS-H rows with zeros, and pad their other blocks accordingly
if ((i_xl == p_bo->l-1) && (i_xr == p_bo->r_ll-1)) {
for (int ir_hpad = p_bo->r_ll; ir_hpad < PE_ROWS; ir_hpad++){
write_x(0, ib, i_xp, i_xn, i_xl, i_xw, i_xcm, ir_hpad, p_bo, X_CMP);
write_x(0, ib, i_xp, i_xn, i_xl, i_xw, i_xcm, ir_hpad, p_bo, xcm);

if (ir_hpad < X_PAD) {
int dest_xl = (i_xl == 0) ? p_bo->l-1 : i_xl-1;
write_x(0, ib, i_xp, i_xn, dest_xl, i_xw, i_xcm, ir_hpad+PE_ROWS, p_bo, X_CMP);
write_x(0, ib, i_xp, i_xn, dest_xl, i_xw, i_xcm, ir_hpad+PE_ROWS, p_bo, xcm);
}
}
}
Expand Down

0 comments on commit aa18cb5

Please sign in to comment.