#include #include #include #include #include #include #include #include #include constexpr int D = 1024; constexpr int D_ffn = 4096; constexpr int N_head = 16; constexpr int MAX_SEQ_LEN = 1024; constexpr int NUM_SLR = 3; constexpr int NUM_DUM_SLR = 4; constexpr int D_head = D / N_head; constexpr int FFN_WEIGHT_SIZE = D * D_ffn; constexpr int OUT_WEIGHT_SIZE = D * D; constexpr int QKV_WEIGHT_SIZE = D * D / N_head * NUM_DUM_SLR * 2; // multi-head attention using std::vector; using int_v16 = tapa::vec_t; using int4_v128 = tapa::vec_t, 128>; using int8_v64 = tapa::vec_t, 64>; void opt_kernel( const int L, const int L_out, const int seq_len, // tapa::mmap inst, // inst[0] = L, inst[1] = reload_weight tapa::mmap> X_acc0, tapa::mmap> X_acc1, tapa::mmap> W_acc0, tapa::mmap> W_acc1, tapa::mmap> acc0_out, tapa::mmap> acc1_out, tapa::mmap cycle_count ); template using aligned_vector = std::vector>; DEFINE_string(bitstream, "", "path to bitstream file"); int main(int argc, char *argv[]){ gflags::ParseCommandLineFlags(&argc, &argv, true); const int L = argc > 1 ? atoll(argv[1]) : MAX_SEQ_LEN; srand((unsigned)time(nullptr)); // data preparation aligned_vector inst = {L, 1}; aligned_vector> X_acc0(L * D); aligned_vector> X_acc1(L * D); aligned_vector> W_acc0(D * D_head * NUM_DUM_SLR * 10); aligned_vector> W_acc1(D * D_head * NUM_DUM_SLR * 10); aligned_vector> acc0_out(NUM_SLR * L * D / 8); // aligned_vector> acc0_out(NUM_SLR, aligned_vector>(L * L / 16)); aligned_vector> acc1_out(NUM_SLR * L * D / 8); aligned_vector cycle_count(1); vector X_copy(L * D); vector> W_acc0_split(NUM_DUM_SLR, vector(D * D_head * 10)); vector> W_acc1_split(NUM_DUM_SLR, vector(D * D_head * 10)); vector> W_k_split(NUM_DUM_SLR, vector(D * D_head * 10)); vector> q_golden(NUM_DUM_SLR, aligned_vector(L * D_head)); vector> k_golden(NUM_DUM_SLR, aligned_vector(L * D_head)); vector> attn_golden(NUM_DUM_SLR, aligned_vector(L * L)); vector> acc1_out_golden(NUM_DUM_SLR, aligned_vector(L * D_head)); for(int i = 0; i < L * D; i++){ int val = (rand() % 8) + 1; ap_int<32> full = tapa::bit_cast>(val); X_copy[i] = val; X_acc0[i] = ap_int<8>(full(7, 0)); X_acc1[i] = ap_int<8>(full(7, 0)); } for(int i = 0; i < D * D_head * NUM_DUM_SLR * 5; i++){ int val = (rand() % 6) - 1; ap_int<32> full = tapa::bit_cast>(val); W_acc0[i/2]((i%2+1)*4-1, (i%2)*4) = ap_int<4>(full(3, 0)); W_acc0_split[(i / 32) % 4][(i / 128) * 32 + (i % 32)] = val; } for(int i = 0; i < D * D_head * NUM_DUM_SLR * 5; i++){ int val = (rand() % 6) - 1; ap_int<32> full = tapa::bit_cast>(val); W_acc1[i/2]((i%2+1)*4-1, (i%2)*4) = ap_int<4>(full(3, 0)); W_acc1_split[(i / 32) % 4][(i / 128) * 32 + (i % 32)] = val; } for(int i = D * D_head * NUM_DUM_SLR * 5; i < D * D_head * NUM_DUM_SLR * 15; i++){ int val = (rand() % 6) - 1; int ind = i - D * D_head * NUM_DUM_SLR * 5; ap_int<32> full = tapa::bit_cast>(val); W_acc0[i/2]((i%2+1)*4-1, (i%2)*4) = ap_int<4>(full(3, 0)); W_acc1[i/2]((i%2+1)*4-1, (i%2)*4) = ap_int<4>(full(3, 0)); W_k_split[(ind / 32) % 4][(ind / 128) * 32 + (ind % 32)] = val; } // cpu for(int i = 0; i < NUM_SLR; i++){ // WqX for(int j = 0; j < L; j++){ for(int k = 0; k < D_head; k++){ int acc = 0; for(int l = 0; l < D; l++){ acc += X_copy[j*D+l] * W_acc0_split[i][l*D_head + k]; } q_golden[i][j * D_head + k] = std::min(std::max((acc >> 8), -128), 127); } } //WvX for(int j = 0; j < L; j++){ for(int k = 0; k < D_head; k++){ int acc = 0; for(int l = 0; l < D; l++){ acc += X_copy[j*D+l] * W_acc1_split[i][l*D_head + k]; } acc1_out_golden[i][j * D_head + k] = std::min(std::max((acc >> 8), -128), 127); } } //WkX for(int j = 0; j < L; j++){ for(int k = 0; k < D_head; k++){ int acc = 0; for(int l = 0; l < D; l++){ acc += X_copy[j*D+l] * W_k_split[i][l*D_head + k]; } k_golden[i][j * D_head + k] = std::min(std::max((acc >> 8), -128), 127); } } // QK^T for(int j = 0; j < L; j++){ for(int k = 0; k < L; k++){ int acc = 0; for(int l = 0; l < D_head; l++){ acc += q_golden[i][k*D_head+l] * k_golden[i][j*D_head+l]; } attn_golden[i][j*D_head+k] = acc; } } } // invoke the kernel int64_t kernel_time_ns = 0; for(int i = 0; i < 24; i++){ kernel_time_ns += tapa::invoke(opt_kernel, FLAGS_bitstream, L * D, L * D / 8, L, // tapa::read_only_mmap(inst), tapa::read_only_mmap>(X_acc0).reinterpret>(), tapa::read_only_mmap>(X_acc1).reinterpret>(), tapa::read_only_mmap>(W_acc0).reinterpret>(), tapa::read_only_mmap>(W_acc1).reinterpret>(), tapa::write_only_mmap>(acc0_out), tapa::write_only_mmap>(acc1_out), tapa::write_only_mmap(cycle_count)); } // std::clog << "cycle time: " << cycle_count[0] << std::endl; std::clog << "kernel time: " << kernel_time_ns * 2e-9 << " s" << std::endl; int error = 0; // compare // for(int i = 0; i < NUM_SLR; i++){ // for(int j = 0; j < 4; j++){ // for(int k = 0; k < 16; k++){ // if(tapa::bit_cast(ap_int<32>(acc0_out[i][j](k*32+31,k*32)))-attn_golden[i][j*16+k] != 0){ // std::clog << "slr: " << i << ", index: " << j << ", actual: " << tapa::bit_cast(ap_int<32>(acc0_out[i][j](k*32+31,k*32))) << ", expect: " << attn_golden[i][j*16+k] << std::endl; // error++; // } // } // } // } if (error == 0) { std::clog << "PASSED" << std::endl; } else { std::clog << "FAILED" << std::endl; return 1; } return 0; }