Spaces:
Runtime error
Runtime error
| // Copyright (C) 2018-2022 Intel Corporation | |
| // SPDX-License-Identifier: Apache-2.0 | |
| // | |
| int extract_weights(uchar val, int bit) { return ((val >> bit) & 1); } | |
| __kernel void binary_convolution( | |
| const __global half *restrict src_data, | |
| const __global uchar *restrict weights_data, | |
| __global half *restrict dst_data, | |
| float pad_value, | |
| int IW, | |
| int IH, | |
| int IC, | |
| int DW, | |
| int DH, | |
| int GC, | |
| int KW, | |
| int KH, | |
| int PW, | |
| int PH, | |
| int SW, | |
| int SH) | |
| { | |
| int ipad_value = ((pad_value > 0.f) ? 1 : 0); | |
| int c = get_global_id(2); | |
| int y = get_global_id(1); | |
| int x = get_global_id(0); | |
| int OC = get_global_size(2); | |
| int OH = get_global_size(1); | |
| int OW = get_global_size(0); | |
| int KD = 1; | |
| int SD = 0; | |
| int DD = 0; | |
| int PD = 0; | |
| int ID = 1; | |
| int OD = 1; | |
| int nbits = 8; | |
| int g = c % GC; | |
| int oc = c / GC; | |
| int oh = y; | |
| int ow = x; | |
| for (int od = 0; od < OD; od++) { | |
| int oidx = g * OC / GC * OD * OH * OW + oc * OD * OH * OW + od * OH * OW + oh * OW + ow; | |
| int res = 0; | |
| for (int ic = 0; ic < IC / GC; ic++) { | |
| for (int kd = 0; kd < KD; kd++) { | |
| for (int kh = 0; kh < KH; kh++) { | |
| for (int kw = 0; kw < KW; kw++) { | |
| int widx = g * OC / GC * IC / GC * KD * KH * KW | |
| + oc * IC / GC * KD * KH * KW + ic * KD * KH * KW + kd * KH * KW | |
| + kh * KW + kw; | |
| int w = extract_weights(weights_data[widx / nbits], (widx % nbits)); | |
| int s; | |
| int iw = ow * SW - PW + kw * DW; | |
| int ih = oh * SH - PH + kh * DH; | |
| int id = od * SD - PD + kd * DD; | |
| if (iw < 0 || iw >= (int)IW || ih < 0 || ih >= (int)IH || id < 0 | |
| || id >= (int)ID) { | |
| s = ipad_value; | |
| } else { | |
| int iidx = g * IC / GC * ID * IH * IW + ic * ID * IH * IW + id * IH * IW | |
| + ih * IW + iw; | |
| s = ((src_data[iidx] > 0.f) ? 1 : 0); | |
| } | |
| res += s ^ w; | |
| } | |
| } | |
| } | |
| } | |
| dst_data[oidx] = (half)(IC / GC * KD * KH * KW - 2 * res); | |
| } | |
| } | |