diff --git a/lib_nn/src/c/vpu_sim.c b/lib_nn/src/c/vpu_sim.c index 709af474..05f31580 100644 --- a/lib_nn/src/c/vpu_sim.c +++ b/lib_nn/src/c/vpu_sim.c @@ -466,45 +466,54 @@ void VLSUB(xs3_vpu *vpu, const void *addr) { } } +static inline +unsigned _VLMUL_GET_SHIFT(const nn_target_arch_t arch, xs3_vpu *vpu) { + // VLMUL shift = bpe - 2 for XS3A, bpe - 1 for VX4A + assert(arch == TARGET_ARCH_XS3A || arch == TARGET_ARCH_VX4A); + unsigned shift = 0; + unsigned adj = (arch == TARGET_ARCH_XS3A) ? 0 : 1; + switch (vpu->mode) { + case MODE_S8: + shift = 8 - 2 + adj; + break; + case MODE_S16: + shift = 16 - 2 + adj; + break; + case MODE_S32: + shift = 32 - 2 + adj; + break; + default: + assert(0); // How'd this happen? + break; + } + return shift; +} + void VLMUL(xs3_vpu *vpu, const void *addr) { #ifdef __XS3A__ assert_word_aligned(addr); #endif - int VLMUL_SHR_S16; - if(NN_ARCH == TARGET_ARCH_XS3A){ - VLMUL_SHR_S16 = VLMUL_SHR_XS3A; - } else if (NN_ARCH == TARGET_ARCH_VX4A){ - VLMUL_SHR_S16 = VLMUL_SHR_VX4A; - } else { - assert(false); - } - + const unsigned shift = _VLMUL_GET_SHIFT(NN_ARCH, vpu); if (vpu->mode == MODE_S8) { const int8_t *addr8 = (const int8_t *)addr; for (int i = 0; i < VPU_INT8_EPV; i++) { int32_t val = addr8[i]; - int32_t res = ((int32_t)vpu->vR.s8[i] * val + (1<<5)) >> 6; // TODO use macros - if (NN_ARCH == TARGET_ARCH_VX4A){ - res = res >> 1; - } + int32_t res = ((int32_t)vpu->vR.s8[i] * (int32_t)val + (1L<<(shift - 1))) >> shift; vpu->vR.s8[i] = vpu_saturate(res, 8); } } else if (vpu->mode == MODE_S16) { const int16_t *addr16 = (const int16_t *)addr; - for (int i = 0; i < VPU_INT16_EPV; i++) { int64_t val = addr16[i]; - int64_t res = - ((int64_t)vpu->vR.s16[i] * (int64_t)val + (1LL<<(VLMUL_SHR_S16 - 1))) >> VLMUL_SHR_S16; // TODO use macros + int64_t res = ((int64_t)vpu->vR.s16[i] * (int64_t)val + (1LL<<(shift - 1))) >> shift; vpu->vR.s16[i] = vpu_saturate(res, 16); } } else if (vpu->mode == MODE_S32) { const int32_t *addr32 = (const int32_t *)addr; - for (int i = 0; i < VPU_INT32_EPV; i++) { int64_t val = addr32[i]; - int64_t res = (vpu->vR.s32[i] * val + (1<<29)) >> 30; // TODO use macros + int64_t res = ((int64_t)vpu->vR.s32[i] * (int64_t)val + (1LL<<(shift - 1))) >> shift; vpu->vR.s32[i] = vpu_saturate(res, 32); } } else {