File size: 4,094 Bytes
2cc8629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_extended_async_copies : enable

__kernel void Convolution1x1_NCHW(
    const __global half *in,
    const __global half *out,
    const __global half *w,
    int IW,
    int IH,
    int IC,
    int OW,
    int OH,
    int OC)
{
    __local half in_local[8 * 1024];
    __local half out_local[8 * 1024];

    event_t e1 = async_work_group_copy_2D2D(
        in_local, // dst
        in + get_group_id(0) * IW, // src
        IW, // num_elements_per_line,
        IC, // num_lines,
        IW * IH - IW, // src_line_stride,
        0, // dst_line_stride,
        0);
    wait_group_events(1, &e1);

    int oh = get_global_id(0);
    int oc = get_global_id(1);

    int stride;
    int write_output = 0;
    __global half *src;

    __global half8 *w8 = (__global half8 *)(&w[oc * IC]);
    __global half *w1  = (__global half *)(&w[oc * IC]);

    for (uint ow = 0; ow < (OW & (~0x7)); ow += 8) {
        uint iw = ow;
        uint ih = oh;

        half8 val8_0 = 0.0f;

        __local half8 *in8_0 = (__local half8 *)(&in_local[iw + 0 * IW]);
        __local half8 *in8_1 = (__local half8 *)(&in_local[iw + 1 * IW]);
        __local half8 *in8_2 = (__local half8 *)(&in_local[iw + 2 * IW]);
        __local half8 *in8_3 = (__local half8 *)(&in_local[iw + 3 * IW]);
        __local half8 *in8_4 = (__local half8 *)(&in_local[iw + 4 * IW]);
        __local half8 *in8_5 = (__local half8 *)(&in_local[iw + 5 * IW]);
        __local half8 *in8_6 = (__local half8 *)(&in_local[iw + 6 * IW]);
        __local half8 *in8_7 = (__local half8 *)(&in_local[iw + 7 * IW]);

        for (uint ic = 0; ic < IC / 8; ic++) {
            val8_0 += (in8_0[ic * IW]) * ((half8)w8[ic].s0);
            val8_0 += (in8_1[ic * IW]) * ((half8)w8[ic].s1);
            val8_0 += (in8_2[ic * IW]) * ((half8)w8[ic].s2);
            val8_0 += (in8_3[ic * IW]) * ((half8)w8[ic].s3);
            val8_0 += (in8_4[ic * IW]) * ((half8)w8[ic].s4);
            val8_0 += (in8_5[ic * IW]) * ((half8)w8[ic].s5);
            val8_0 += (in8_6[ic * IW]) * ((half8)w8[ic].s6);
            val8_0 += (in8_7[ic * IW]) * ((half8)w8[ic].s7);
        }

        for (uint ic = (IC & (~0x7)); ic < IC; ++ic) {
            val8_0 += *((__local half8 *)(&in_local[iw + ic * IW])) * ((half8)w1[ic]);
        }
        *((__local half8 *)&out_local[ow + 0]) = (val8_0);
    }

    uint iw = (OW & (~0x7));
    uint ih = oh;

    half8 val8_0 = 0.0f;

    __local half8 *in8_0 = (__local half8 *)(&in_local[iw + 0 * IW]);
    __local half8 *in8_1 = (__local half8 *)(&in_local[iw + 1 * IW]);
    __local half8 *in8_2 = (__local half8 *)(&in_local[iw + 2 * IW]);
    __local half8 *in8_3 = (__local half8 *)(&in_local[iw + 3 * IW]);
    __local half8 *in8_4 = (__local half8 *)(&in_local[iw + 4 * IW]);
    __local half8 *in8_5 = (__local half8 *)(&in_local[iw + 5 * IW]);
    __local half8 *in8_6 = (__local half8 *)(&in_local[iw + 6 * IW]);
    __local half8 *in8_7 = (__local half8 *)(&in_local[iw + 7 * IW]);

    for (uint ic = 0; ic < IC / 8; ic++) {
        val8_0 += (in8_0[ic * IW]) * ((half8)w8[ic].s0);
        val8_0 += (in8_1[ic * IW]) * ((half8)w8[ic].s1);
        val8_0 += (in8_2[ic * IW]) * ((half8)w8[ic].s2);
        val8_0 += (in8_3[ic * IW]) * ((half8)w8[ic].s3);
        val8_0 += (in8_4[ic * IW]) * ((half8)w8[ic].s4);
        val8_0 += (in8_5[ic * IW]) * ((half8)w8[ic].s5);
        val8_0 += (in8_6[ic * IW]) * ((half8)w8[ic].s6);
        val8_0 += (in8_7[ic * IW]) * ((half8)w8[ic].s7);
    }

    for (uint ic = (IC & (~0x7)); ic < IC; ++ic) {
        val8_0 += *((__local half8 *)(&in_local[iw + ic * IW])) * ((half8)w1[ic]);
    }
    for (uint ow = (OW & (~0x7)); ow < OW; ow++) {
        out_local[ow + 0] = (val8_0[ow % 8]);
    }

    barrier(CLK_LOCAL_MEM_FENCE);

    event_t e2 = async_work_group_copy(
        out + get_group_id(1) * OW * OH + get_group_id(0) * OW,
        out_local,
        OW,
        0);
    wait_group_events(1, &e2);
}