Skip to content

Commit

Permalink
[hwpe] WIP Support for HWPE version
Browse files Browse the repository at this point in the history
Changes:
- Add register for shape parameters

Current Limitations:
- Only works without biases
  • Loading branch information
Xeratec committed Sep 24, 2024
1 parent 9a0d9b4 commit a21aab1
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 20 deletions.
3 changes: 3 additions & 0 deletions src/hwpe/ita_hwpe_ctrl.sv
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ module ita_hwpe_ctrl
ctrl_engine_o.tile_s = reg_file.hwpe_params[ITA_REG_TILES][3:0];
ctrl_engine_o.tile_e = reg_file.hwpe_params[ITA_REG_TILES][7:4];
ctrl_engine_o.tile_p = reg_file.hwpe_params[ITA_REG_TILES][11:8];
ctrl_engine_o.seq_length = reg_file.hwpe_params[ITA_REG_LENGTH][7:0];
ctrl_engine_o.proj_space = reg_file.hwpe_params[ITA_REG_LENGTH][15:8];
ctrl_engine_o.embed_size = reg_file.hwpe_params[ITA_REG_LENGTH][23:16];
ctrl_engine_o.eps_mult[0] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][7:0];
ctrl_engine_o.eps_mult[1] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][15:8];
ctrl_engine_o.eps_mult[2] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][23:16];
Expand Down
19 changes: 10 additions & 9 deletions src/hwpe/ita_hwpe_package.sv
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ package ita_hwpe_package;
parameter int unsigned N_CORES = 9;
parameter int unsigned N_CONTEXT = 2;
parameter int unsigned ID_WIDTH = 2;
parameter int unsigned ITA_IO_REGS = 14; // 5 address + 8 parameters + 1 sync
parameter int unsigned ITA_IO_REGS = 15; // 5 address + 9 parameters + 1 sync

parameter int unsigned ITA_TCDM_DW = 1024;
parameter int unsigned ITA_INPUT_DW = M*WI;
Expand All @@ -28,13 +28,14 @@ package ita_hwpe_package;
parameter int unsigned ITA_REG_OUTPUT_PTR = 4;
parameter int unsigned ITA_REG_SEQ_LENGTH = 5;
parameter int unsigned ITA_REG_TILES = 6; // tile_s [3:0], tile_e [7:4], tile_p [11:8]
parameter int unsigned ITA_REG_EPS_MULT0 = 7; // eps_mult[0] [7:0], eps_mult[1] [15:8], eps_mult[2] [23:16], eps_mult[3] [31:24]
parameter int unsigned ITA_REG_EPS_MULT1 = 8; // eps_mult[4] [7:0], eps_mult[5] [15:8]
parameter int unsigned ITA_REG_RIGHT_SHIFT0 = 9; // right_shift[0] [7:0], right_shift[1] [15:8], right_shift[2] [23:16], right_shift[3] [31:24]
parameter int unsigned ITA_REG_RIGHT_SHIFT1 = 10; // right_shift[4] [7:0], right_shift[5] [15:8]
parameter int unsigned ITA_REG_ADD0 = 11; // add[0] [7:0], add[1] [15:8], add[2] [23:16], add[3] [31:24]
parameter int unsigned ITA_REG_ADD1 = 12; // add[4] [7:0], add[5] [15:8]
parameter int unsigned ITA_REG_CTRL_STREAM = 13; // ctrl_stream [0]: weight preload, ctrl_stream [1]: weight nextload, ctrl_stream [2]: bias disable, ctrl_stream [3]: bias direction, ctrl_stream [4]: output disable
parameter int unsigned ITA_REG_LENGTH = 7; // tile_s [3:0], tile_e [7:4], tile_p [11:8]
parameter int unsigned ITA_REG_EPS_MULT0 = 8; // eps_mult[0] [7:0], eps_mult[1] [15:8], eps_mult[2] [23:16], eps_mult[3] [31:24]
parameter int unsigned ITA_REG_EPS_MULT1 = 9; // eps_mult[4] [7:0], eps_mult[5] [15:8]
parameter int unsigned ITA_REG_RIGHT_SHIFT0 = 10; // right_shift[0] [7:0], right_shift[1] [15:8], right_shift[2] [23:16], right_shift[3] [31:24]
parameter int unsigned ITA_REG_RIGHT_SHIFT1 = 11; // right_shift[4] [7:0], right_shift[5] [15:8]
parameter int unsigned ITA_REG_ADD0 = 12; // add[0] [7:0], add[1] [15:8], add[2] [23:16], add[3] [31:24]
parameter int unsigned ITA_REG_ADD1 = 13; // add[4] [7:0], add[5] [15:8]
parameter int unsigned ITA_REG_CTRL_STREAM = 14; // ctrl_stream [0]: weight preload, ctrl_stream [1]: weight nextload, ctrl_stream [2]: bias disable, ctrl_stream [3]: bias direction, ctrl_stream [4]: output disable

typedef struct packed {
hci_package::hci_streamer_ctrl_t input_source_ctrl;
Expand Down Expand Up @@ -84,4 +85,4 @@ package ita_hwpe_package;
Done
} state_t;

endpackage : ita_hwpe_package
endpackage : ita_hwpe_package
30 changes: 19 additions & 11 deletions src/hwpe/tb/ita_hwpe_tb.sv
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ module ita_hwpe_tb;
logic [MP-1:0] tcdm_r_valid;
logic [MP-1:0] tcdm_r_ready;

