Sreeja123 commited on
Commit
b89907e
·
1 Parent(s): 2760440
Files changed (1) hide show
  1. SCNN.py +315 -0
SCNN.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import time as tm
3
+ import sys
4
+ # import numpy as np
5
+ # import cudnn as cd
6
+ # from tensorflow.keras import datasets, layers, models
7
+ # import matplotlib.pyplot as plt
8
+ # from tensorflow.python.client import device_lib
9
+
10
+ class Integrator_layer(tf.keras.layers.Layer):
11
+ def __init__(self, n_steps=100, integration_window=100, time_constant=1.0, leakyness=0.0,
12
+ V_m_threshold = 1.0, refractory_period=0, amplitude=1.0, V_m_min=0, V_cm=2.5, device='cuda', name='I&F'):
13
+ super(Integrator_layer, self).__init__(name=name)
14
+ # self.threshold = nn.Threshold(V_m_threshold, 0)
15
+ # self.zero = torch.tensor(0, dtype=torch.float, device=device)
16
+ self.Vm_threshold = V_m_threshold
17
+ self.integration_window = integration_window
18
+ self.refractory_period = refractory_period
19
+ self.time_constant = time_constant
20
+ # self.epsilon = 0.001
21
+ self.epsilon = tf.keras.backend.epsilon
22
+ self.amplitude = amplitude
23
+ # self.threshold = nn.Threshold(V_m_threshold - self.epsilon, 0) ### Thresholding function
24
+ # self.threshold = tf.nn.relu(V_m_threshold - self.epsilon, 0) ### Thresholding function
25
+ self.V_m_min = V_m_min
26
+ self.device = device
27
+
28
+ @tf.function
29
+ def chunk_sizes(self, length, chunk_size):
30
+ chunks = [chunk_size for x in range(length//chunk_size)]
31
+ if length % chunk_size != 0:
32
+ chunks.append(length % chunk_size)
33
+ return chunks
34
+
35
+ def build(self, input_shape):
36
+ self.batch_size = input_shape[0]
37
+ self.timesteps = input_shape[1]
38
+ self.image_shape = input_shape[2:]
39
+ self.image_rank = len(input_shape)
40
+
41
+ # self.chunk_sizes = self.chunk_sizes(self.timesteps, self.integration_window)
42
+ self.chunk_sizes = self.chunk_sizes(input_shape[1], self.integration_window)
43
+ self.tensor_invariance = [None for i in range(self.image_rank)]
44
+ # self.list_of_indices = [[x, 0] for x in tf.range(input_shape[0])]
45
+ # self.list_of_indices = tf.range(input_shape[0])
46
+
47
+ @tf.function
48
+ def call(self, inputs):
49
+ ### List of indices - list of indices to replace very first timestep with zero after the roll operation
50
+ # list_of_indices = tf.pad(tf.expand_dims(tf.range(tf.shape(inputs)[0]), axis=1),
51
+ # paddings=[[0, 0], [0, 1]],
52
+ # mode="CONSTANT")
53
+ list_of_indices = tf.pad(tf.expand_dims(tf.range(tf.shape(inputs)[0]), axis=1),
54
+ paddings=[[0, 0], [0, 1]],
55
+ mode="CONSTANT")
56
+ roll_padding = tf.zeros([self.image_rank, 2], dtype=tf.int32)
57
+ roll_padding = tf.tensor_scatter_nd_update(roll_padding, indices=[[1, 0]], updates= [1])
58
+ images_chunks = tf.split(inputs, self.chunk_sizes, axis=1) ### Fragment current sample into multiple chunks with length equal to the integration window
59
+ first_chunk = True
60
+ # zero = torch.tensor(0, dtype=torch.float, device=self.device)
61
+ for chunk, n_timesteps in zip(images_chunks, self.chunk_sizes):
62
+ ### n_timesteps - the number of timesteps for current chunk of integration window
63
+ Spikes_out = tf.zeros([tf.shape(chunk)[0], *self.image_shape, n_timesteps + 1])
64
+ ### V_m_out - array for storing membrane potential
65
+ V_m_out = tf.zeros_like(chunk)
66
+ # V_m_temp = tf.zeros_like(chunk)
67
+ V_m_temp = tf.ones_like(chunk)
68
+ # V_m_temp = tf.tensor_scatter_nd_update(V_m_temp, indices=list_of_indices,
69
+ # updates=tf.ones([1, *self.image_shape]))
70
+ while tf.math.count_nonzero(V_m_temp) != 0:
71
+ tf.autograph.experimental.set_loop_options(shape_invariants=[(V_m_temp, tf.TensorShape(self.tensor_invariance))])
72
+ ### V_m_chunk - cumulative summation (integration) along time dimension
73
+ V_m_chunk = tf.math.cumsum(tf.math.multiply(chunk, self.time_constant), axis=1)
74
+ ### Thresholding chunks, all values bellow threshold value are zeroed
75
+ V_m_temp = tf.nn.relu(V_m_chunk - self.Vm_threshold)
76
+ # V_m_temp = tf.print(V_m_temp, [V_m_temp], 'breaking')
77
+ if tf.math.count_nonzero(V_m_temp) == 0: ### if Vm did not cross threshold, break the cycle
78
+ # V_m_out = V_m_out + V_m_chunk
79
+ # V_m_out = tf.print(V_m_out, [V_m_out], 'breaking')
80
+ break
81
+ ### Cumsum of the thresholded cumsum - to avoid any future threshold crossings (additional zeroes) that can occur after threshold is hit:
82
+ V_m_temp = tf.math.cumsum(V_m_temp, axis=1)
83
+ ### V_m_temp == 0 The amount of zero values before function crosses the threshold. Used to calculated how many timesteps it took for an integrator to fire an output spike
84
+ Spikes_out = Spikes_out + tf.one_hot(tf.reduce_sum(tf.cast((V_m_temp == 0), tf.int32), axis=1), depth=n_timesteps + 1)### One hotted zero counts
85
+ ### TF roll operation is used to shift the vector values by 1, other timestep which crossed threshold is not included:
86
+ V_m_temp = tf.pad(V_m_temp, paddings=roll_padding, mode="CONSTANT")
87
+ V_m_temp, _ = tf.split(V_m_temp, [n_timesteps, 1], axis=1)
88
+ # V_m_temp = tf.roll(V_m_temp, shift=1, axis=1)
89
+ ###Since roll operation will shift the last value to the first place, the first value should be 0'ed for a proper counting of 0 in the next code fragments.
90
+ # __, V_m_temp = tf.split(V_m_temp, [1, n_timesteps - 1], axis=1)
91
+ # V_m_temp = tf.concat((V_m_temp, tf.zeros_like(__)), axis=1)
92
+ # V_m_temp = tf.tensor_scatter_nd_update(V_m_temp, indices=list_of_indices,
93
+ # updates=tf.zeros([tf.shape(chunk)[0], *self.image_shape]))
94
+ # V_m_out = tf.where(V_m_temp == 0, V_m_out + V_m_chunk, 0) ### Resets V_m to 0 after firing
95
+ if self.refractory_period!=0: ### Resets (=0) number of timesteps after output spike is fired
96
+ V_m_temp = tf.roll(chunk, shift=self.refractory_period, axis=1)
97
+ V_m_temp[:, 0:(self.refractory_period-1), :, :, :] = 0
98
+ chunk = tf.where(V_m_temp == 0.0, 0.0, chunk) ### Removes spikes before firing. So new V_m can be calculated for a next spike.
99
+ # Spikes_out = torch.narrow(Spikes_out, dim=-1, start=0, length= n_timesteps)
100
+ Spikes_out, _ = tf.split(Spikes_out, [n_timesteps, 1], axis=-1) ### Onehot operation adds back time dimension to the last place, so it must be popped out
101
+ if first_chunk:
102
+ # V_m_final = V_m_out
103
+ Spikes_out_final = Spikes_out
104
+ first_chunk = False
105
+ else:
106
+ V_m_final = tf.concat((V_m_final, V_m_out), axis=1)
107
+ Spikes_out_final = tf.concat((Spikes_out_final, Spikes_out), axis=-1)
108
+
109
+ ### Onehotting puts time as the last tensor dimension. 'movedim' moves time dimension to the 2nd place, after the batch number, as it was before.
110
+ Spikes_out_final = tf.experimental.numpy.moveaxis(Spikes_out_final, source=-1, destination=1)
111
+ # return V_m_final, Spikes_out_final
112
+ # print('LIF forward end:')
113
+ # print(f'{datetime.now().time().replace(microsecond=0)} --- ')
114
+ # print(Spikes_out.type())
115
+ if self.amplitude !=1.0:
116
+ Spikes_out_final = Spikes_out_final*self.amplitude
117
+ return Spikes_out_final
118
+
119
+ def sparse_data_generator_non_spiking(input_images, input_labels, batch_size=32, nb_steps=100, shuffle=True, flatten= False):
120
+ """ This generator takes datasets in analog format and generates network input as constant currents.
121
+ If repeat=True, encoding is rate-based, otherwise it is a latency encoding
122
+ Args:
123
+ X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
124
+ y: The labels
125
+ """
126
+ # data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers)
127
+ data_loader_original = tf.data.Dataset.from_tensor_slices((tf.cast(input_images, tf.float32), input_labels))
128
+ if shuffle:
129
+ data_loader_original = data_loader_original.shuffle(buffer_size=100)
130
+ data_loader = data_loader_original.batch(batch_size=batch_size, drop_remainder=False)
131
+
132
+ number_of_batches = input_labels.__len__() // batch_size
133
+ counter = 0
134
+ time = tm.time()
135
+
136
+ for X, y in data_loader:
137
+ if flatten:
138
+ X = X.reshape(X.shape[0], -1)
139
+ # sample_dims = np.array(X.shape[1:], dtype=int)
140
+ sample_dims = X.shape[1:]
141
+ # X = torch.unsqueeze(X, dim=1)
142
+ X = tf.expand_dims(X, axis=1)
143
+ X = tf.repeat(X, repeats=nb_steps, axis=1)
144
+ time_taken = tm.time() - time
145
+ time = tm.time()
146
+ ETA = time_taken * (number_of_batches - counter)
147
+ sys.stdout.write(
148
+ "\rBatch: {0}/{1}, Progress: {2:0.2f}%, Time to process last batch: {3:0.2f} seconds, Estimated time to finish epoch: {4:0.2f} seconds | {5}:{6} minutes".format(
149
+ counter, number_of_batches, (counter / number_of_batches) * 100, time_taken, ETA, int(ETA // 60),
150
+ int(ETA % 60)))
151
+ sys.stdout.flush()
152
+ # X_batch = torch.tensor(X, device=device, dtype=torch.float)
153
+ # yield X.expand(-1, nb_steps, *sample_dims).to(device), y.to(device) ### Returns this values after each batch
154
+ counter += 1
155
+ yield X, y ### Returns this values after each batch
156
+
157
+ # return argument_free_generator()
158
+
159
+ class Reduce_sum(tf.keras.layers.Layer):
160
+ def __init__(self, name=None):
161
+ super(Reduce_sum, self).__init__(name=name)
162
+
163
+
164
+ def call(self, inputs):
165
+ return tf.math.reduce_sum(inputs, axis=1, keepdims=False)
166
+
167
+ # """
168
+ # class Integrator_layer(tf.keras.layers.Layer):
169
+ # def __init__(self, n_steps=100, integration_window=100, time_constant=1.0, leakyness=0.0,
170
+ # V_m_threshold = 2.0, refractory_period=0, amplitude=1.0, V_m_min=0, V_cm=2.5, device='cuda', name='I&F'):
171
+ # super(Integrator_layer, self).__init__(name=name)
172
+ # # self.threshold = nn.Threshold(V_m_threshold, 0)
173
+ # # self.zero = torch.tensor(0, dtype=torch.float, device=device)
174
+ # self.Vm_threshold = V_m_threshold
175
+ # self.integration_window = integration_window
176
+ # self.refractory_period = refractory_period
177
+ # self.time_constant = time_constant
178
+ # # self.epsilon = 0.001
179
+ # self.epsilon = tf.keras.backend.epsilon
180
+ # self.amplitude = amplitude
181
+ # # self.threshold = nn.Threshold(V_m_threshold - self.epsilon, 0) ### Thresholding function
182
+ # # self.threshold = tf.nn.relu(V_m_threshold - self.epsilon, 0) ### Thresholding function
183
+ # self.V_m_min = V_m_min
184
+ # self.device = device
185
+ #
186
+ # @tf.function
187
+ # def chunk_sizes(self, length, chunk_size):
188
+ # chunks = [chunk_size for x in range(length//chunk_size)]
189
+ # if length % chunk_size != 0:
190
+ # chunks.append(length % chunk_size)
191
+ # return chunks
192
+ #
193
+ # def build(self, input_shape):
194
+ # self.batch_size = input_shape[0]
195
+ # self.timesteps = input_shape[1]
196
+ # self.image_shape = input_shape[2:]
197
+ # # self.chunk_sizes = self.chunk_sizes(self.timesteps, self.integration_window)
198
+ # self.chunk_sizes = self.chunk_sizes(input_shape[1], self.integration_window)
199
+ # ###
200
+ # ###
201
+ # ###
202
+ # # self.list_of_indices = [[x, 0] for x in tf.range(input_shape[0])]
203
+ # # self.list_of_indices = tf.range(input_shape[0])
204
+ #
205
+ # @tf.function
206
+ # def call(self, inputs):
207
+ # ### List of indices - list of indices to replace very first timestep with zero after the roll operation
208
+ # list_of_indices = tf.pad(tf.expand_dims(tf.range(tf.shape(inputs)[0]), axis=1),
209
+ # paddings=[[0, 0], [0, 1]],
210
+ # mode="CONSTANT")
211
+ # images_chunks = tf.split(inputs, self.chunk_sizes, axis=1) ### Fragment current sample into multiple chunks with length equal to the integration window
212
+ # first_chunk = True
213
+ # # zero = torch.tensor(0, dtype=torch.float, device=self.device)
214
+ # for chunk, n_timesteps in zip(images_chunks, self.chunk_sizes):
215
+ # ### n_timesteps - the number of timesteps for current chunk of integration window
216
+ # Spikes_out = tf.zeros([tf.shape(chunk)[0], *self.image_shape, n_timesteps + 1])
217
+ # ### V_m_out - array for storing membrane potential
218
+ # V_m_out = tf.zeros_like(chunk)
219
+ # V_m_temp = tf.zeros_like(chunk)
220
+ # while tf.math.count_nonzero(V_m_temp) != 0:
221
+ # ### V_m_chunk - cumulative summation (integration) along time dimension
222
+ # V_m_chunk = tf.math.cumsum(tf.math.multiply(chunk, self.time_constant), axis=1)
223
+ # ### Thresholding chunks, all values bellow threshold value are zeroed
224
+ # V_m_temp = tf.nn.relu(V_m_chunk - self.Vm_threshold)
225
+ # if tf.math.count_nonzero(V_m_temp) == 0: ### if Vm did not cross threshold, break the cycle
226
+ # V_m_out = V_m_out + V_m_chunk
227
+ # break
228
+ # ### Cumsum of the thresholded cumsum - to avoid any future threshold crossings (additional zeroes) that can occur after threshold is hit:
229
+ # V_m_temp = tf.math.cumsum(V_m_temp, axis=1)
230
+ # ### V_m_temp == 0 The amount of zero values before function crosses the threshold. Used to calculated how many timesteps it took for an integrator to fire an output spike
231
+ # Spikes_out = Spikes_out + tf.one_hot(tf.reduce_sum(tf.cast((V_m_temp == 0), tf.int32), axis=1), depth=n_timesteps + 1)### One hotted zero counts
232
+ # ### TF roll operation is used to shift the vector values by 1, other timestep which crossed threshold is not included:
233
+ # V_m_temp = tf.roll(V_m_temp, shift=1, axis=1)
234
+ # ###Since roll operation will shift the last value to the first place, the first value should be 0'ed for a proper counting of 0 in the next code fragments.
235
+ # V_m_temp = tf.tensor_scatter_nd_update(V_m_temp, indices=list_of_indices,
236
+ # updates=tf.zeros([1, *self.image_shape]))
237
+ # V_m_out = tf.where(V_m_temp == 0, V_m_out + V_m_chunk, 0) ### Resets V_m to 0 after firing
238
+ # if self.refractory_period!=0: ### Resets (=0) number of timesteps after output spike is fired
239
+ # V_m_temp = tf.roll(chunk, shift=self.refractory_period, axis=1)
240
+ # V_m_temp[:, 0:(self.refractory_period-1), :, :, :] = 0
241
+ # chunk = tf.where(V_m_temp == 0.0, 0.0, chunk) ### Removes spikes before firing. So new V_m can be calculated for a next spike.
242
+ # # Spikes_out = torch.narrow(Spikes_out, dim=-1, start=0, length= n_timesteps)
243
+ # Spikes_out, _ = tf.split(Spikes_out, [n_timesteps, 1], axis=-1) ### Onehot operation adds back time dimension to the last place, so it must be popped out
244
+ # if first_chunk:
245
+ # V_m_final = V_m_out
246
+ # Spikes_out_final = Spikes_out
247
+ # first_chunk = False
248
+ # else:
249
+ # V_m_final = tf.concat((V_m_final, V_m_out), axis=1)
250
+ # Spikes_out_final = tf.concat((Spikes_out_final, Spikes_out), axis=-1)
251
+ #
252
+ # # Spikes_out_final = torch.movedim(Spikes_out_final, source=-1, destination=1)
253
+ # Spikes_out_final = tf.experimental.numpy.swapaxes(Spikes_out_final, axis1=-1,
254
+ # axis2=1) ### Onehotting puts time as the last tensor dimension. 'movedim' moves time dimension to the 2nd place, after the batch number, as it was before.
255
+ #
256
+ # # return V_m_final, Spikes_out_final
257
+ # # print('LIF forward end:')
258
+ # # print(f'{datetime.now().time().replace(microsecond=0)} --- ')
259
+ # # print(Spikes_out.type())
260
+ # if self.amplitude !=1.0:
261
+ # Spikes_out_final = Spikes_out_final*self.amplitude
262
+ # return Spikes_out_final
263
+ # """
264
+ #
265
+ # def sparse_data_generator_non_spiking(input_images, input_labels, batch_size=32, nb_steps=100, shuffle=True, flatten= False):
266
+ # """ This generator takes datasets in analog format and generates network input as constant currents.
267
+ # If repeat=True, encoding is rate-based, otherwise it is a latency encoding
268
+ # Args:
269
+ # X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
270
+ # y: The labels
271
+ # """
272
+ #
273
+ # # def argument_free_generator():
274
+ # # data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers)
275
+ # data_loader_original = tf.data.Dataset.from_tensor_slices((tf.cast(input_images, tf.float32), input_labels))
276
+ # if shuffle:
277
+ # data_loader_original = data_loader_original.shuffle(buffer_size=100)
278
+ # data_loader = data_loader_original.batch(batch_size=batch_size, drop_remainder=False)
279
+ #
280
+ # number_of_batches = input_labels.__len__() // batch_size
281
+ # counter = 0
282
+ # time = tm.time()
283
+ #
284
+ # for X, y in data_loader:
285
+ # if flatten:
286
+ # X = X.reshape(X.shape[0], -1)
287
+ # # sample_dims = np.array(X.shape[1:], dtype=int)
288
+ # sample_dims = X.shape[1:]
289
+ # # X = torch.unsqueeze(X, dim=1)
290
+ # X = tf.expand_dims(X, axis=1)
291
+ # X = tf.repeat(X, repeats=nb_steps, axis=1)
292
+ # time_taken = tm.time() - time
293
+ # time = tm.time()
294
+ # ETA = time_taken * (number_of_batches - counter)
295
+ # sys.stdout.write(
296
+ # "\rBatch: {0}/{1}, Progress: {2:0.2f}%, Time to process last batch: {3:0.2f} seconds, Estimated time to finish epoch: {4:0.2f} seconds | {5}:{6} minutes".format(
297
+ # counter, number_of_batches, (counter / number_of_batches) * 100, time_taken, ETA, int(ETA // 60),
298
+ # int(ETA % 60)))
299
+ # sys.stdout.flush()
300
+ # # X_batch = torch.tensor(X, device=device, dtype=torch.float)
301
+ # # yield X.expand(-1, nb_steps, *sample_dims).to(device), y.to(device) ### Returns this values after each batch
302
+ # counter += 1
303
+ # yield X, y ### Returns this values after each batch
304
+ #
305
+ # # return argument_free_generator()
306
+ #
307
+ #
308
+ # class Reduce_sum(tf.keras.layers.Layer):
309
+ # def __init__(self, name=None):
310
+ # super(Reduce_sum, self).__init__(name=name)
311
+ #
312
+ # def call(self, inputs):
313
+ # return tf.math.reduce_sum(inputs, axis=1, keepdims=False)
314
+ #
315
+ #