Spaces:
Runtime error
Runtime error
| // Copyright (C) 2018-2022 Intel Corporation | |
| // SPDX-License-Identifier: Apache-2.0 | |
| // | |
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | |
| #pragma OPENCL EXTENSION cl_khr_extended_async_copies : enable | |
| __kernel void Convolution3x3( | |
| const __global half *in_param, | |
| const __global half *out, | |
| const __global half *w, | |
| int IW, | |
| int IH, | |
| int IC, | |
| int OW, | |
| int OH, | |
| int OC, | |
| int KX, | |
| int KY, | |
| int stride_x, | |
| int stride_y, | |
| int pad_x, | |
| int pad_y, | |
| int dilation_x, | |
| int dilation_y) | |
| { | |
| __local half in_local[8 * 1024]; | |
| __local half out_local[8 * 1024]; | |
| __local half w_local[8 * 1024]; | |
| const int sizePlane = IW * IH; | |
| event_t e1 = async_work_group_copy_2D2D( | |
| in_local, // dst | |
| in_param + get_group_id(0) * stride_y * IW, // src | |
| 3 * IW, // num_elements_per_line, | |
| IC, // num_lines, | |
| IW * IH - 3 * IW, // src_line_stride, | |
| 0, // dst_line_stride, | |
| 0); | |
| wait_group_events(1, &e1); | |
| const int sizeWeight = IC * 3 * 3; | |
| e1 = async_work_group_copy(w_local, w + get_group_id(1) * sizeWeight, sizeWeight, 0); | |
| wait_group_events(1, &e1); | |
| int oh = get_global_id(0); | |
| int oc = get_global_id(1); | |
| __local half *in = (__local half *)in_local + 1; | |
| int stride; | |
| int write_output = 0; | |
| __local half *src; | |
| if ((stride_x == 1) && (stride_y == 1)) { | |
| stride = OW / 8; | |
| write_output = 1; | |
| } | |
| if ((stride_x == 2) && (stride_y == 2)) { | |
| stride = OW / 4; | |
| write_output = 2; | |
| } | |
| for (int ow = 0; ow < stride; ow++) { | |
| float8 val = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; | |
| for (int ic = 0; ic < IC; ++ic) { | |
| src = (__local half *)((__local half8 *)(in + ic * IW * 3) + ow); | |
| __local half *k = (__local half *)(w_local + ic * 3 * 3); | |
| half8 aux_in00 = *((__local half8 *)src - 1); | |
| half8 aux_in01 = *((__local half8 *)src + 0); | |
| half8 aux_in02 = *((__local half8 *)src + 1); | |
| half8 aux_in10 = *((__local half8 *)(src + IW) - 1); | |
| half8 aux_in11 = *((__local half8 *)(src + IW) + 0); | |
| half8 aux_in12 = *((__local half8 *)(src + IW) + 1); | |
| half8 aux_in20 = *((__local half8 *)(src + IW * 2) - 1); | |
| half8 aux_in21 = *((__local half8 *)(src + IW * 2) + 0); | |
| half8 aux_in22 = *((__local half8 *)(src + IW * 2) + 1); | |
| short8 in00 = *((short8 *)&aux_in00); | |
| short8 in01 = *((short8 *)&aux_in01); | |
| short8 in02 = *((short8 *)&aux_in02); | |
| short8 in10 = *((short8 *)&aux_in10); | |
| short8 in11 = *((short8 *)&aux_in11); | |
| short8 in12 = *((short8 *)&aux_in12); | |
| short8 in20 = *((short8 *)&aux_in20); | |
| short8 in21 = *((short8 *)&aux_in21); | |
| short8 in22 = *((short8 *)&aux_in22); | |
| short8 aux_aux00 = __builtin_shave_cmu_alignvec_rri_short8(in00, in01, 14); | |
| short8 aux_aux01 = in01; | |
| short8 aux_aux02 = __builtin_shave_cmu_alignvec_rri_short8(in01, in02, 2); | |
| short8 aux_aux10 = __builtin_shave_cmu_alignvec_rri_short8(in10, in11, 14); | |
| short8 aux_aux11 = in11; | |
| short8 aux_aux12 = __builtin_shave_cmu_alignvec_rri_short8(in11, in12, 2); | |
| short8 aux_aux20 = __builtin_shave_cmu_alignvec_rri_short8(in20, in21, 14); | |
| short8 aux_aux21 = in21; | |
| short8 aux_aux22 = __builtin_shave_cmu_alignvec_rri_short8(in21, in22, 2); | |
| half8 aux00 = *((half8 *)&aux_aux00); | |
| half8 aux01 = *((half8 *)&aux_aux01); | |
| half8 aux02 = *((half8 *)&aux_aux02); | |
| half8 aux10 = *((half8 *)&aux_aux10); | |
| half8 aux11 = *((half8 *)&aux_aux11); | |
| half8 aux12 = *((half8 *)&aux_aux12); | |
| half8 aux20 = *((half8 *)&aux_aux20); | |
| half8 aux21 = *((half8 *)&aux_aux21); | |
| half8 aux22 = *((half8 *)&aux_aux22); | |
| half8 w00 = (half8)(*(k + 0)); | |
| half8 w01 = (half8)(*(k + 1)); | |
| half8 w02 = (half8)(*(k + 2)); | |
| half8 w10 = (half8)(*(k + 3)); | |
| half8 w11 = (half8)(*(k + 4)); | |
| half8 w12 = (half8)(*(k + 5)); | |
| half8 w20 = (half8)(*(k + 6)); | |
| half8 w21 = (half8)(*(k + 7)); | |
| half8 w22 = (half8)(*(k + 8)); | |
| val += convert_float8(aux00) * convert_float8(w00); | |
| val += convert_float8(aux01) * convert_float8(w01); | |
| val += convert_float8(aux02) * convert_float8(w02); | |
| val += convert_float8(aux10) * convert_float8(w10); | |
| val += convert_float8(aux11) * convert_float8(w11); | |
| val += convert_float8(aux12) * convert_float8(w12); | |
| val += convert_float8(aux20) * convert_float8(w20); | |
| val += convert_float8(aux21) * convert_float8(w21); | |
| val += convert_float8(aux22) * convert_float8(w22); | |
| } | |
| if (write_output == 2) *((__local half4 *)(out_local) + ow) = convert_half4(val.s0246); | |
| if (write_output == 1) *((__local half8 *)(out_local) + ow) = convert_half8(val); | |
| } | |
| for (int ow = OW & ~(0x7); ow < OW; ow++) { | |
| float val = 0.0f; | |
| for (int ic = 0; ic < IC; ++ic) { | |
| for (int ky = 0; ky < 3; ++ky) { | |
| for (int kx = 0; kx < 3; ++kx) { | |
| int iw = ow * stride_x - pad_x + kx * dilation_x; | |
| int ih = oh * stride_y - pad_y + ky * dilation_y; | |
| val += convert_float(in[ic * IW * 3 + (ky * dilation_y) * IW + iw]) | |
| * convert_float(w_local[ic * 3 * 3 + ky * 3 + kx]); | |
| } | |
| } | |
| } | |
| out_local[ow] = convert_half(val); | |
| } | |
| barrier(CLK_LOCAL_MEM_FENCE); | |
| event_t e2 = async_work_group_copy( | |
| out + get_group_id(1) * OW * OH + get_group_id(0) * OW, | |
| out_local, | |
| OW, | |
| 0); | |
| wait_group_events(1, &e2); | |
| } | |