hwpe_ctrl_intf_periph #(
.ID_WIDTH (IdWidth)
) periph (
.clk (clk)
hwpe_ctrl_intf_periph #(
.ID_WIDTH (IdWidth)
) periph (
.clk (clk)
);

localparam hci_size_parameter_t `HCI_SIZE_PARAM(tcdm_mem) = '{
Expand Down Expand Up @@ -285,8 +285,13 @@ endfunction
logic [31:0] status;
string STIM_DATA;
logic [31:0] ita_reg_tiles_val;
logic [31:0] ita_reg_length_val;
logic [5:0][31:0] ita_reg_rqs_val;

ita_reg_length_val[7:0] = SEQUENCE_LEN;
ita_reg_length_val[15:8] = PROJECTION_SPACE;
ita_reg_length_val[23:16] = EMBEDDING_SIZE;

$timeformat(-9, 2, " ns", 10);

// Wait for reset to be released
Expand All @@ -308,13 +313,13 @@ endfunction
PERIPH_READ( 32'h04, 32'h0, status, clk);

// 1: Step Q
ita_compute_step(Q, ita_reg_tiles_val, ita_reg_rqs_val, clk);
ita_compute_step(Q, ita_reg_tiles_val, ita_reg_length_val, ita_reg_rqs_val, clk);

// 2: Step K
ita_compute_step(K, ita_reg_tiles_val, ita_reg_rqs_val, clk);
ita_compute_step(K, ita_reg_tiles_val, ita_reg_length_val, ita_reg_rqs_val, clk);

// 3: Step V
ita_compute_step(V, ita_reg_tiles_val, ita_reg_rqs_val, clk);
ita_compute_step(V, ita_reg_tiles_val, ita_reg_length_val, ita_reg_rqs_val, clk);


for (int group = 0; group < N_TILES_SEQUENCE_DIM; group++) begin
Expand All @@ -325,19 +330,19 @@ endfunction
BASE_PTR_OUTPUT[AV] = BASE_PTR[14] + group * N_TILES_OUTER_X[AV] * N_ELEMENTS_PER_TILE;

// 4: Step QK
ita_compute_step(QK, ita_reg_tiles_val, ita_reg_rqs_val, clk);
ita_compute_step(QK, ita_reg_tiles_val, ita_reg_length_val, ita_reg_rqs_val, clk);

// WIESEP: Hack to ensure that during the last tile of AV, the weight pointer is set correctly
if (group == N_TILES_SEQUENCE_DIM-1) begin
BASE_PTR_WEIGHT0[QK] = BASE_PTR_WEIGHT0[OW];
end

// 5: Step AV
ita_compute_step(AV, ita_reg_tiles_val, ita_reg_rqs_val, clk);
ita_compute_step(AV, ita_reg_tiles_val, ita_reg_length_val, ita_reg_rqs_val, clk);
end

// 6: Step OW
ita_compute_step(OW, ita_reg_tiles_val, ita_reg_rqs_val, clk);
ita_compute_step(OW, ita_reg_tiles_val, ita_reg_length_val, ita_reg_rqs_val, clk);

// Wait for the last step to finish
wait(evt);
Expand All @@ -361,6 +366,7 @@ endfunction
task automatic ita_compute_step(
input step_e step,
input logic [31:0] ita_reg_tiles_val,
input logic [31:0] ita_reg_length_val,
input logic [5:0][31:0] ita_reg_rqs_val,
ref logic clk_i
);
Expand Down Expand Up @@ -405,7 +411,7 @@ endfunction
$display(" - ITA Reg En 0x%0h, Ctrl Stream Val 0x%0h, Weight Ptr En %0d, Bias Ptr En %0d", ita_reg_en, ctrl_stream_val, weight_ptr_en, bias_ptr_en);

// Program ITA
PROGRAM_ITA(input_ptr, weight_ptr0, weight_ptr1, weight_ptr_en, bias_ptr, bias_ptr_en, output_ptr, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_en, ctrl_stream_val, clk_i);
PROGRAM_ITA(input_ptr, weight_ptr0, weight_ptr1, weight_ptr_en, bias_ptr, bias_ptr_en, output_ptr, ita_reg_tiles_val, ita_reg_length_val, ita_reg_rqs_val, ita_reg_en, ctrl_stream_val, clk_i);

// Wait for ITA to finish
@(posedge clk_i);
Expand Down Expand Up @@ -650,6 +656,7 @@ endfunction
input logic bias_ptr_en,
input logic [31:0] output_ptr,
input logic [31:0] ita_reg_tiles_val,
input logic [31:0] ita_reg_length_val,
input logic [5:0][31:0] ita_reg_rqs_val,
input logic ita_reg_en,
input logic [31:0] ctrl_stream_val,
Expand All @@ -664,6 +671,7 @@ endfunction
PERIPH_WRITE( 4*ITA_REG_OUTPUT_PTR, ITA_REG_OFFSET, output_ptr, clk_i);

if (ita_reg_en) begin
PERIPH_WRITE( 4*ITA_REG_LENGTH, ITA_REG_OFFSET, ita_reg_length_val, clk_i);
PERIPH_WRITE( 4*ITA_REG_TILES, ITA_REG_OFFSET, ita_reg_tiles_val, clk_i);
PERIPH_WRITE( 4*ITA_REG_EPS_MULT0, ITA_REG_OFFSET, ita_reg_rqs_val[0], clk_i);
PERIPH_WRITE( 4*ITA_REG_EPS_MULT1, ITA_REG_OFFSET, ita_reg_rqs_val[1], clk_i);
Expand Down

0 comments on commit a21aab1

Please sign in to comment.