#include #include #include #define AP_INT_MAX_W 2048 #include #include constexpr int D = 1024; constexpr int D_div_2 = D / 2; constexpr int D_div_4 = D / 4; constexpr int D_ffn = 4096; constexpr int N_head = 16; constexpr int MAX_SEQ_LEN = 1024; constexpr int MAX_SEQ_LEN_div_2 = MAX_SEQ_LEN / 2; constexpr int MAX_SEQ_LEN_div_8 = MAX_SEQ_LEN / 8; constexpr int NUM_SLR = 4; constexpr int NUM_DUM_SLR = 4; constexpr int TOTAL_PORT = NUM_SLR * 2; constexpr int D_head = D / N_head; constexpr int D_head_div_32 = D_head / 32; constexpr int D_head_div_16 = D_head / 16; constexpr int D_head_div_8 = D_head / 8; constexpr int D_head_div_4 = D_head / 4; constexpr int D_head_div_2 = D_head / 2; constexpr int D_div_8 = D / 8; constexpr int D_div_16 = D / 16; constexpr int FFN_WEIGHT_SIZE = D * D_ffn; constexpr int OUT_WEIGHT_SIZE = D * D_head * NUM_DUM_SLR * 4; constexpr int WEIGHT_D = D * 2; constexpr int QKV_WEIGHT_SIZE = D * D_head * NUM_DUM_SLR * 6; // multi-head attention constexpr int TOTAL_WEIGHT_SIZE = OUT_WEIGHT_SIZE + QKV_WEIGHT_SIZE + FFN_WEIGHT_SIZE; constexpr int CONTEXT_D = D_head_div_8 * 4; constexpr int D_head_mul_4 = D_head * 4; constexpr int D_write_zero_acc0 = D / 32; constexpr int D_write_zero_acc1 = D / 32 + D / 16; using int_v16 = tapa::vec_t; using int4_v128 = tapa::vec_t, 128>; using int8_v64 = tapa::vec_t, 64>; template inline void bh(tapa::istream & q) { #pragma HLS inline for (;;) { #pragma HLS pipeline II=1 style=stp data_t tmp; q.try_read(tmp); } } struct ConfigInst { ap_uint<3> stage; // stage 7 -> read L ap_uint<11> weight_bound; ap_uint<7> i_bound; ap_uint<7> j_bound; ap_uint<8> k_bound; }; void black_hole_int(tapa::istream & fifo_in) { bh(fifo_in); } void black_hole_inst(tapa::istream & fifo_in) { bh(fifo_in); } void black_hole_int_v16(tapa::istream & fifo_in) { bh(fifo_in); } void black_hole_x(tapa::istream & fifo_in) { bh(fifo_in); } void black_hole_w(tapa::istream & fifo_in) { bh(fifo_in); } void black_hole_ap_uint_8(tapa::istream> & fifo_in) { bh(fifo_in); } void black_hole_ap_uint_512(tapa::istream> & fifo_in) { bh(fifo_in); } void black_hole_ap_uint_576(tapa::istream> & fifo_in) { bh(fifo_in); } void black_hole_ap_uint_1024(tapa::istream> & fifo_in) { bh(fifo_in); } void read_W( tapa::async_mmap>& vec, tapa::ostream>& fifo_out ){ for(int i_req = 0, i_resp = 0; i_resp < (TOTAL_WEIGHT_SIZE >> 7);){ #pragma HLS pipeline II=1 style=stp if((i_req < (TOTAL_WEIGHT_SIZE >> 7)) & !vec.read_addr.full()){ vec.read_addr.write(i_req); i_req++; } ap_uint<512> tmp_o; bool success = vec.read_data.try_read(tmp_o); if(success){ fifo_out.write(tmp_o); i_resp++; } } } void read_X( const int N, tapa::async_mmap>& vec, tapa::ostream>& fifo_out ){ for(int i_req = 0, i_resp = 0; i_resp < (N >> 6);){ #pragma HLS pipeline II=1 style=stp if((i_req < (N >> 6)) & !vec.read_addr.full()){ vec.read_addr.write(i_req); i_req++; } ap_uint<512> tmp_o; bool success = vec.read_data.try_read(tmp_o); if(success){ fifo_out.write(tmp_o); i_resp++; } } } void read_inst( const int L, tapa::ostream& fifo_out_acc0, tapa::ostream& fifo_out_acc1 ){ ConfigInst len; len.stage = 7; len.weight_bound = L; fifo_out_acc0.write(len); fifo_out_acc1.write(len); for(int stage_i = 0; stage_i < 14; stage_i++){ #pragma HLS pipeline II=1 style=stp ConfigInst inst_acc0; ConfigInst inst_acc1; const int stage = (stage_i < 12) ? (stage_i % 3) : (stage_i - 9); inst_acc0.stage = ap_uint<3>(stage); inst_acc1.stage = ap_uint<3>(stage); if(stage == 0){ inst_acc0.weight_bound = D_head_div_4; inst_acc0.i_bound = (L >> 4); inst_acc0.j_bound = D_head_div_16; inst_acc0.k_bound = D_div_8; inst_acc1 = inst_acc0; } else if (stage == 1){ inst_acc0.weight_bound = D_head_div_8; inst_acc0.i_bound = (L >> 4); inst_acc0.j_bound = D_head_div_32; inst_acc0.k_bound = D_div_8; inst_acc1 = inst_acc0; } else if (stage == 2){ inst_acc0.weight_bound = 0; inst_acc0.i_bound = (L >> 4); inst_acc0.j_bound = (L >> 4); inst_acc0.k_bound = D_head_div_8; inst_acc1.weight_bound = 0; inst_acc1.i_bound = (L >> 4); inst_acc1.j_bound = D_head_div_16; inst_acc1.k_bound = (L >> 3); } else if (stage == 3){ inst_acc0.weight_bound = D_head; inst_acc0.i_bound = (L >> 5); inst_acc0.j_bound = D_div_16; inst_acc0.k_bound = D_head_div_2; inst_acc1 = inst_acc0; } else { inst_acc0.weight_bound = D_div_4; inst_acc0.i_bound = (L >> 4); inst_acc0.j_bound = D_div_16; inst_acc0.k_bound = D_div_8; inst_acc1 = inst_acc0; } fifo_out_acc0.write(inst_acc0); fifo_out_acc1.write(inst_acc1); } } void packet_switch_acc( tapa::istream& fifo_inst_in, tapa::ostream& fifo_sfu_out, tapa::ostream& fifo_sfu_gelu ) { const int L = fifo_inst_in.read(); fifo_sfu_out.write(L); fifo_sfu_gelu.write(L); } void write_mtx( const int N, tapa::async_mmap>& output_mtx, tapa::istream>& fifo_in, tapa::ostream& fifo_fin ){ for(int i_req = 0, i_resp = 0; i_resp < N;){ #pragma HLS pipeline II=1 style=stp if((i_req < N) & !fifo_in.empty() & !output_mtx.write_addr.full() & !output_mtx.write_data.full()){ output_mtx.write_addr.try_write(i_req); ap_uint<128> tmp; fifo_in.try_read(tmp); output_mtx.write_data.try_write(tmp); ++i_req; } bool success = false; auto resp = output_mtx.write_resp.read(success); if(success){ i_resp += unsigned(resp)+1; } } fifo_fin.write(true); } void write_zero( const int L, const int D, tapa::ostream>& fifo_zero ){ for(int i = 0; i < L * D;){ if(!fifo_zero.full()){ ap_uint<512> tmp = 0; fifo_zero.try_write(tmp); i++; } } } void PE_GEMM_acc0( // 2x2 systolic array const int idx, const int idy, tapa::istream>& fifo_bound_in, tapa::ostream>& fifo_bound_out, tapa::istream>& fifo_in_x, // 8*72 tapa::ostream>& fifo_out_x, tapa::istream>& fifo_in_y, tapa::ostream>& fifo_out_y, tapa::ostream>& fifo_drain ){ for(;;){ const int k_bound = fifo_bound_in.read(); fifo_bound_out.write(ap_uint<8>(k_bound)); const int stage = fifo_bound_in.read(); fifo_bound_out.write(ap_uint<8>(stage)); ap_int<22> acc_vec[2][3][8][8]; #pragma HLS array_partition variable=acc_vec dim=1 complete #pragma HLS array_partition variable=acc_vec dim=2 complete #pragma HLS array_partition variable=acc_vec dim=3 complete #pragma HLS array_partition variable=acc_vec dim=4 complete for(int l = 0; l < 2; l++){ #pragma HLS unroll for(int ii = 0; ii < 3; ii++){ #pragma HLS unroll for(int kk = 0; kk < 8; kk++){ #pragma HLS unroll for(int k = 0; k < 8; k++){ #pragma HLS unroll acc_vec[l][ii][kk][k] = 0; } } } } compute: for(int k = 0; k < k_bound; k++){ #pragma HLS pipeline II=1 style=stp #pragma HLS dependence variable=acc_vec type=inter distance=2 true ap_uint<576> tmp_x = fifo_in_x.read(); ap_uint<576> tmp_y = fifo_in_y.read(); fifo_out_x.write(tmp_x); fifo_out_y.write(tmp_y); for(int kk = 0; kk < 8; kk++){ #pragma HLS unroll for(int l = 0; l < 8; l++){ #pragma HLS unroll ap_uint<72> op1 = tmp_x(l*72+71, l*72); ap_uint<72> op2 = tmp_y(kk*72+71, kk*72); ap_int<22> res0 = acc_vec[k%2][0][kk][l]; ap_int<22> res1 = acc_vec[k%2][1][kk][l]; ap_int<22> res2 = acc_vec[k%2][2][kk][l]; ap_int<8> a0; ap_int<8> a1; ap_int<8> a2; ap_int<8> b0; ap_int<8> b1; ap_int<8> b2; ap_int<8> a0_1; ap_int<8> a1_1; ap_int<8> a2_1; ap_int<8> b0_1; ap_int<8> b1_1; ap_int<8> b2_1; ap_int<8> a0_2; ap_int<8> a1_2; ap_int<8> a2_2; ap_int<8> b0_2; ap_int<8> b1_2; ap_int<8> b2_2; (a2_2, a1_2, a0_2, a2_1, a1_1, a0_1, a2, a1, a0) = op2; if(stage==2){ (b2_2, b1_2, b0_2, b2_1, b1_1, b0_1, b2, b1, b0) = op1; }else{ b0 = ap_int<4>(op1(3, 0)); b1 = ap_int<4>(op1(7, 4)); b2 = ap_int<4>(op1(11, 8)); b0_1 = ap_int<4>(op1(15, 12)); b1_1 = ap_int<4>(op1(19, 16)); b2_1 = ap_int<4>(op1(23, 20)); b0_2 = ap_int<4>(op1(27, 24)); b1_2 = ap_int<4>(op1(31, 28)); b2_2 = ap_int<4>(op1(35, 32)); } res0 = (((a0 * b0) + (a1 * b1)) + (a2 * b2)) + res0; res1 = (((a0_1 * b0_1) + (a1_1 * b1_1)) + (a2_1 * b2_1)) + res1; res2 = (((a0_2 * b0_2) + (a1_2 * b1_2)) + (a2_2 * b2_2)) + res2; acc_vec[k%2][0][kk][l] = res0; acc_vec[k%2][1][kk][l] = res1; acc_vec[k%2][2][kk][l] = res2; } } } ap_int<24> acc_final[8][8]; #pragma HLS array_partition variable=acc_final dim=1 complete #pragma HLS array_partition variable=acc_final dim=2 complete for(int ii = 0; ii < 8; ii++){ #pragma HLS unroll for(int k = 0; k < 8; k++){ #pragma HLS unroll acc_final[ii][k] = -(k_bound << 2); } } reduction: for(int k = 0; k < 2; k++){ for(int ii = 0; ii < 3; ii++){ for(int kk = 0; kk < 8; kk++){ #pragma HLS unroll for(int l = 0; l < 8; l++){ #pragma HLS unroll acc_final[kk][l] += acc_vec[k][ii][kk][l]; } } } } ap_uint<1536> final_result; for(int i = 0; i < 8; i++){ #pragma HLS unroll for(int j = 0; j < 8; j++){ #pragma HLS unroll final_result(i*192+j*24+23, i*192+j*24) = acc_final[i][j]; } } fifo_drain.write(final_result); } } void PE_GEMM_acc1( // 2x2 systolic array const int idx, const int idy, tapa::istream>& fifo_bound_in, tapa::ostream>& fifo_bound_out, tapa::istream>& fifo_in_x, // 8*72 tapa::ostream>& fifo_out_x, tapa::istream>& fifo_in_y, tapa::ostream>& fifo_out_y, tapa::ostream>& fifo_drain ){ for(;;){ const int k_bound = fifo_bound_in.read(); fifo_bound_out.write(k_bound); const int stage = fifo_bound_in.read(); fifo_bound_out.write(stage); ap_int<24> acc_vec[2][3][8][8]; #pragma HLS array_partition variable=acc_vec dim=1 complete #pragma HLS array_partition variable=acc_vec dim=2 complete #pragma HLS array_partition variable=acc_vec dim=3 complete #pragma HLS array_partition variable=acc_vec dim=4 complete for(int l = 0; l < 2; l++){ #pragma HLS unroll for(int ii = 0; ii < 3; ii++){ #pragma HLS unroll for(int kk = 0; kk < 8; kk++){ #pragma HLS unroll for(int k = 0; k < 8; k++){ #pragma HLS unroll acc_vec[l][ii][kk][k] = 0; } } } } compute: for(int k = 0; k < k_bound; k++){ #pragma HLS pipeline II=1 style=stp #pragma HLS dependence variable=acc_vec type=inter distance=2 true ap_uint<576> tmp_x = fifo_in_x.read(); ap_uint<576> tmp_y = fifo_in_y.read(); fifo_out_x.write(tmp_x); fifo_out_y.write(tmp_y); for(int kk = 0; kk < 8; kk++){ #pragma HLS unroll for(int l = 0; l < 8; l++){ #pragma HLS unroll ap_uint<72> op1 = tmp_x(l*72+71, l*72); ap_uint<72> op2 = tmp_y(kk*72+71, kk*72); ap_int<24> res0 = acc_vec[k%2][0][kk][l]; ap_int<24> res1 = acc_vec[k%2][1][kk][l]; ap_int<24> res2 = acc_vec[k%2][2][kk][l]; ap_int<8> a0; ap_int<8> a1; ap_int<8> a2; ap_int<8> b0; ap_int<8> b1; ap_int<8> b2; ap_int<8> a0_1; ap_int<8> a1_1; ap_int<8> a2_1; ap_int<8> b0_1; ap_int<8> b1_1; ap_int<8> b2_1; ap_int<8> a0_2; ap_int<8> a1_2; ap_int<8> a2_2; ap_int<8> b0_2; ap_int<8> b1_2; ap_int<8> b2_2; (a2_2, a1_2, a0_2, a2_1, a1_1, a0_1, a2, a1, a0) = op2; if(stage==2){ (b2_2, b1_2, b0_2, b2_1, b1_1, b0_1, b2, b1, b0) = op1; }else{ b0 = ap_int<4>(op1(3, 0)); b1 = ap_int<4>(op1(7, 4)); b2 = ap_int<4>(op1(11, 8)); b0_1 = ap_int<4>(op1(15, 12)); b1_1 = ap_int<4>(op1(19, 16)); b2_1 = ap_int<4>(op1(23, 20)); b0_2 = ap_int<4>(op1(27, 24)); b1_2 = ap_int<4>(op1(31, 28)); b2_2 = ap_int<4>(op1(35, 32)); } res0 = (((a0 * b0) + (a1 * b1)) + (a2 * b2)) + res0; res1 = (((a0_1 * b0_1) + (a1_1 * b1_1)) + (a2_1 * b2_1)) + res1; res2 = (((a0_2 * b0_2) + (a1_2 * b1_2)) + (a2_2 * b2_2)) + res2; acc_vec[k%2][0][kk][l] = res0; acc_vec[k%2][1][kk][l] = res1; acc_vec[k%2][2][kk][l] = res2; } } } ap_int<27> acc_final[8][8]; #pragma HLS array_partition variable=acc_final dim=1 complete #pragma HLS array_partition variable=acc_final dim=2 complete for(int ii = 0; ii < 8; ii++){ #pragma HLS unroll for(int k = 0; k < 8; k++){ #pragma HLS unroll acc_final[ii][k] = -(k_bound << 2); } } reduction: for(int k = 0; k < 2; k++){ for(int ii = 0; ii < 3; ii++){ for(int kk = 0; kk < 8; kk++){ #pragma HLS unroll for(int l = 0; l < 8; l++){ #pragma HLS unroll acc_final[kk][l] += acc_vec[k][ii][kk][l]; } } } } ap_uint<1728> final_result; for(int i = 0; i < 8; i++){ #pragma HLS unroll for(int j = 0; j < 8; j++){ #pragma HLS unroll final_result(i*216+j*27+26, i*216+j*27) = acc_final[i][j]; } } fifo_drain.write(final_result); } } // acc slr0 master node void temporal_acc0_slr0( tapa::istream& fifo_inst_in, tapa::ostream& fifo_inst_out, tapa::ostream& fifo_len_sfu, tapa::istream>& fifo_X_in, tapa::ostream>& fifo_X_out, // 8-bit activation tapa::istream>& fifo_W_in, tapa::ostream>& fifo_W_out, // 4-bit weight tapa::istream>& fifo_from_acc1, tapa::ostream>& fifo_O_out, tapa::ostream>& fifo_ffn_out, tapa::istream>& fifo_context, tapa::istream>& fifo_ffn_in, tapa::istream>& fifo_reduce_recv, tapa::ostream>& fifo_res_send, // systolic array tapa::ostream>& fifo_x0, tapa::ostream>& fifo_x1, tapa::ostream>& fifo_y0, tapa::ostream>& fifo_y1, tapa::ostream>& fifo_bound_x0, tapa::ostream>& fifo_bound_x1, tapa::istreams, 4>& fifo_drain ){ ap_uint<64> scratchpad_q[MAX_SEQ_LEN][D_head_div_8]; // 8 bit #pragma HLS array_partition variable=scratchpad_q cyclic dim=1 factor=16 #pragma HLS array_partition variable=scratchpad_q cyclic dim=2 factor=2 #pragma HLS bind_storage variable=scratchpad_q type=ram_2p impl=bram ap_uint<64> scratchpad_k[MAX_SEQ_LEN][D_head_div_8]; // 8 bit #pragma HLS array_partition variable=scratchpad_k cyclic dim=1 factor=16 #pragma HLS array_partition variable=scratchpad_k cyclic dim=2 factor=2 #pragma HLS bind_storage variable=scratchpad_k type=ram_2p impl=bram ap_uint<64> X[MAX_SEQ_LEN][D_div_8]; // 8 bit #pragma HLS array_partition variable=X cyclic dim=1 factor=16 #pragma HLS array_partition variable=X cyclic dim=2 factor=2 #pragma HLS bind_storage variable=X type=ram_2p impl=uram ConfigInst len = fifo_inst_in.read(); const int L = len.weight_bound; fifo_inst_out.write(len); fifo_len_sfu.write(L); for(int stage_i = 0; stage_i < 14; stage_i++){ //TODO: stage send from inst // stage 0: WqX // stage 1: WkX0 <- acc1 // stage 2: QK^T ap_uint<32> W[D][D_div_8]; // TODO: reduce dimension #pragma HLS array_partition variable=W cyclic dim=1 factor=16 #pragma HLS bind_storage variable=W type=ram_2p impl=uram ConfigInst inst = fifo_inst_in.read(); fifo_inst_out.write(inst); const ap_uint<3> stage = inst.stage; // load weights and forward if(stage != 2) { // TODO: 1d array & uniform access const int weight_bound = inst.weight_bound; for(int i = 0; i < weight_bound; i++){ load_weight: for(int j = 0; j < D_div_8;){ if(!fifo_W_in.empty()){ ap_uint<512> val; fifo_W_in.try_read(val); for(int k = 0; k < 4; k++){ #pragma HLS unroll W[i*4+k][j] = ap_uint<32>(val(k*32+31, k*32)); } val = ap_uint<512>(val >> 128); fifo_W_out.write(val); j++; } } } } // stage 1: compute Q const int i_bound = inst.i_bound; const int j_bound = inst.j_bound; const int k_bound = inst.k_bound; for(int i = 0; i < i_bound; i++){ // make sure L is multiple of 16 if(stage_i == 0){ for(int ii = 0; ii < 2; ii++){ // load only 1 time load_x: for(int jj = 0; jj < D_div_8;){ if(!fifo_X_in.empty()){ ap_uint<512> val; fifo_X_in.try_read(val); for(int k = 0; k < 8; k++){ #pragma HLS unroll X[i*16+ii*8+k][jj] = ap_uint<64>(val(k*64+63, k*64)); } jj++; } } } } for(int j = 0; (j < j_bound) & ((stage != 2) | (j <= i)); j++){ #pragma HLS loop_flatten off fifo_bound_x0.write(ap_uint<8>(k_bound)); fifo_bound_x0.write(ap_uint<8>(stage)); fifo_bound_x1.write(ap_uint<8>(k_bound)); fifo_bound_x1.write(ap_uint<8>(stage)); compute: for(int k = 0; k < k_bound; k++){ // reduction dim #pragma HLS pipeline II=1 style=stp ap_uint<72> op1_mtx[16]; ap_uint<72> op2_mtx[16]; #pragma HLS array_partition variable=op1_mtx complete #pragma HLS array_partition variable=op2_mtx complete ap_uint<1024> recv_pkt; if(stage == 3) { recv_pkt = fifo_context.read(); }else if(stage == 4) { recv_pkt = fifo_ffn_in.read(); } for(int ii = 0; ii < 16; ii++){ #pragma HLS unroll if(stage > 2){ op1_mtx[ii] = ap_uint<72>(ap_uint<36>((ap_int<4>(2), W[j*16+ii][k]))); // dummy to enforce dotpra3 binding op2_mtx[ii] = ap_uint<72>((ap_int<8>(2), recv_pkt(ii*64+63, ii*64))); } else if(stage == 2) { op1_mtx[ii] = ap_uint<72>((ap_int<8>(2), scratchpad_k[j*16+ii][k])); op2_mtx[ii] = ap_uint<72>((ap_int<8>(2), scratchpad_q[i*16+ii][k])); } else { op1_mtx[ii] = ap_uint<72>(ap_uint<36>((ap_int<4>(2), W[j*16+ii][k]))); op2_mtx[ii] = ap_uint<72>((ap_int<8>(2), X[i*16+ii][k])); } } if(stage < 2){ ap_uint<1024> send_pkt = ap_uint<1024>(( op2_mtx[0](63,0), op2_mtx[1](63,0), op2_mtx[2](63,0), op2_mtx[3](63,0), op2_mtx[4](63,0), op2_mtx[5](63,0), op2_mtx[6](63,0), op2_mtx[7](63,0), op2_mtx[8](63,0), op2_mtx[9](63,0), op2_mtx[10](63,0), op2_mtx[11](63,0), op2_mtx[12](63,0), op2_mtx[13](63,0), op2_mtx[14](63,0), op2_mtx[15](63,0) )); fifo_X_out.write(send_pkt); } else if (stage == 4) { fifo_X_out.write(recv_pkt); } ap_uint<576> pack_x0; ap_uint<576> pack_x1; ap_uint<576> pack_y0; ap_uint<576> pack_y1; for(int ii = 0; ii < 8; ii++){ #pragma HLS unroll pack_x0(ii*72+71, ii*72) = op1_mtx[ii]; pack_y0(ii*72+71, ii*72) = op2_mtx[ii]; pack_x1(ii*72+71, ii*72) = op1_mtx[ii+8]; pack_y1(ii*72+71, ii*72) = op2_mtx[ii+8]; } fifo_x0.write(pack_x0); fifo_x1.write(pack_x1); fifo_y0.write(pack_y0); fifo_y1.write(pack_y1); } ap_int<24> acc_final[16][16]; #pragma HLS array_partition variable=acc_final dim=1 complete #pragma HLS array_partition variable=acc_final dim=2 complete for(int ii = 0; ii < 2; ii++){ #pragma HLS unroll for(int jj = 0; jj < 2; jj++){ #pragma HLS unroll auto pack = fifo_drain[ii*2+jj].read(); for(int k = 0; k < 8; k++){ #pragma HLS unroll for(int l = 0; l < 8; l++){ #pragma HLS unroll acc_final[ii*8+k][jj*8+l] = ap_int<24>(pack(k*192+l*24+23, k*192+l*24)); } } } } if(stage == 0){ for(int ii = 0; ii < 16; ii++){ #pragma HLS unroll for(int k = 0; k < 16; k++){ #pragma HLS unroll int offset = k%8; scratchpad_q[i*16+ii][j*2+k/8](offset*8+7, offset*8) = ap_int<8>(acc_final[ii][k] >> 8); } } } else if (stage == 1){ for(int ii = 0; ii < 4; ii++){ for(int jj = 0; jj < 2; jj++){ #pragma HLS pipeline II=1 style=stp ap_uint<256> tmp = fifo_from_acc1.read(); for(int l = 0; l < 4; l++){ #pragma HLS unroll ap_uint<64> tmp_pack; for(int k = 0; k < 8; k++){ #pragma HLS unroll tmp_pack(k*8+7, k*8) = ap_int<8>(acc_final[ii*4+l][jj*8+k] >> 8); } scratchpad_k[i*16+ii*4+l][j*4+jj*2] = tmp_pack; } for(int l = 0; l < 4; l++){ #pragma HLS unroll scratchpad_k[i*16+ii*4+l][j*4+jj*2+1] = tmp(l*64+63, l*64); } } } } else if(stage == 2 || stage == 4){ for(int kk = 0; kk < 16; kk++){ #pragma HLS pipeline II=1 style=stp ap_uint<512> tmp; for(int ii = 0; ii < 16; ii++){ #pragma HLS unroll if(stage == 2 && (i*16+ii < j*16+kk)){ tmp(ii*32+31, ii*32) = ap_int<32>(-1e8); // masking (inefficient) } else { tmp(ii*32+31, ii*32) = tapa::bit_cast>(acc_final[ii][kk]); } } if(stage == 2) fifo_O_out.write(tmp); else fifo_ffn_out.write(tmp); } } else { final_acc: for(int ii = 0; ii < 16; ii++){ #pragma HLS pipeline II=1 style=stp #pragma HLS dependence variable=X type=inter false ap_uint<512> tmp_recv = fifo_reduce_recv.read(); ap_uint<512> tmp_send; for(int k = 0; k < 16; k++){ #pragma HLS unroll ap_int<32> tmp = acc_final[ii][k] + ap_int<24>(tmp_recv(k*32+23, k*32)); tmp += ap_int<8>(X[i*16+ii][j*2+k/8]((k%8)*8+7, (k%8)*8)); tmp_send(k*32+31, k*32) = tmp; } fifo_res_send.write(tmp_send); } } } } } // fifo_fin.write(true); // write: // for(int i = 0; i < L; i++){ // for(int j = 0; j < D_div_8; j++){ // #pragma HLS pipeline II=1 style=stp // fifo_write.write(X[i][j]); // } // } } void temporal_acc0( tapa::istream& fifo_inst_in, tapa::ostream& fifo_inst_out, tapa::ostream& fifo_len_sfu, tapa::istream>& fifo_X_in, tapa::ostream>& fifo_X_out, // 8-bit activation tapa::istream>& fifo_W_in, tapa::ostream>& fifo_W_out, // 4-bit weight tapa::istream>& fifo_from_acc1, tapa::ostream>& fifo_O_out, tapa::istream>& fifo_context, tapa::ostream>& fifo_ffn_out, tapa::istream>& fifo_reduce_recv, tapa::ostream>& fifo_reduce_send, // systolic array tapa::ostream>& fifo_x0, tapa::ostream>& fifo_x1, tapa::ostream>& fifo_y0, tapa::ostream>& fifo_y1, tapa::ostream>& fifo_bound_x0, tapa::ostream>& fifo_bound_x1, tapa::istreams, 4>& fifo_drain ){ ap_uint<64> scratchpad_q[MAX_SEQ_LEN][D_head_div_8]; // 8 bit #pragma HLS array_partition variable=scratchpad_q cyclic dim=1 factor=16 #pragma HLS array_partition variable=scratchpad_q cyclic dim=2 factor=2 #pragma HLS bind_storage variable=scratchpad_q type=ram_2p impl=bram ap_uint<64> scratchpad_k[MAX_SEQ_LEN][D_head_div_8]; // 8 bit #pragma HLS array_partition variable=scratchpad_k cyclic dim=1 factor=16 #pragma HLS array_partition variable=scratchpad_k cyclic dim=2 factor=2 #pragma HLS bind_storage variable=scratchpad_k type=ram_2p impl=bram ConfigInst len = fifo_inst_in.read(); const int L = len.weight_bound; fifo_inst_out.write(len); fifo_len_sfu.write(L); for(int stage_i = 0; stage_i < 14; stage_i++){ #pragma HLS loop_flatten off // stage 0: WqX // stage 1: WkX0 <- acc1 // stage 2: QK^T // stage 3: WoO ap_uint<32> W[D][D_div_8]; // 4 bit #pragma HLS array_partition variable=W cyclic dim=1 factor=16 #pragma HLS bind_storage variable=W type=ram_2p impl=uram ConfigInst inst = fifo_inst_in.read(); fifo_inst_out.write(inst); const ap_uint<3> stage = inst.stage; // load weights and forward if(stage != 2) { const int weight_bound = inst.weight_bound; for(int i = 0; i < weight_bound; i++){ load_weight: for(int j = 0; j < D_div_8;){ if(!fifo_W_in.empty()){ ap_uint<512> val; fifo_W_in.try_read(val); for(int k = 0; k < 4; k++){ #pragma HLS unroll W[i*4+k][j] = ap_uint<32>(val(k*32+31, k*32)); } val = ap_uint<512>(val >> 128); fifo_W_out.write(val); j++; } } } } const int i_bound = inst.i_bound; const int j_bound = inst.j_bound; const int k_bound = inst.k_bound; // stage 1: compute Q for(int i = 0; i < i_bound; i++){ // make sure L is multiple of 64 for(int j = 0; (j < j_bound) & ((stage != 2) | (j <= i)); j++){ #pragma HLS loop_flatten off fifo_bound_x0.write(ap_uint<8>(k_bound)); fifo_bound_x0.write(ap_uint<8>(stage)); fifo_bound_x1.write(ap_uint<8>(k_bound)); fifo_bound_x1.write(ap_uint<8>(stage)); compute: for(int k = 0; k < k_bound; k++){ // reduction dim #pragma HLS pipeline II=1 style=stp ap_uint<72> op1_mtx[16]; ap_uint<72> op2_mtx[16]; #pragma HLS array_partition variable=op1_mtx complete #pragma HLS array_partition variable=op2_mtx complete ap_uint<1024> recv_pkt; if(stage == 3){ recv_pkt = fifo_context.read(); } else if(stage != 2) { recv_pkt = fifo_X_in.read(); fifo_X_out.write(recv_pkt); } for(int ii = 0; ii < 16; ii++){ #pragma HLS unroll if(stage > 2){ op1_mtx[ii] = ap_uint<72>(ap_uint<36>((ap_int<4>(2), W[j*16+ii][k]))); // dummy to enforce dotpra3 binding op2_mtx[ii] = ap_uint<72>((ap_int<8>(2), recv_pkt(ii*64+63, ii*64))); } else if(stage == 2) { op1_mtx[ii] = ap_uint<72>((ap_int<8>(2), scratchpad_k[j*16+ii][k])); op2_mtx[ii] = ap_uint<72>((ap_int<8>(2), scratchpad_q[i*16+ii][k])); } else { op1_mtx[ii] = ap_uint<72>(ap_uint<36>((ap_int<4>(2), W[j*16+ii][k]))); op2_mtx[ii] = ap_uint<72>((ap_int<8>(2), recv_pkt(ii*64+63, ii*64))); } } ap_uint<576> pack_x0; ap_uint<576> pack_x1; ap_uint<576> pack_y0; ap_uint<576> pack_y1; for(int ii = 0; ii < 8; ii++){ #pragma HLS unroll pack_x0(ii*72+71, ii*72) = op1_mtx[ii]; pack_y0(ii*72+71, ii*72) = op2_mtx[ii]; pack_x1(ii*72+71, ii*72) = op1_mtx[ii+8]; pack_y1(ii*72+71, ii*72) = op2_mtx[ii+8]; } fifo_x0.write(pack_x0); fifo_x1.write(pack_x1); fifo_y0.write(pack_y0); fifo_y1.write(pack_y1); } ap_int<24> acc_final[16][16]; #pragma HLS array_partition variable=acc_final dim=1 complete #pragma HLS array_partition variable=acc_final dim=2 complete // read 2x8 data for(int ii = 0; ii < 2; ii++){ #pragma HLS unroll for(int jj = 0; jj < 2; jj++){ #pragma HLS unroll auto pack = fifo_drain[ii*2+jj].read(); for(int k = 0; k < 8; k++){ #pragma HLS unroll for(int l = 0; l < 8; l++){ #pragma HLS unroll acc_final[ii*8+k][jj*8+l] = ap_int<24>(pack(k*192+l*24+23, k*192+l*24)); } } } } if(stage == 0){ for(int ii = 0; ii < 16; ii++){ #pragma HLS unroll for(int k = 0; k < 16; k++){ #pragma HLS unroll int offset = k%8; scratchpad_q[i*16+ii][j*2+k/8](offset*8+7, offset*8) = ap_int<8>(acc_final[ii][k] >> 8); } } } else if (stage == 1){ for(int ii = 0; ii < 4; ii++){ for(int jj = 0; jj < 2; jj++){ #pragma HLS pipeline II=1 style=stp ap_uint<256> tmp = fifo_from_acc1.read(); for(int l = 0; l < 4; l++){ #pragma HLS unroll ap_uint<64> tmp_pack; for(int k = 0; k < 8; k++){ #pragma HLS unroll tmp_pack(k*8+7, k*8) = ap_int<8>(acc_final[ii*4+l][jj*8+k] >> 8); } scratchpad_k[i*16+ii*4+l][j*4+jj*2] = tmp_pack; } for(int l = 0; l < 4; l++){ #pragma HLS unroll scratchpad_k[i*16+ii*4+l][j*4+jj*2+1] = tmp(l*64+63, l*64); } } } } else if(stage == 2 || stage == 4){ for(int kk = 0; kk < 16; kk++){ #pragma HLS pipeline II=1 style=stp ap_uint<512> tmp; for(int ii = 0; ii < 16; ii++){ #pragma HLS unroll if(stage == 2 && (i*16+ii < j*16+kk)){ tmp(ii*32+31, ii*32) = ap_int<32>(-1e8); // masking (inefficient) } else { tmp(ii*32+31, ii*32) = tapa::bit_cast>(acc_final[ii][kk]); } } if(stage == 2) fifo_O_out.write(tmp); else fifo_ffn_out.write(tmp); } } else { final_acc: for(int ii = 0; ii < 16; ii++){ #pragma HLS pipeline II=1 style=stp ap_uint<512> tmp_recv = fifo_reduce_recv.read(); ap_uint<512> tmp; for(int k = 0; k < 16; k++){ #pragma HLS unroll acc_final[ii][k] += ap_int<24>(tmp_recv(k*32+23, k*32)); tmp(k*32+23, k*32) = acc_final[ii][k]; } fifo_reduce_send.write(tmp); } } } } } } // acc slr0 master node void temporal_acc1_slr0( tapa::istream& fifo_inst_in, tapa::ostream& fifo_inst_out, tapa::ostream& fifo_len_context, tapa::istream>& fifo_X_in, tapa::ostream>& fifo_X_out, // 8-bit activation tapa::istream>& fifo_W_in, tapa::ostream>& fifo_W_out, // 4-bit weight tapa::ostream>& fifo_to_acc0, tapa::istream>& fifo_from_sfu, tapa::ostream>& fifo_O_out, tapa::istream>& fifo_context, tapa::istream>& fifo_reduce_recv, tapa::ostream>& fifo_res_send, tapa::istream>& fifo_gelu_in, tapa::ostream>& fifo_ffn_out, // systolic array tapa::ostream>& fifo_x0, tapa::ostream>& fifo_x1, tapa::ostream>& fifo_y0, tapa::ostream>& fifo_y1, tapa::ostream>& fifo_bound_x0, tapa::ostream>& fifo_bound_x1, tapa::istreams, 4>& fifo_drain ){ ap_uint<64> X[MAX_SEQ_LEN][D_div_8]; // 8 bit #pragma HLS array_partition variable=X cyclic dim=1 factor=16 #pragma HLS array_partition variable=X cyclic dim=2 factor=2 #pragma HLS bind_storage variable=X type=ram_2p impl=uram ap_uint<64> scratchpad[MAX_SEQ_LEN_div_8][D_head]; // 8 bit #pragma HLS array_partition variable=scratchpad cyclic dim=1 factor=2 #pragma HLS array_partition variable=scratchpad cyclic dim=2 factor=16 #pragma HLS bind_storage variable=scratchpad type=ram_2p impl=bram // ap_uint<64> scratchpad_out[MAX_SEQ_LEN][D_head_div_8]; // #pragma HLS array_partition variable=scratchpad_out cyclic dim=1 factor=16 // #pragma HLS array_partition variable=scratchpad_out cyclic dim=2 factor=2 ConfigInst len = fifo_inst_in.read(); const int L = len.weight_bound; fifo_inst_out.write(len); fifo_len_context.write(L); for(int stage_i = 0; stage_i < 14; stage_i++){ // stage 0: WvX // stage 1: WkX1 -> acc0 // stage 2: Softmax(QK)V <- acc0 // stage 3: WoO ap_uint<32> W[D][D_div_8]; // 4 bit #pragma HLS array_partition variable=W cyclic dim=1 factor=16 #pragma HLS bind_storage variable=W type=ram_2p impl=uram ConfigInst inst = fifo_inst_in.read(); fifo_inst_out.write(inst); const ap_uint<3> stage = inst.stage; // load weights and forward if(stage != 2) { const int weight_bound = inst.weight_bound; for(int i = 0; i < weight_bound; i++){ load_weight: for(int j = 0; j < D_div_8;){ if(!fifo_W_in.empty()){ ap_uint<512> val; fifo_W_in.try_read(val); for(int k = 0; k < 4; k++){ #pragma HLS unroll W[i*4+k][j] = ap_uint<32>(val(k*32+31, k*32)); } val = ap_uint<512>(val >> 128); fifo_W_out.write(val); j++; } } } } const int i_bound = inst.i_bound; const int j_bound = inst.j_bound; for(int i = 0; i < i_bound; i++){ // make sure L is multiple of 4 const int k_bound = (stage == 2) ? ap_uint<8>((i+1)*2) : inst.k_bound; ap_uint<64> cache_attn[MAX_SEQ_LEN_div_8][16]; #pragma HLS array_partition variable=cache_attn dim=2 complete if(stage_i == 0){ for(int ii = 0; ii < 2; ii++){ // load only 1 time load_x: for(int jj = 0; jj < D_div_8;){ if(!fifo_X_in.empty()){ ap_uint<512> val; fifo_X_in.try_read(val); for(int k = 0; k < 8; k++){ #pragma HLS unroll X[i*16+ii*8+k][jj] = ap_uint<64>(val(k*64+63, k*64)); } jj++; } } } } else if (stage == 2) { for(int ii = 0; ii < ((i+1)*2); ii++){ ap_uint<64> fuse_reg[16]; load_attn: for(int offset = 0; offset < 8;){ #pragma HLS pipeline II=1 style=stp if(!fifo_from_sfu.empty()){ ap_uint<128> val; fifo_from_sfu.try_read(val); for(int k = 0; k < 16; k++){ #pragma HLS unroll fuse_reg[k](offset*8+7, offset*8) = ap_int<8>(val(k*8+7, k*8)); } offset++; } } for(int k = 0; k < 16; k++){ #pragma HLS unroll cache_attn[ii][k] = fuse_reg[k]; } } } for(int j = 0; j < j_bound; j++){ #pragma HLS loop_flatten off fifo_bound_x0.write(ap_uint<8>(k_bound)); fifo_bound_x0.write(ap_uint<8>(stage)); fifo_bound_x1.write(ap_uint<8>(k_bound)); fifo_bound_x1.write(ap_uint<8>(stage)); compute: for(int k = 0; k < k_bound; k++){ #pragma HLS pipeline II=1 style=stp ap_uint<72> op1_mtx[16]; ap_uint<72> op2_mtx[16]; #pragma HLS array_partition variable=op1_mtx complete #pragma HLS array_partition variable=op2_mtx complete ap_uint<1024> recv_pkt; if(stage == 3) { recv_pkt = fifo_context.read(); } else if(stage == 4) { recv_pkt = fifo_gelu_in.read(); } for(int ii = 0; ii < 16; ii++){ #pragma HLS unroll if(stage >= 3){ op1_mtx[ii] = ap_uint<72>(ap_uint<36>((ap_int<4>(2), W[j*16+ii][k]))); op2_mtx[ii] = ap_uint<72>((ap_int<8>(2), recv_pkt(ii*64+63, ii*64))); } else if(stage != 2) { op1_mtx[ii] = ap_uint<72>(ap_uint<36>((ap_int<4>(2), W[j*16+ii][k]))); op2_mtx[ii] = ap_uint<72>((ap_int<8>(2), X[i*16+ii][k])); } else { op1_mtx[ii] = ap_uint<72>((ap_int<8>(2), scratchpad[k][j*16+ii])); op2_mtx[ii] = ap_uint<72>((ap_int<8>(2), cache_attn[k][ii])); } } if(stage < 2){ ap_uint<1024> send_pkt = ap_uint<1024>(( op2_mtx[0](63,0), op2_mtx[1](63,0), op2_mtx[2](63,0), op2_mtx[3](63,0), op2_mtx[4](63,0), op2_mtx[5](63,0), op2_mtx[6](63,0), op2_mtx[7](63,0), op2_mtx[8](63,0), op2_mtx[9](63,0), op2_mtx[10](63,0), op2_mtx[11](63,0), op2_mtx[12](63,0), op2_mtx[13](63,0), op2_mtx[14](63,0), op2_mtx[15](63,0) )); fifo_X_out.write(send_pkt); } ap_uint<576> pack_x0; ap_uint<576> pack_x1; ap_uint<576> pack_y0; ap_uint<576> pack_y1; for(int ii = 0; ii < 8; ii++){ #pragma HLS unroll pack_x0(ii*72+71, ii*72) = op1_mtx[ii]; pack_y0(ii*72+71, ii*72) = op2_mtx[ii]; pack_x1(ii*72+71, ii*72) = op1_mtx[ii+8]; pack_y1(ii*72+71, ii*72) = op2_mtx[ii+8]; } fifo_x0.write(pack_x0); fifo_x1.write(pack_x1); fifo_y0.write(pack_y0); fifo_y1.write(pack_y1); } ap_int<27> acc_final[16][16]; #pragma HLS array_partition variable=acc_final dim=1 complete #pragma HLS array_partition variable=acc_final dim=2 complete for(int ii = 0; ii < 2; ii++){ #pragma HLS unroll for(int jj = 0; jj < 2; jj++){ #pragma HLS unroll auto pack = fifo_drain[ii*2+jj].read(); for(int k = 0; k < 8; k++){ #pragma HLS unroll for(int l = 0; l < 8; l++){ #pragma HLS unroll acc_final[ii*8+k][jj*8+l] = ap_int<27>(pack(k*216+l*27+26, k*216+l*27)); } } } } if(stage == 0){ for(int ii = 0; ii < 16; ii++){ #pragma HLS unroll for(int k = 0; k < 16; k++){ #pragma HLS unroll int offset = ii%8; scratchpad[i*2+ii/8][j*16+k](offset*8+7, offset*8) = ap_int<8>(acc_final[ii][k] >> 8); } } } else if (stage == 2){ for(int ii = 0; ii < 2; ii++){ #pragma HLS pipeline II=1 style=stp ap_uint<1024> tmp; for(int jj = 0; jj < 8; jj++){ #pragma HLS unroll for(int k = 0; k < 16; k++){ #pragma HLS unroll tmp((jj*16+k)*8+7, (jj*16+k)*8) = ap_int<8>(acc_final[ii*8+jj][k] >> 12); } } fifo_O_out.write(tmp); } } else if (stage == 1) { for(int ii = 0; ii < 4; ii++){ for(int jj = 0; jj < 2; jj++){ ap_uint<256> tmp; for(int k = 0; k < 32; k++){ #pragma HLS unroll tmp(k*8+7, k*8) = ap_int<8>(acc_final[ii*4+k/8][jj*8+k%8] >> 8); } fifo_to_acc0.write(tmp); } } } else { final_acc: for(int ii = 0; ii < 16; ii++){ #pragma HLS pipeline II=1 style=stp #pragma HLS dependence variable=X type=inter false ap_uint<512> tmp_recv = fifo_reduce_recv.read(); ap_uint<512> tmp_send; for(int k = 0; k < 16; k++){ #pragma HLS unroll ap_int<32> tmp = acc_final[ii][k] + ap_int<27>(tmp_recv(k*32+26, k*32)); if(stage == 3) tmp += ap_int<8>(X[i*16+ii][j*2+k/8]((k%8)*8+7, (k%8)*8)); tmp_send(k*32+31, k*32) = tmp; } if(stage == 3) fifo_res_send.write(tmp_send); else fifo_ffn_out.write(tmp_send); } } } } } // fifo_fin.write(true); // write out for debug // write: // for(int i = 0; i < L; i++){ // for(int j = 0; j < D_div_8; j++){ // #pragma HLS pipeline II=1 style=stp // fifo_write.write(X[i][j]); // } // } } void residual( const int L, tapa::istream>& fifo_res_in, tapa::ostream>& fifo_res_out ){ for(int i = 0; i < (L >> 5); i++){ for(int j = 0; j < D_div_16; j++){ ap_uint<32> res_buffer[16][16]; #pragma HLS array_partition variable=res_buffer complete dim=1 #pragma HLS array_partition variable=res_buffer complete dim=2 read: for(int k = 0; k < 16;){ #pragma HLS pipeline II=1 style=stp ap_uint<512> tmp; bool success = fifo_res_in.try_read(tmp); if(success){ for(int l = 0; l < 16; l++){ #pragma HLS unroll res_buffer[k][l] = ap_uint<32>(tmp(l*32+31, l*32)); } k++; } } transpose: for(int k = 0; k < 16; k++){ #pragma HLS pipeline II=1 style=stp ap_uint<512> tmp; for(int l = 0; l < 16; l++){ #pragma HLS unroll tmp(l*32+31, l*32) = ap_uint<32>(res_buffer[l][k]); } fifo_res_out.write(tmp); } } } } void temporal_acc1( tapa::istream& fifo_inst_in, tapa::ostream& fifo_inst_out, tapa::ostream& fifo_len_context, tapa::istream>& fifo_X_in, tapa::ostream>& fifo_X_out, // 8-bit activation tapa::istream>& fifo_W_in, tapa::ostream>& fifo_W_out, // 4-bit weight tapa::ostream>& fifo_to_acc0, tapa::istream>& fifo_from_sfu, tapa::ostream>& fifo_O_out, tapa::istream>& fifo_context, tapa::istream>& fifo_reduce_recv, tapa::ostream>& fifo_reduce_send, tapa::istream>& fifo_gelu_in, // systolic array tapa::ostream>& fifo_x0, tapa::ostream>& fifo_x1, tapa::ostream>& fifo_y0, tapa::ostream>& fifo_y1, tapa::ostream>& fifo_bound_x0, tapa::ostream>& fifo_bound_x1, tapa::istreams, 4>& fifo_drain ){ ap_uint<64> scratchpad[MAX_SEQ_LEN_div_8][D_head]; // 8 bit #pragma HLS array_partition variable=scratchpad cyclic dim=1 factor=2 #pragma HLS array_partition variable=scratchpad cyclic dim=2 factor=16 #pragma HLS bind_storage variable=scratchpad type=ram_2p impl=bram // ap_uint<64> scratchpad_out[MAX_SEQ_LEN][D_head_div_8]; // #pragma HLS array_partition variable=scratchpad_out cyclic dim=1 factor=16 // #pragma HLS array_partition variable=scratchpad_out cyclic dim=2 factor=2 ConfigInst len = fifo_inst_in.read(); const int L = len.weight_bound; fifo_inst_out.write(len); fifo_len_context.write(L); for(int stage_i = 0; stage_i < 14; stage_i++){ // stage 0: WvX // stage 1: WkX1 -> acc0 // stage 2: Softmax(QK)V <- acc0 // stage 3: WoO ap_uint<32> W[D][D_div_8]; // 4 bit #pragma HLS array_partition variable=W cyclic dim=1 factor=16 #pragma HLS bind_storage variable=W type=ram_2p impl=uram ConfigInst inst = fifo_inst_in.read(); fifo_inst_out.write(inst); const ap_uint<3> stage = inst.stage; // load weights and forward if(stage != 2) { const int weight_bound = inst.weight_bound; for(int i = 0; i < weight_bound; i++){ load_weight: for(int j = 0; j < D_div_8;){ if(!fifo_W_in.empty()){ ap_uint<512> val; fifo_W_in.try_read(val); for(int k = 0; k < 4; k++){ #pragma HLS unroll W[i*4+k][j] = ap_uint<32>(val(k*32+31, k*32)); } val = ap_uint<512>(val >> 128); fifo_W_out.write(val); j++; } } } } const int i_bound = inst.i_bound; const int j_bound = inst.j_bound; for(int i = 0; i < i_bound; i++){ // make sure L is multiple of 4 const int k_bound = (stage == 2) ? ap_uint<8>((i+1)*2) : inst.k_bound; ap_uint<64> cache_attn[MAX_SEQ_LEN_div_8][16]; #pragma HLS array_partition variable=cache_attn dim=2 complete if(stage == 2){ for(int ii = 0; ii < (i+1)*2; ii++){ ap_uint<64> fuse_reg[16]; load_attn: for(int offset = 0; offset < 8;){ #pragma HLS pipeline II=1 style=stp if(!fifo_from_sfu.empty()){ ap_uint<128> val; fifo_from_sfu.try_read(val); for(int k = 0; k < 16; k++){ #pragma HLS unroll fuse_reg[k](offset*8+7, offset*8) = ap_int<8>(val(k*8+7, k*8)); } offset++; } } for(int k = 0; k < 16; k++){ #pragma HLS unroll cache_attn[ii][k] = fuse_reg[k]; } } } for(int j = 0; j < j_bound; j++){ #pragma HLS loop_flatten off fifo_bound_x0.write(ap_uint<8>(k_bound)); fifo_bound_x0.write(ap_uint<8>(stage)); fifo_bound_x1.write(ap_uint<8>(k_bound)); fifo_bound_x1.write(ap_uint<8>(stage)); compute: for(int k = 0; k < k_bound; k++){ #pragma HLS pipeline II=1 style=stp ap_uint<72> op1_mtx[16]; ap_uint<72> op2_mtx[16]; #pragma HLS array_partition variable=op1_mtx complete #pragma HLS array_partition variable=op2_mtx complete ap_uint<1024> recv_pkt; if(stage == 3) { recv_pkt = fifo_context.read(); }else if(stage == 4) { recv_pkt = fifo_gelu_in.read(); }else if(stage != 2) { recv_pkt = fifo_X_in.read(); fifo_X_out.write(recv_pkt); } for(int ii = 0; ii < 16; ii++){ //TODO: change logic #pragma HLS unroll if (stage != 2) { op1_mtx[ii] = ap_uint<72>(ap_uint<36>((ap_int<4>(2), W[j*16+ii][k]))); op2_mtx[ii] = ap_uint<72>((ap_int<8>(2), recv_pkt(ii*64+63, ii*64))); } else { op1_mtx[ii] = ap_uint<72>((ap_int<8>(2), scratchpad[k][j*16+ii])); op2_mtx[ii] = ap_uint<72>((ap_int<8>(2), cache_attn[k][ii])); } } ap_uint<576> pack_x0; ap_uint<576> pack_x1; ap_uint<576> pack_y0; ap_uint<576> pack_y1; for(int ii = 0; ii < 8; ii++){ #pragma HLS unroll pack_x0(ii*72+71, ii*72) = op1_mtx[ii]; pack_y0(ii*72+71, ii*72) = op2_mtx[ii]; pack_x1(ii*72+71, ii*72) = op1_mtx[ii+8]; pack_y1(ii*72+71, ii*72) = op2_mtx[ii+8]; } fifo_x0.write(pack_x0); fifo_x1.write(pack_x1); fifo_y0.write(pack_y0); fifo_y1.write(pack_y1); } ap_int<27> acc_final[16][16]; #pragma HLS array_partition variable=acc_final dim=1 complete #pragma HLS array_partition variable=acc_final dim=2 complete for(int ii = 0; ii < 2; ii++){ #pragma HLS unroll for(int jj = 0; jj < 2; jj++){ #pragma HLS unroll auto pack = fifo_drain[ii*2+jj].read(); for(int k = 0; k < 8; k++){ #pragma HLS unroll for(int l = 0; l < 8; l++){ #pragma HLS unroll acc_final[ii*8+k][jj*8+l] = ap_int<27>(pack(k*216+l*27+26, k*216+l*27)); } } } } if(stage == 0){ for(int ii = 0; ii < 16; ii++){ #pragma HLS unroll for(int k = 0; k < 16; k++){ #pragma HLS unroll int offset = ii%8; scratchpad[i*2+ii/8][j*16+k](offset*8+7, offset*8) = ap_int<8>(acc_final[ii][k] >> 8); } } } else if (stage == 2){ for(int ii = 0; ii < 2; ii++){ #pragma HLS pipeline II=1 style=stp ap_uint<1024> tmp; for(int jj = 0; jj < 8; jj++){ #pragma HLS unroll for(int k = 0; k < 16; k++){ #pragma HLS unroll tmp((jj*16+k)*8+7, (jj*16+k)*8) = ap_int<8>(acc_final[ii*8+jj][k] >> 12); } } fifo_O_out.write(tmp); } } else if (stage == 1){ for(int ii = 0; ii < 4; ii++){ for(int jj = 0; jj < 2; jj++){ ap_uint<256> tmp; for(int k = 0; k < 32; k++){ #pragma HLS unroll tmp(k*8+7, k*8) = ap_int<8>(acc_final[ii*4+k/8][jj*8+k%8] >> 8); } fifo_to_acc0.write(tmp); } } } else { final_acc: for(int ii = 0; ii < 16; ii++){ #pragma HLS pipeline II=1 style=stp ap_uint<512> tmp_recv = fifo_reduce_recv.read(); ap_uint<512> tmp; for(int k = 0; k < 16; k++){ #pragma HLS unroll acc_final[ii][k] += ap_int<27>(tmp_recv(k*32+26, k*32)); tmp(k*32+26, k*32) = acc_final[ii][k]; } fifo_reduce_send.write(tmp); } } } } } // write out for debug // write: // for(int i = 0; i < L; i++){ // for(int j = 0; j < D_head_div_8; j++){ // #pragma HLS pipeline II=1 style=stp // fifo_O_out.write(scratchpad_out[i][j]); // } // } } void sfu_buffer( // double buffering tapa::istream& fifo_inst, tapa::istream>& fifo_data_in, tapa::ostream>& fifo_data_out ){ const int L = fifo_inst.read(); for(int stage = 0; stage < 4; stage++){ for(int l = 0; l < (L >> 5); l++){ float sum[8][16]; float cache[MAX_SEQ_LEN][16]; #pragma HLS array_partition variable=cache dim=2 complete #pragma HLS array_partition variable=sum dim=2 complete const int hidden_bound = fifo_inst.read(); for(int i = 0; i < 8; i++){ for(int j = 0; j < 16; j++){ #pragma HLS unroll sum[i][j] = 0.0; } } acc: for(int i = 0; i < hidden_bound; i++){ #pragma HLS pipeline II=1 style=stp #pragma HLS dependence false variable=sum #pragma HLS dependence true variable=sum distance=8 ap_uint<512> tmp = fifo_data_in.read(); for(int k = 0; k < 16; k++){ #pragma HLS unroll float res = tapa::bit_cast(ap_int<32>(tmp(k*32+31, k*32))); sum[i%8][k] += res; cache[i][k] = res; } } reduce: for(int i = 1; i < 8; i++){ for(int j = 0; j < 8; j++){ #pragma HLS pipeline II=1 style=stp #pragma HLS dependence true variable=sum distance=8 for(int k = 0; k < 2; k++){ sum[0][j*2+k] += sum[i][j*2+k]; } } } ap_uint<512> tmp; for(int i = 0; i < 16; i++){ #pragma HLS unroll tmp(i*32+31, i*32) = tapa::bit_cast>(sum[0][i]); } fifo_data_out.write(tmp); write: for(int i = 0; i < hidden_bound; i++){ #pragma HLS pipeline II=1 style=stp ap_uint<512> tmp; for(int j = 0; j < 16; j++){ #pragma HLS unroll tmp(j*32+31, j*32) = tapa::bit_cast>(cache[i][j]); } fifo_data_out.write(tmp); } } } } void sfu_buffer_slr0( // double buffering tapa::istream& fifo_inst, tapa::istream>& fifo_data_in_exp, tapa::istream>& fifo_data_in_ln, tapa::istream>& fifo_data_in_ffn, tapa::ostream>& fifo_data_out ){ const int L = fifo_inst.read(); for(int stage = 0; stage < 6; stage++){ int hidden_bound = D; for(int l = 0; l < (L >> 5); l++){ float sum[8][16]; float var[8][16]; float cache[MAX_SEQ_LEN][16]; #pragma HLS array_partition variable=cache dim=2 complete #pragma HLS array_partition variable=sum dim=2 complete #pragma HLS array_partition variable=var dim=2 complete if(stage < 4) hidden_bound = fifo_inst.read(); for(int i = 0; i < 8; i++){ for(int j = 0; j < 16; j++){ #pragma HLS unroll sum[i][j] = 0.0; var[i][j] = 0.0; } } acc: for(int i = 0; i < hidden_bound; i++){ #pragma HLS pipeline II=1 style=stp #pragma HLS dependence false variable=sum #pragma HLS dependence true variable=sum distance=8 ap_uint<512> tmp; if(stage < 4) { tmp = fifo_data_in_exp.read(); } else if(stage == 4){ tmp = fifo_data_in_ln.read(); } else { tmp = fifo_data_in_ffn.read(); } for(int k = 0; k < 16; k++){ #pragma HLS unroll float res = tapa::bit_cast(ap_int<32>(tmp(k*32+31, k*32))); sum[i%8][k] += res; if(stage >= 4) var[i%8][k] += (res * res); cache[i][k] = res; } } reduce: for(int i = 1; i < 8; i++){ for(int j = 0; j < 8; j++){ #pragma HLS pipeline II=1 style=stp #pragma HLS dependence true variable=sum distance=8 #pragma HLS dependence true variable=var distance=8 for(int k = 0; k < 2; k++){ sum[0][j*2+k] += sum[i][j*2+k]; if(stage >= 4) var[0][j*2+k] += var[i][j*2+k]; } } } ap_uint<512> tmp; ap_uint<512> tmp_var; for(int i = 0; i < 16; i++){ #pragma HLS unroll tmp(i*32+31, i*32) = tapa::bit_cast>(sum[0][i]); if(stage >= 4) tmp_var(i*32+31, i*32) = tapa::bit_cast>(var[0][i]); } fifo_data_out.write(tmp); if(stage >= 4) fifo_data_out.write(tmp_var); write: for(int i = 0; i < hidden_bound; i++){ #pragma HLS pipeline II=1 style=stp ap_uint<512> tmp; for(int j = 0; j < 16; j++){ #pragma HLS unroll tmp(j*32+31, j*32) = tapa::bit_cast>(cache[i][j]); } fifo_data_out.write(tmp); } } } } void sfu_acc_exp( tapa::istream& fifo_inst, tapa::istream>& fifo_data_in, tapa::ostreams, 2>& fifo_buf, tapa::ostreams& fifo_inst_out ) { const int L = fifo_inst.read(); fifo_inst_out[0].write(L); fifo_inst_out[1].write(L); for(int stage = 0; stage < 4; stage++){ for(int l = 0; l < (L >> 4); l++){ fifo_inst_out[l%2].write(((l+1) << 4)); exp_acc: for(int i = 0; i < ((l+1) << 4);){ #pragma HLS pipeline II=1 style=stp if(!fifo_data_in.empty()){ ap_uint<512> tmp; fifo_data_in.try_read(tmp); ap_uint<512> tmp_o; for(int k = 0; k < 16; k++){ #pragma HLS unroll int res = tapa::bit_cast(ap_int<32>(tmp(k*32+31, k*32))); float res_exp = 0.0; res_exp = hls::exp(ap_int<32>(res >> 10)); tmp_o(k*32+31, k*32) = tapa::bit_cast>(res_exp); } fifo_buf[l%2].write(tmp_o); i++; } } } } } void sfu_gelu( tapa::istream& fifo_inst, tapa::ostream& fifo_inst_out, tapa::istream>& fifo_ffn, tapa::ostream>& fifo_out ){ const int L = fifo_inst.read(); fifo_inst_out.write(L); for(int i = 0; i < (L >> 4); i++){ for(int j = 0; j < D;){ if(!fifo_ffn.empty()){ ap_uint<512> tmp; fifo_ffn.try_read(tmp); ap_uint<128> tmp_out; for(int k = 0; k < 16; k++){ // piecewise linear approximation: https://github.com/cornell-zhang/allo-pldi24-artifact/blob/main/llm/gpt_region_3.cpp#L335 float val = (float) tapa::bit_cast(ap_int<32>(tmp(k*32+31, k*32))); float outp_data = 0.0; if (val < -3) outp_data = 0; else if(val < -1) outp_data = (float) -0.0773 * (val + (float) 3) - (float) 0.004; else if(val < 0) outp_data = (float) 0.1587 * val; else if(val < 1) outp_data = (float) 0.8413 * val; else if(val < 3) outp_data = (float) 1.0773 * (val - (float) 1) + (float) 0.8413; else outp_data = val; tmp_out(k*8+7, k*8) = ap_int<8>((int) (outp_data) >> 8); } fifo_out.write(tmp_out); j++; } } } } void data_packing( tapa::istream& fifo_inst, tapa::istream>& fifo_in, tapa::ostream>& fifo_out ){ const int L = fifo_inst.read(); for(int i = 0; i < (L >> 4); i++){ ap_uint<1024> cache[D_div_8]; #pragma HLS bind_storage variable=cache type=ram_2p impl=uram for(int j = 0; j < D_div_8; j++){ ap_uint<64> fuse_reg[16]; ap_uint<1024> send_pkt; #pragma HLS array_partition variable=fuse_reg complete for(int k = 0; k < 8;){ #pragma HLS pipeline II=1 if(!fifo_in.empty()){ ap_uint<128> tmp; fifo_in.try_read(tmp); for(int l = 0; l < 16; l++){ #pragma HLS unroll fuse_reg[l](k*8+7, k*8) = tmp(l*8+7, l*8); } k++; } } send_pkt = ap_uint<1024>(( fuse_reg[0], fuse_reg[1], fuse_reg[2], fuse_reg[3], fuse_reg[4], fuse_reg[5], fuse_reg[6], fuse_reg[7], fuse_reg[8], fuse_reg[9], fuse_reg[10], fuse_reg[11], fuse_reg[12], fuse_reg[13], fuse_reg[14], fuse_reg[15] )); cache[j] = send_pkt; fifo_out.write(send_pkt); } for(int iter = 0; iter < D_div_16 - 1; iter++){ for(int j = 0; j < D_div_8; j++){ #pragma HLS pipeline II=1 fifo_out.write(cache[j]); } } } } void sfu_norm( tapa::istream& fifo_inst, tapa::istreams, 2>& fifo_buf, tapa::ostream>& fifo_data_out ){ const int L = fifo_inst.read(); for(int stage = 0; stage < 4; stage++){ for(int l = 0; l < (L >> 4); l++){ float sum[16]; #pragma HLS array_partition variable=sum complete ap_uint<512> tmp_in = fifo_buf[l%2].read(); for(int i = 0; i < 16; i++){ #pragma HLS unroll sum[i] = 32.0 / tapa::bit_cast(ap_uint<32>(tmp_in(i*32+31, i*32))); } for(int i = 0; i < ((l+1) << 4);){ #pragma HLS pipeline II=1 style=stp if(!fifo_buf[l%2].empty()){ ap_uint<512> tmp_cache; fifo_buf[l%2].try_read(tmp_cache); ap_uint<128> tmp; for(int j = 0; j < 16; j++){ #pragma HLS unroll ap_int<8> res = (int) (tapa::bit_cast(ap_uint<32>(tmp_cache(j*32+31, j*32))) * sum[j]); tmp(j*8 + 7, j*8) = res; } fifo_data_out.write(tmp); i++; } } } } } void sfu_norm_slr0( tapa::istream& fifo_inst, tapa::istreams, 2>& fifo_buf, tapa::ostream>& fifo_data_out, tapa::ostream>& fifo_data_off, tapa::ostream>& fifo_out ){ const int L = fifo_inst.read(); for(int stage = 0; stage < 6; stage++){ for(int l = 0; l < (L >> 4); l++){ float sum[16]; float mean[16]; float var[16]; #pragma HLS array_partition variable=sum complete #pragma HLS array_partition variable=mean complete #pragma HLS array_partition variable=var complete const int fifo_idx = l%2; const int hidden_bound = (stage < 4) ? ((l+1) << 4) : D; ap_uint<512> tmp_in = fifo_buf[fifo_idx].read(); ap_uint<512> tmp_var; if(stage >= 4) tmp_var = fifo_buf[fifo_idx].read(); if(stage >= 4){ for(int i = 0; i < 16; i++){ #pragma HLS unroll mean[i] = tapa::bit_cast(ap_uint<32>(tmp_in(i*32+31, i*32))) / D; var[i] = 1024 / (tapa::bit_cast(ap_uint<32>(tmp_var(i*32+31, i*32))) / D - mean[i]*mean[i] + (float) 0.00001); } } else { for(int i = 0; i < 16; i++){ #pragma HLS unroll sum[i] = 32.0 / tapa::bit_cast(ap_uint<32>(tmp_in(i*32+31, i*32))); } } for(int i = 0; i < hidden_bound;){ #pragma HLS pipeline II=1 style=stp if(!fifo_buf[fifo_idx].empty()){ ap_uint<512> tmp_cache; fifo_buf[fifo_idx].try_read(tmp_cache); ap_uint<128> tmp; for(int j = 0; j < 16; j++){ #pragma HLS unroll ap_int<8> res; float op1; float op2; if(stage >= 4){ op1 = tapa::bit_cast(ap_uint<32>(tmp_cache(j*32+31, j*32))) - mean[j]; op2 = std::sqrt(var[j]); } else { op1 = tapa::bit_cast(ap_uint<32>(tmp_cache(j*32+31, j*32))); op2 = sum[j]; } res = (int) (op1 * op2); tmp(j*8 + 7, j*8) = res; } if(stage == 4) { fifo_data_off.write(tmp); } else if(stage == 5){ fifo_out.write(tmp); } else { fifo_data_out.write(tmp); } i++; } } } } } void context_buffer( tapa::istream& fifo_inst, tapa::istream>& fifo_context, tapa::ostream>& fifo_to_acc0, tapa::ostream>& fifo_to_acc1 ){ ap_uint<64> context[MAX_SEQ_LEN][D_head_div_2]; #pragma HLS array_partition variable=context cyclic dim=1 factor=32 #pragma HLS bind_storage variable=context type=ram_2p impl=uram const int L = fifo_inst.read(); for(int stage = 0; stage < 4; stage++){ for(int i = 0; i < (L >> 4); i++){ for(int j = stage * D_head_div_8; j < (stage + 1) * D_head_div_8;){ if(!fifo_context.empty()){ ap_uint<1024> tmp; fifo_context.try_read(tmp); for(int ii = 0; ii < 16; ii++){ #pragma HLS unroll context[i*16+ii][j] = tmp(ii*64+63, ii*64); } j++; } } } } // NOTE: change it to write to HBM for debugging // write ops to acc0 and acc1 in parallel for(int i = 0; i < (L >> 5); i++){ for(int l = 0; l < D_div_16; l++){ for(int j = 0; j < D_head_div_2; j++){ ap_uint<1024> tmp_acc0; ap_uint<1024> tmp_acc1; for(int k = 0; k < 16; k++){ #pragma HLS unroll tmp_acc0(k*64+63, k*64) = context[i*32+k][j]; tmp_acc1(k*64+63, k*64) = context[i*32+16+k][j]; } fifo_to_acc0.write(tmp_acc0); fifo_to_acc1.write(tmp_acc1); } } } } void ffn_buffer( const int L, tapa::istream>& fifo_ffn_in, tapa::ostream>& fifo_ffn_out, tapa::ostream>& fifo_ffn_res ){ ap_uint<64> X[MAX_SEQ_LEN][D_div_8]; // 8 bit #pragma HLS array_partition variable=X cyclic dim=1 factor=16 #pragma HLS bind_storage variable=X type=ram_2p impl=uram for(int i = 0; i < (L >> 4); i++){ for(int j = 0; j < D_div_8; j++){ ap_uint<64> fuse_reg[16]; #pragma HLS array_partition variable=fuse_reg complete for(int l = 0; l < 8;){ #pragma HLS pipeline II=1 if(!fifo_ffn_in.empty()){ ap_uint<128> tmp; fifo_ffn_in.try_read(tmp); for(int k = 0; k < 16; k++){ #pragma HLS unroll fuse_reg[k](l*8+7, l*8) = tmp(k*8+7, k*8); } l++; } } for(int k = 0; k < 16; k++){ #pragma HLS unroll X[i*16+k][j] = fuse_reg[k]; } } } for(int i = 0; i < (L >> 4); i++){ for(int iter = 0; iter < D_div_16; iter++){ for(int j = 0; j < D_div_8; j++){ ap_uint<1024> tmp; for(int k = 0; k < 16; k++){ #pragma HLS unroll tmp(k*64+63, k*64) = X[i*16+k][j]; } fifo_ffn_out.write(tmp); } for(int j = 0; j < 2; j++){ ap_uint<1024> send; for(int k = 0; k < 16; k++){ #pragma HLS unroll send(k*64+63, k*64) = X[i*16+k][iter*2+j]; } fifo_ffn_res.write(send); } } } } void ffn_residual( const int L, tapa::istream>& fifo_x, tapa::istream>& fifo_in, tapa::ostreams, 2>& fifo_out ){ for(int i = 0; i < (L >> 4); i++){ for(int j = 0; j < D_div_8; j++){ ap_uint<1024> tmp_x = fifo_x.read(); for(int k = 0; k < 8;){ if(!fifo_in.empty()){ ap_uint<512> tmp; fifo_in.try_read(tmp); ap_uint<512> tmp_o; for(int l = 0; l < 16; l++){ #pragma HLS unroll ap_int<32> a = tmp(l*32+31, l*32); ap_int<8> b = tmp_x(l*64+k*8+7, l*64+k*8); ap_int<32> res = a + b; tmp_o(l*32+31, l*32) = res; } fifo_out[i%2].write(tmp_o); k++; } } } } } void measure_cycle(tapa::istream& fifo_fin, tapa::mmap cycle_count){ for(int cycle = 0;;cycle++){ if(!fifo_fin.empty()){ fifo_fin.read(nullptr); cycle_count[0] = cycle; break; } } } void opt_kernel( const int L, const int L_out, const int seq_len, // tapa::mmap inst, // inst[0] = L, inst[1] = reload_weight tapa::mmap> X_acc0, tapa::mmap> X_acc1, tapa::mmap> W_acc0, tapa::mmap> W_acc1, tapa::mmap> acc0_out, // tapa::mmap> acc1_out, tapa::mmap cycle_count ){ tapa::streams fifo_inst_acc0("fifo_inst_acc0"); tapa::streams fifo_inst_acc1("fifo_inst_acc1"); tapa::stream, 16> fifo_X_acc0_slr0("fifo_X_acc0_slr0"); tapa::stream, 16> fifo_X_acc1_slr0("fifo_X_acc1_slr0"); tapa::streams, NUM_SLR, 16> fifo_X_acc0("fifo_X_acc0"); tapa::streams, NUM_SLR, 16> fifo_X_acc1("fifo_X_acc1"); tapa::streams, NUM_SLR+1, 8> fifo_W_acc0("fifo_W_acc0"); tapa::streams, NUM_SLR+1, 8> fifo_W_acc1("fifo_W_acc1"); // tapa::streams, NUM_SLR, 4> fifo_acc0_out("fifo_acc0_out"); tapa::streams, NUM_SLR, 16> fifo_acc0_to_sfu("fifo_acc0_to_sfu"); tapa::streams, NUM_SLR*2> fifo_sfu_buf_in("fifo_sfu_buf_in"); tapa::streams, NUM_SLR*2> fifo_sfu_buf_out("fifo_sfu_buf_out"); // tapa::streams, NUM_SLR> fifo_acc1_out("fifo_acc1_out"); tapa::streams, NUM_SLR, 8> fifo_from_acc1_to_acc0("fifo_from_acc1_to_acc0"); tapa::streams, NUM_SLR, 2> fifo_from_sfu_to_acc1("fifo_from_sfu_to_acc1"); tapa::stream fifo_fin("fifo_fin"); tapa::streams, NUM_SLR> fifo_context("fifo_context"); tapa::streams, NUM_SLR, 16> fifo_cont_to_acc0("fifo_cont_to_acc0"); tapa::streams, NUM_SLR, 16> fifo_cont_to_acc1("fifo_cont_to_acc1"); tapa::streams, NUM_SLR> fifo_reduce_acc0("fifo_reduce_acc0"); tapa::streams, NUM_SLR> fifo_reduce_acc1("fifo_reduce_acc1"); // tapa::stream> fifo_acc0_out("fifo_acc0_out"); tapa::stream> fifo_acc1_out("fifo_acc1_out"); tapa::stream, 16> fifo_res_acc0("fifo_res_acc0"); tapa::stream, 16> fifo_res_acc1("fifo_res_acc1"); tapa::stream, D> fifo_ln_acc0("fifo_ln_acc0"); tapa::stream, D> fifo_ln_acc1("fifo_ln_acc1"); tapa::stream> fifo_ffn_buffer_in("fifo_ffn_buffer_in"); tapa::stream, 16> fifo_ffn_buffer_out("fifo_ffn_buffer_out"); tapa::streams, NUM_SLR, 16> fifo_gelu_in("fifo_gelu_in"); tapa::streams, NUM_SLR, D> fifo_gelu_out("fifo_gelu_out"); tapa::streams, NUM_SLR, 16> fifo_gelu_full("fifo_gelu_full"); tapa::stream, 8> fifo_ffn2("fifo_ffn2"); tapa::stream, D_div_8+2> fifo_skip_x("fifo_skip_x"); tapa::streams, 2> fifo_res2("fifo_res2"); tapa::streams fifo_inst_switch_acc0("fifo_inst_switch_acc0"); tapa::streams fifo_inst_switch_acc1("fifo_inst_switch_acc1"); tapa::streams fifo_inst_switch_sfu("fifo_inst_switch_sfu"); tapa::streams fifo_inst_switch_context("fifo_inst_switch_context"); tapa::streams fifo_inst_switch_gelu("fifo_inst_switch_gelu"); tapa::streams fifo_inst_sfu_buffer("fifo_inst_sfu_buffer"); tapa::streams fifo_inst_data_pack("fifo_inst_data_pack"); tapa::streams fifo_inst_norm("fifo_inst_norm"); // systolic array 2x2 tapa::streams, 6*NUM_SLR> fifo_x0("fifo_x0"); tapa::streams, 6*NUM_SLR> fifo_x1("fifo_x1"); tapa::streams, 6*NUM_SLR> fifo_y0("fifo_y0"); tapa::streams, 6*NUM_SLR> fifo_y1("fifo_y1"); tapa::streams, 6*NUM_SLR> fifo_bound_x0("fifo_bound_x0"); tapa::streams, 6*NUM_SLR> fifo_bound_x1("fifo_bound_x1"); tapa::streams, 4*NUM_SLR> fifo_drain_acc0("fifo_drain_acc0"); tapa::streams, 4*NUM_SLR> fifo_drain_acc1("fifo_drain_acc1"); tapa::task() .invoke(read_inst, seq_len, fifo_inst_acc0, fifo_inst_acc1) .invoke(read_W, W_acc0, fifo_W_acc0) .invoke(read_W, W_acc1, fifo_W_acc1) .invoke(read_X, L, X_acc0, fifo_X_acc0_slr0) .invoke(read_X, L, X_acc1, fifo_X_acc1_slr0) .invoke( temporal_acc0_slr0, fifo_inst_acc0, fifo_inst_acc0, fifo_inst_switch_acc0, fifo_X_acc0_slr0, fifo_X_acc0, fifo_W_acc0, fifo_W_acc0, fifo_from_acc1_to_acc0, fifo_acc0_to_sfu, fifo_gelu_in, fifo_cont_to_acc0, fifo_ffn_buffer_out, fifo_reduce_acc0, fifo_res_acc0, fifo_x0, fifo_x1, fifo_y0, fifo_y1, fifo_bound_x0, fifo_bound_x1, fifo_drain_acc0 ) // Systolic array 2x2 .invoke( PE_GEMM_acc0, 0, 0, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y0, fifo_y0, fifo_drain_acc0 ) .invoke( PE_GEMM_acc0, 0, 1, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y1, fifo_y1, fifo_drain_acc0 ) .invoke( PE_GEMM_acc0, 1, 0, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y0, fifo_y0, fifo_drain_acc0 ) .invoke( PE_GEMM_acc0, 1, 1, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y1, fifo_y1, fifo_drain_acc0 ) .invoke( black_hole_ap_uint_8, fifo_bound_x0 ) .invoke( black_hole_ap_uint_8, fifo_bound_x1 ) .invoke( black_hole_ap_uint_576, fifo_y0 ) .invoke( black_hole_ap_uint_576, fifo_y1 ) .invoke( black_hole_ap_uint_576, fifo_x0 ) .invoke( black_hole_ap_uint_576, fifo_x1 ) // end systolic array .invoke( temporal_acc1_slr0, fifo_inst_acc1, fifo_inst_acc1, fifo_inst_switch_acc1, fifo_X_acc1_slr0, fifo_X_acc1, fifo_W_acc1, fifo_W_acc1, fifo_from_acc1_to_acc0, fifo_from_sfu_to_acc1, fifo_context, fifo_cont_to_acc1, fifo_reduce_acc1, fifo_res_acc1, fifo_gelu_full, fifo_ffn2, fifo_x0, fifo_x1, fifo_y0, fifo_y1, fifo_bound_x0, fifo_bound_x1, fifo_drain_acc1 ) // Systolic array 2x2 .invoke( PE_GEMM_acc1, 0, 0, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y0, fifo_y0, fifo_drain_acc1 ) .invoke( PE_GEMM_acc1, 0, 1, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y1, fifo_y1, fifo_drain_acc1 ) .invoke( PE_GEMM_acc1, 1, 0, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y0, fifo_y0, fifo_drain_acc1 ) .invoke( PE_GEMM_acc1, 1, 1, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y1, fifo_y1, fifo_drain_acc1 ) .invoke( black_hole_ap_uint_8, fifo_bound_x0 ) .invoke( black_hole_ap_uint_8, fifo_bound_x1 ) .invoke( black_hole_ap_uint_576, fifo_y0 ) .invoke( black_hole_ap_uint_576, fifo_y1 ) .invoke( black_hole_ap_uint_576, fifo_x0 ) .invoke( black_hole_ap_uint_576, fifo_x1 ) // end systolic array .invoke( residual, seq_len, fifo_res_acc0, fifo_ln_acc0 ) .invoke( residual, seq_len, fifo_res_acc1, fifo_ln_acc1 ) .invoke( temporal_acc0, fifo_inst_acc0, fifo_inst_acc0, fifo_inst_switch_acc0, fifo_X_acc0, fifo_X_acc0, fifo_W_acc0, fifo_W_acc0, fifo_from_acc1_to_acc0, fifo_acc0_to_sfu, fifo_cont_to_acc0, fifo_gelu_in, fifo_reduce_acc0, fifo_reduce_acc0, fifo_x0, fifo_x1, fifo_y0, fifo_y1, fifo_bound_x0, fifo_bound_x1, fifo_drain_acc0 ) // Systolic array 2x2 .invoke( PE_GEMM_acc0, 0, 0, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y0, fifo_y0, fifo_drain_acc0 ) .invoke( PE_GEMM_acc0, 0, 1, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y1, fifo_y1, fifo_drain_acc0 ) .invoke( PE_GEMM_acc0, 1, 0, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y0, fifo_y0, fifo_drain_acc0 ) .invoke( PE_GEMM_acc0, 1, 1, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y1, fifo_y1, fifo_drain_acc0 ) .invoke( black_hole_ap_uint_8, fifo_bound_x0 ) .invoke( black_hole_ap_uint_8, fifo_bound_x1 ) .invoke( black_hole_ap_uint_576, fifo_y0 ) .invoke( black_hole_ap_uint_576, fifo_y1 ) .invoke( black_hole_ap_uint_576, fifo_x0 ) .invoke( black_hole_ap_uint_576, fifo_x1 ) .invoke( temporal_acc0, fifo_inst_acc0, fifo_inst_acc0, fifo_inst_switch_acc0, fifo_X_acc0, fifo_X_acc0, fifo_W_acc0, fifo_W_acc0, fifo_from_acc1_to_acc0, fifo_acc0_to_sfu, fifo_cont_to_acc0, fifo_gelu_in, fifo_reduce_acc0, fifo_reduce_acc0, fifo_x0, fifo_x1, fifo_y0, fifo_y1, fifo_bound_x0, fifo_bound_x1, fifo_drain_acc0 ) // Systolic array 2x2 .invoke( PE_GEMM_acc0, 0, 0, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y0, fifo_y0, fifo_drain_acc0 ) .invoke( PE_GEMM_acc0, 0, 1, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y1, fifo_y1, fifo_drain_acc0 ) .invoke( PE_GEMM_acc0, 1, 0, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y0, fifo_y0, fifo_drain_acc0 ) .invoke( PE_GEMM_acc0, 1, 1, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y1, fifo_y1, fifo_drain_acc0 ) .invoke( black_hole_ap_uint_8, fifo_bound_x0 ) .invoke( black_hole_ap_uint_8, fifo_bound_x1 ) .invoke( black_hole_ap_uint_576, fifo_y0 ) .invoke( black_hole_ap_uint_576, fifo_y1 ) .invoke( black_hole_ap_uint_576, fifo_x0 ) .invoke( black_hole_ap_uint_576, fifo_x1 ) .invoke( temporal_acc0, fifo_inst_acc0, fifo_inst_acc0, fifo_inst_switch_acc0, fifo_X_acc0, fifo_X_acc0, fifo_W_acc0, fifo_W_acc0, fifo_from_acc1_to_acc0, fifo_acc0_to_sfu, fifo_cont_to_acc0, fifo_gelu_in, fifo_reduce_acc0, fifo_reduce_acc0, fifo_x0, fifo_x1, fifo_y0, fifo_y1, fifo_bound_x0, fifo_bound_x1, fifo_drain_acc0 ) // Systolic array 2x2 .invoke( PE_GEMM_acc0, 0, 0, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y0, fifo_y0, fifo_drain_acc0 ) .invoke( PE_GEMM_acc0, 0, 1, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y1, fifo_y1, fifo_drain_acc0 ) .invoke( PE_GEMM_acc0, 1, 0, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y0, fifo_y0, fifo_drain_acc0 ) .invoke( PE_GEMM_acc0, 1, 1, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y1, fifo_y1, fifo_drain_acc0 ) .invoke( black_hole_ap_uint_8, fifo_bound_x0 ) .invoke( black_hole_ap_uint_8, fifo_bound_x1 ) .invoke( black_hole_ap_uint_576, fifo_y0 ) .invoke( black_hole_ap_uint_576, fifo_y1 ) .invoke( black_hole_ap_uint_576, fifo_x0 ) .invoke( black_hole_ap_uint_576, fifo_x1 ) // end systolic array .invoke( temporal_acc1, fifo_inst_acc1, fifo_inst_acc1, fifo_inst_switch_acc1, fifo_X_acc1, fifo_X_acc1, fifo_W_acc1, fifo_W_acc1, fifo_from_acc1_to_acc0, fifo_from_sfu_to_acc1, fifo_context, fifo_cont_to_acc1, fifo_reduce_acc1, fifo_reduce_acc1, fifo_gelu_full, fifo_x0, fifo_x1, fifo_y0, fifo_y1, fifo_bound_x0, fifo_bound_x1, fifo_drain_acc1 ) // Systolic array 2x2 .invoke( PE_GEMM_acc1, 0, 0, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y0, fifo_y0, fifo_drain_acc1 ) .invoke( PE_GEMM_acc1, 0, 1, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y1, fifo_y1, fifo_drain_acc1 ) .invoke( PE_GEMM_acc1, 1, 0, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y0, fifo_y0, fifo_drain_acc1 ) .invoke( PE_GEMM_acc1, 1, 1, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y1, fifo_y1, fifo_drain_acc1 ) .invoke( black_hole_ap_uint_8, fifo_bound_x0 ) .invoke( black_hole_ap_uint_8, fifo_bound_x1 ) .invoke( black_hole_ap_uint_576, fifo_y0 ) .invoke( black_hole_ap_uint_576, fifo_y1 ) .invoke( black_hole_ap_uint_576, fifo_x0 ) .invoke( black_hole_ap_uint_576, fifo_x1 ) .invoke( temporal_acc1, fifo_inst_acc1, fifo_inst_acc1, fifo_inst_switch_acc1, fifo_X_acc1, fifo_X_acc1, fifo_W_acc1, fifo_W_acc1, fifo_from_acc1_to_acc0, fifo_from_sfu_to_acc1, fifo_context, fifo_cont_to_acc1, fifo_reduce_acc1, fifo_reduce_acc1, fifo_gelu_full, fifo_x0, fifo_x1, fifo_y0, fifo_y1, fifo_bound_x0, fifo_bound_x1, fifo_drain_acc1 ) // Systolic array 2x2 .invoke( PE_GEMM_acc1, 0, 0, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y0, fifo_y0, fifo_drain_acc1 ) .invoke( PE_GEMM_acc1, 0, 1, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y1, fifo_y1, fifo_drain_acc1 ) .invoke( PE_GEMM_acc1, 1, 0, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y0, fifo_y0, fifo_drain_acc1 ) .invoke( PE_GEMM_acc1, 1, 1, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y1, fifo_y1, fifo_drain_acc1 ) .invoke( black_hole_ap_uint_8, fifo_bound_x0 ) .invoke( black_hole_ap_uint_8, fifo_bound_x1 ) .invoke( black_hole_ap_uint_576, fifo_y0 ) .invoke( black_hole_ap_uint_576, fifo_y1 ) .invoke( black_hole_ap_uint_576, fifo_x0 ) .invoke( black_hole_ap_uint_576, fifo_x1 ) .invoke( temporal_acc1, fifo_inst_acc1, fifo_inst_acc1, fifo_inst_switch_acc1, fifo_X_acc1, fifo_X_acc1, fifo_W_acc1, fifo_W_acc1, fifo_from_acc1_to_acc0, fifo_from_sfu_to_acc1, fifo_context, fifo_cont_to_acc1, fifo_reduce_acc1, fifo_reduce_acc1, fifo_gelu_full, fifo_x0, fifo_x1, fifo_y0, fifo_y1, fifo_bound_x0, fifo_bound_x1, fifo_drain_acc1 ) // Systolic array 2x2 .invoke( PE_GEMM_acc1, 0, 0, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y0, fifo_y0, fifo_drain_acc1 ) .invoke( PE_GEMM_acc1, 0, 1, fifo_bound_x0, fifo_bound_x0, fifo_x0, fifo_x0, fifo_y1, fifo_y1, fifo_drain_acc1 ) .invoke( PE_GEMM_acc1, 1, 0, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y0, fifo_y0, fifo_drain_acc1 ) .invoke( PE_GEMM_acc1, 1, 1, fifo_bound_x1, fifo_bound_x1, fifo_x1, fifo_x1, fifo_y1, fifo_y1, fifo_drain_acc1 ) .invoke( black_hole_ap_uint_8, fifo_bound_x0 ) .invoke( black_hole_ap_uint_8, fifo_bound_x1 ) .invoke( black_hole_ap_uint_576, fifo_y0 ) .invoke( black_hole_ap_uint_576, fifo_y1 ) .invoke( black_hole_ap_uint_576, fifo_x0 ) .invoke( black_hole_ap_uint_576, fifo_x1 ) // end systolic array .invoke(packet_switch_acc, fifo_inst_switch_acc0, fifo_inst_switch_sfu, fifo_inst_switch_gelu) .invoke(packet_switch_acc, fifo_inst_switch_acc1, fifo_inst_switch_context, fifo_inst_norm) .invoke(write_zero, seq_len, D_write_zero_acc0, fifo_reduce_acc0) .invoke(write_zero, seq_len, D_write_zero_acc1, fifo_reduce_acc1) .invoke( sfu_acc_exp, fifo_inst_switch_sfu, fifo_acc0_to_sfu, fifo_sfu_buf_in, fifo_inst_sfu_buffer ) .invoke( sfu_buffer_slr0, fifo_inst_sfu_buffer, fifo_sfu_buf_in, fifo_ln_acc0, fifo_res2, fifo_sfu_buf_out ) .invoke( sfu_buffer_slr0, fifo_inst_sfu_buffer, fifo_sfu_buf_in, fifo_ln_acc1, fifo_res2, fifo_sfu_buf_out ) .invoke( sfu_buffer, fifo_inst_sfu_buffer, fifo_sfu_buf_in, fifo_sfu_buf_out ) .invoke( sfu_norm_slr0, fifo_inst_norm, fifo_sfu_buf_out, fifo_from_sfu_to_acc1, fifo_ffn_buffer_in, fifo_acc1_out ) .invoke( sfu_norm, fifo_inst_norm, fifo_sfu_buf_out, fifo_from_sfu_to_acc1 ) .invoke( ffn_buffer, seq_len, fifo_ffn_buffer_in, fifo_ffn_buffer_out, fifo_skip_x ) .invoke( ffn_residual, seq_len, fifo_skip_x, fifo_ffn2, fifo_res2 ) .invoke( context_buffer, fifo_inst_switch_context, fifo_context, fifo_cont_to_acc0, fifo_cont_to_acc1 ) .invoke( sfu_gelu, fifo_inst_switch_gelu, fifo_inst_data_pack, fifo_gelu_in, fifo_gelu_out ) .invoke( data_packing, fifo_inst_data_pack, fifo_gelu_out, fifo_gelu_full ) // .invoke(write_attention, seq_len, acc0_out, fifo_acc0_out) .invoke(write_mtx, L_out, acc0_out, fifo_acc1_out, fifo_fin) // .invoke(write_mtx, L_out, acc1_out, fifo_acc1_out) .invoke(measure_cycle, fifo_fin, cycle_count) .invoke(black_hole_inst, fifo_inst_acc0) .invoke(black_hole_inst, fifo_inst_acc1) .invoke(black_hole_ap_uint_1024, fifo_X_acc0) .invoke(black_hole_ap_uint_1024, fifo_X_acc1) .invoke(black_hole_ap_uint_512, fifo_W_acc0) .invoke(black_hole_ap_uint_512, fifo_W_acc1); }