diff --git a/deepsocflow/py/dataflow.py b/deepsocflow/py/dataflow.py index e5a3053..8b1c623 100644 --- a/deepsocflow/py/dataflow.py +++ b/deepsocflow/py/dataflow.py @@ -107,12 +107,9 @@ def pack_bits(arr, total): packed |= val << sum_width sum_width += width assert sum_width <= total, f"Number of total packed bits {sum_width} is more than input DMA width {total}" - packed_le = np.array([packed],dtype=np.uint64) - packed_be = np.frombuffer(packed_le.tobytes(), dtype=np.dtype(np.uint64).newbyteorder('>')) - return packed_le, packed_be # np.arrays + return np.array([packed],dtype=np.uint64)[0] - d = {'w_header_le_p':[],'w_header_be_p':[]} - + d = {} d['header'] = pack_bits([ (r.KW//2 , hw.BITS_KW2), (r.XW-1 , hw.BITS_COLS_MAX), @@ -122,24 +119,7 @@ def pack_bits(arr, total): (r.XN-1 , hw.BITS_XN_MAX), (hw.CONFIG_BEATS + r.KH*r.CM_0-1, hw.BITS_RAM_WEIGHTS_ADDR), (hw.CONFIG_BEATS + r.KH*r.CM-1, hw.BITS_RAM_WEIGHTS_ADDR), - ], hw.HEADER_WIDTH)[0] # little endian - - for ip in range(min(2, r.CP)): - CM_p = r.CM_0 if ip==0 else r.CM - print(f'headers: ip={ip}, CM_p={CM_p}') - - ''' Weights Config''' - - w_header_le, w_header_be = pack_bits([ - (r.KW//2, hw.BITS_KW2), - (CM_p-1 , hw.BITS_CIN_MAX), - (r.XW-1 , hw.BITS_COLS_MAX), - (r.XL-1 , hw.BITS_BLOCKS_MAX), - (r.XN-1 , hw.BITS_XN_MAX), - (hw.CONFIG_BEATS + r.KH*CM_p-1, hw.BITS_RAM_WEIGHTS_ADDR) - ], hw.AXI_WIDTH-1) - d['w_header_le_p'] += [w_header_le] - d['w_header_be_p'] += [w_header_be] + ], hw.HEADER_WIDTH) n = namedtuple('Runtime', d)(**d) diff --git a/deepsocflow/py/xmodel.py b/deepsocflow/py/xmodel.py index fdbf545..ca69af5 100644 --- a/deepsocflow/py/xmodel.py +++ b/deepsocflow/py/xmodel.py @@ -156,8 +156,8 @@ def export_inference(model, hw): for ib, b in enumerate(BUNDLES): assert ib == b.ib - w_bpt = (hw.K_BITS*b.we[-1][0].size + hw.AXI_WIDTH)//8 - w_bpt_p0 = (hw.K_BITS*b.we[0][0].size + hw.AXI_WIDTH )//8 + w_bpt = (hw.K_BITS*b.we[-1][0].size)//8 + w_bpt_p0 = (hw.K_BITS*b.we[0][0].size)//8 x_bpt = (hw.X_BITS*b.xe[-1].size)//8 x_bpt_p0 = (hw.X_BITS*b.xe[0].size )//8 @@ -223,7 +223,7 @@ def export_inference(model, hw): ch.write( f".ca_nzero={ca_nzero:<3}, .ca_shift={ca_shift:<3}, .ca_pl_scale={ca_pl_scale:<3}, .aa_nzero={aa_nzero:<3}, .aa_shift={aa_shift:<3}, .aa_pl_scale={aa_pl_scale:<3}, .pa_nzero={pa_nzero:<3}, .pa_shift={pa_shift:<3}, .pa_pl_scale={pa_pl_scale:<3}, .softmax_frac={b.softmax_frac:<3}, ") ch.write( f".softmax_max_f={b.softmax_max_f:<15}, ") ch.write( f".csh={b.r.CSH:<3}, .ch={b.r.CYH:<3}, .csh_shift={b.r.CSH_SHIFT:<3}, .pkh={b.r.PKH:<3}, .psh={b.r.PSH:<3}, .ph={b.r.PYH:<3}, .psh_shift={b.r.PSH_SHIFT:<3}, .csw={b.r.CSW:<3}, .cw={b.r.CYW:<3}, .csw_shift={b.r.CSW_SHIFT:<3}, .pkw={b.r.PKW:<3}, .psw={b.r.PSW:<3}, .pw={b.r.PYW:<3}, .psw_shift={b.r.PSW_SHIFT:<3}, .pool={pool_type:<10}, .on={b.r.ON:<3}, .oh={b.r.OH:<3}, .ow={b.r.OW:<3}, .oc={b.r.OC:<4}, ") - ch.write( f".header={b.r.header[0]:>23}u, .w_header={b.r.w_header_le_p[-1][0]:>23}u, .w_header_p0={b.r.w_header_le_p[0][0]:>25}u , ") + ch.write( f".header={b.r.header:>23}u, ") ch.write( f".debug_nhwc_words={b.oe_exp_nhwc.size:<9} }}") b_words += b.be.size if b.core.b else 0 @@ -271,8 +271,6 @@ def export_inference(model, hw): b_bitstring = b'' x_bitstring_0 = b'' - header_padding = b'\x00\x00\x00\x00\x00\x00\x00\x00' if hw.AXI_WIDTH == 128 else b'' - for ib, b in enumerate(BUNDLES): assert ib == b.ib x_bitstring_b = b'' @@ -284,7 +282,7 @@ def export_inference(model, hw): for it in range(b.r.IT): we = pack_words_into_bytes(arr=b.we[ip][it].flatten(), bits=hw.K_BITS) - w_bitstring += b.r.w_header_be_p[ip!=0].tobytes() + header_padding + we.tobytes() + w_bitstring += we.tobytes() x_bitstring += x_bitstring_b with open(f"{hw.DATA_DIR}/{ib}_x_sim.bin", 'wb') as f: f.write(x_bitstring_b) @@ -317,18 +315,9 @@ def export_inference(model, hw): np.savetxt(f"{hw.DATA_DIR}/{b.ib}_{ip}_x.txt", xp, fmt='%d') for it in range(b.r.IT): - - w_config = b.r.w_header_le_p[ip!=0][0] - w_config = format(w_config, f'#0{hw.AXI_WIDTH}b') - w_config_words = [int(w_config[i:i+hw.K_BITS], 2) for i in range(0, len(w_config), hw.K_BITS)] - w_config_words.reverse() - w_config_words = np.array(w_config_words, dtype=np.uint8) - wp = b.we[ip][it].flatten() - wp = np.concatenate([w_config_words, wp], axis=0) - assert wp.shape == (hw.AXI_WIDTH/hw.K_BITS + (CM_p*b.r.KH+hw.CONFIG_BEATS)*hw.COLS,) + assert wp.shape == ((CM_p*b.r.KH+hw.CONFIG_BEATS)*hw.COLS,), f"{wp.shape} != {(CM_p*b.r.KH+hw.CONFIG_BEATS)*hw.COLS}" np.savetxt(f"{hw.DATA_DIR}/{b.ib}_{ip}_{it}_w.txt", wp, fmt='%d') - np.savetxt(f"{hw.DATA_DIR}/{b.ib}_{ip}_{it}_y_exp.txt", b.ye_exp_p[ip][it].flatten(), fmt='%d') y_exp = BUNDLES[-1].o_int.flatten() diff --git a/deepsocflow/rtl/axis_weight_rotator.sv b/deepsocflow/rtl/axis_weight_rotator.sv index 8b9323a..f9b8abe 100644 --- a/deepsocflow/rtl/axis_weight_rotator.sv +++ b/deepsocflow/rtl/axis_weight_rotator.sv @@ -51,7 +51,7 @@ module axis_weight_rotator #( output logic [COLS-1:0] m_axis_tvalid, output logic [COLS-1:0] m_axis_tlast , output tuser_st [COLS-1:0] m_axis_tuser , - //output logic [1:0] m_rd_state, + output logic [COLS-1:0][WORD_WIDTH-1:0] m_axis_tdata ); @@ -59,13 +59,12 @@ module axis_weight_rotator #( // if (s_axis_tvalid && s_axis_tready && s_axis_tlast) // $display("weights: s_axis_tuser = %d", s_axis_tuser); - enum {W_IDLE_S, W_GET_REF_S, W_WRITE_S, W_FILL_1_S, W_FILL_2_S, W_SWITCH_S} state_write; + enum {W_IDLE_S, W_WRITE_S, W_FILL_1_S, W_SWITCH_S} state_write; typedef enum {R_IDLE_S, R_PASS_CONFIG_S, R_READ_S, R_SWITCH_S} rd_state; rd_state state_read [COLS-1:0]; // independent state for each column //enum {R_IDLE_S, R_PASS_CONFIG_S, R_READ_S, R_SWITCH_S} state_read; - enum {DW_PASS_S, DW_BLOCK_S} state_dw; - logic i_write, dw_m_ready, dw_m_valid, dw_m_last, dw_s_valid, dw_s_ready; + logic i_write, dw_m_ready, dw_m_valid, dw_m_last; logic [COLS-1:0] i_read; logic [M_WIDTH-1:0] dw_m_data_flat; logic [1:0][M_WIDTH-1:0] bram_m_data; @@ -79,6 +78,22 @@ module axis_weight_rotator #( logic [COLS-1:0][BITS_SB_CNTR-1:0] fill_skid_buffer_cntr; logic [COLS-1:0] en_count_config, l_config, l_kw, l_cin, l_cols, l_blocks, l_xn, f_kw, f_cin, f_cols, lc_config, lc_kw, lc_cin, lc_cols, lc_blocks, lc_xn; logic [COLS-1:0] last_config; + + typedef struct packed { + logic [BITS_ADDR -1:0] addr_p_max; + logic [BITS_ADDR -1:0] addr_p0_max; + logic [BITS_XN -1:0] xn_1; + logic [BITS_CI -1:0] cin_p_1; + logic [BITS_CI -1:0] cin_p0_1; + logic [BITS_IM_BLOCKS -1:0] blocks_1; + logic [BITS_XW -1:0] cols_1; + logic [BITS_KW2 -1:0] kw2; + logic is_first_p; + } config_input_st; + config_input_st sci; + assign sci = config_input_st'(s_axis_tuser); + + localparam BITS_CONFIG = BITS_ADDR + BITS_XN + BITS_IM_BLOCKS + BITS_XW + BITS_CI + BITS_KW2; typedef struct packed { logic [BITS_ADDR -1:0] addr_max; logic [BITS_XN -1:0] xn_1; @@ -87,10 +102,11 @@ module axis_weight_rotator #( logic [BITS_CI -1:0] cin_1; logic [BITS_KW2 -1:0] kw2; } config_st; - config_st s_config; - logic [1:0][BITS_ADDR + BITS_XN + BITS_IM_BLOCKS + BITS_XW + BITS_CI + BITS_KW2 -1:0] ref_config; - - assign s_config = config_st'(s_axis_tdata); + config_st s_config, dw_config; + assign s_config = {(sci.is_first_p ? sci.addr_p0_max : sci.addr_p_max), sci.xn_1, sci.blocks_1, sci.cols_1, (sci.is_first_p ? sci.cin_p0_1 : sci.cin_p_1), sci.kw2}; + + logic [1:0][BITS_ADDR + BITS_XN + BITS_IM_BLOCKS + BITS_XW + BITS_CI + BITS_KW2-1:0] ref_config; + wire s_handshake = s_axis_tready && s_axis_tvalid; wire s_last_handshake = s_handshake && s_axis_tlast; //assign m_rd_state = state_read; @@ -105,27 +121,28 @@ module axis_weight_rotator #( .M_KEEP_WIDTH (M_WIDTH/WORD_WIDTH), .ID_ENABLE (0), .DEST_ENABLE (0), - .USER_ENABLE (0) + .USER_ENABLE (1), + .USER_WIDTH (BITS_CONFIG) ) DW ( .clk (aclk ), .rstn (aresetn ), - .s_axis_tvalid (dw_s_valid ), - .s_axis_tready (dw_s_ready ), + .s_axis_tvalid (s_axis_tvalid), + .s_axis_tready (s_axis_tready), .s_axis_tdata (s_axis_tdata), .s_axis_tkeep (s_axis_tkeep), .s_axis_tlast (s_axis_tlast), + .s_axis_tuser (s_config ), .m_axis_tvalid (dw_m_valid ), .m_axis_tready (dw_m_ready ), .m_axis_tdata (dw_m_data_flat ), .m_axis_tlast (dw_m_last ), + .m_axis_tuser (dw_config ), // Extras .s_axis_tid ('0), .s_axis_tdest ('0), - .s_axis_tuser ('0), .m_axis_tid (), .m_axis_tdest (), - .m_axis_tkeep (), - .m_axis_tuser () + .m_axis_tkeep () ); wire dw_m_handshake = dw_m_valid && dw_m_ready; @@ -137,12 +154,13 @@ module axis_weight_rotator #( always_ff @(posedge aclk `OR_NEGEDGE(aresetn)) if (!aresetn) state_write <= W_IDLE_S; else unique case (state_write) - W_IDLE_S : if (&done_read [i_write] ) state_write <= W_GET_REF_S; - W_GET_REF_S : if (s_handshake && state_dw == DW_BLOCK_S) state_write <= W_WRITE_S; + W_IDLE_S : if (&done_read [i_write] ) state_write <= W_WRITE_S; W_WRITE_S : if (dw_m_last_handshake ) state_write <= W_FILL_1_S; // dw_m_last_handshake and bram_w_full[w_i] should be same W_FILL_1_S : state_write <= W_SWITCH_S; W_SWITCH_S : state_write <= W_IDLE_S; endcase + + assign dw_m_ready = (state_write == W_WRITE_S); // STATE MACHINE: READ @@ -206,28 +224,7 @@ module axis_weight_rotator #( if (state_read[col] == R_SWITCH_S) i_read[col] <= !i_read[col]; end end - - - // State machine DW - always_ff @(posedge aclk `OR_NEGEDGE(aresetn)) - if (!aresetn) state_dw <= DW_BLOCK_S; - else unique case (state_dw) - DW_BLOCK_S: if (s_handshake) state_dw <= DW_PASS_S; - DW_PASS_S : if (s_last_handshake) state_dw <= DW_BLOCK_S; - endcase - always_comb begin - dw_m_ready = (state_write == W_WRITE_S); - - if (state_dw == DW_BLOCK_S) begin - dw_s_valid = 0; - s_axis_tready = (state_write == W_GET_REF_S); - end - else begin - dw_s_valid = s_axis_tvalid; - s_axis_tready = dw_s_ready; - end - end generate for (genvar i=0; i<2; i++) begin // FSM Output Decoders for indexed signals @@ -240,16 +237,17 @@ module axis_weight_rotator #( done_read_next [i] = done_read[i]; bram_m_ready [i] = '0; - if (i==i_write) + if (i==i_write) begin + en_ref [i] = dw_m_last_handshake; + done_write_next [i] = 0; case (state_write) - W_GET_REF_S : begin - done_write_next [i] = 0; + W_WRITE_S : bram_wen [i] = dw_m_valid; + W_SWITCH_S : begin bram_resetn [i] = 0; - en_ref [i] = s_handshake && (state_dw == DW_BLOCK_S); + done_write_next [i] = 1; end - W_WRITE_S : bram_wen [i] = dw_m_valid; - W_SWITCH_S : done_write_next [i] = 1; endcase + end for (int j=0; j