Spaces:
Runtime error
Runtime error
File size: 2,524 Bytes
2cc8629 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
int extract_weights(uchar val, int bit) { return ((val >> bit) & 1); }
__kernel void binary_convolution(
const __global half *restrict src_data,
const __global uchar *restrict weights_data,
__global half *restrict dst_data,
float pad_value,
int IW,
int IH,
int IC,
int DW,
int DH,
int GC,
int KW,
int KH,
int PW,
int PH,
int SW,
int SH)
{
int ipad_value = ((pad_value > 0.f) ? 1 : 0);
int c = get_global_id(2);
int y = get_global_id(1);
int x = get_global_id(0);
int OC = get_global_size(2);
int OH = get_global_size(1);
int OW = get_global_size(0);
int KD = 1;
int SD = 0;
int DD = 0;
int PD = 0;
int ID = 1;
int OD = 1;
int nbits = 8;
int g = c % GC;
int oc = c / GC;
int oh = y;
int ow = x;
for (int od = 0; od < OD; od++) {
int oidx = g * OC / GC * OD * OH * OW + oc * OD * OH * OW + od * OH * OW + oh * OW + ow;
int res = 0;
for (int ic = 0; ic < IC / GC; ic++) {
for (int kd = 0; kd < KD; kd++) {
for (int kh = 0; kh < KH; kh++) {
for (int kw = 0; kw < KW; kw++) {
int widx = g * OC / GC * IC / GC * KD * KH * KW
+ oc * IC / GC * KD * KH * KW + ic * KD * KH * KW + kd * KH * KW
+ kh * KW + kw;
int w = extract_weights(weights_data[widx / nbits], (widx % nbits));
int s;
int iw = ow * SW - PW + kw * DW;
int ih = oh * SH - PH + kh * DH;
int id = od * SD - PD + kd * DD;
if (iw < 0 || iw >= (int)IW || ih < 0 || ih >= (int)IH || id < 0
|| id >= (int)ID) {
s = ipad_value;
} else {
int iidx = g * IC / GC * ID * IH * IW + ic * ID * IH * IW + id * IH * IW
+ ih * IW + iw;
s = ((src_data[iidx] > 0.f) ? 1 : 0);
}
res += s ^ w;
}
}
}
}
dst_data[oidx] = (half)(IC / GC * KD * KH * KW - 2 * res);
}
}
|