scnn
Browse files
SCNN.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import time as tm
|
3 |
+
import sys
|
4 |
+
# import numpy as np
|
5 |
+
# import cudnn as cd
|
6 |
+
# from tensorflow.keras import datasets, layers, models
|
7 |
+
# import matplotlib.pyplot as plt
|
8 |
+
# from tensorflow.python.client import device_lib
|
9 |
+
|
10 |
+
class Integrator_layer(tf.keras.layers.Layer):
|
11 |
+
def __init__(self, n_steps=100, integration_window=100, time_constant=1.0, leakyness=0.0,
|
12 |
+
V_m_threshold = 1.0, refractory_period=0, amplitude=1.0, V_m_min=0, V_cm=2.5, device='cuda', name='I&F'):
|
13 |
+
super(Integrator_layer, self).__init__(name=name)
|
14 |
+
# self.threshold = nn.Threshold(V_m_threshold, 0)
|
15 |
+
# self.zero = torch.tensor(0, dtype=torch.float, device=device)
|
16 |
+
self.Vm_threshold = V_m_threshold
|
17 |
+
self.integration_window = integration_window
|
18 |
+
self.refractory_period = refractory_period
|
19 |
+
self.time_constant = time_constant
|
20 |
+
# self.epsilon = 0.001
|
21 |
+
self.epsilon = tf.keras.backend.epsilon
|
22 |
+
self.amplitude = amplitude
|
23 |
+
# self.threshold = nn.Threshold(V_m_threshold - self.epsilon, 0) ### Thresholding function
|
24 |
+
# self.threshold = tf.nn.relu(V_m_threshold - self.epsilon, 0) ### Thresholding function
|
25 |
+
self.V_m_min = V_m_min
|
26 |
+
self.device = device
|
27 |
+
|
28 |
+
@tf.function
|
29 |
+
def chunk_sizes(self, length, chunk_size):
|
30 |
+
chunks = [chunk_size for x in range(length//chunk_size)]
|
31 |
+
if length % chunk_size != 0:
|
32 |
+
chunks.append(length % chunk_size)
|
33 |
+
return chunks
|
34 |
+
|
35 |
+
def build(self, input_shape):
|
36 |
+
self.batch_size = input_shape[0]
|
37 |
+
self.timesteps = input_shape[1]
|
38 |
+
self.image_shape = input_shape[2:]
|
39 |
+
self.image_rank = len(input_shape)
|
40 |
+
|
41 |
+
# self.chunk_sizes = self.chunk_sizes(self.timesteps, self.integration_window)
|
42 |
+
self.chunk_sizes = self.chunk_sizes(input_shape[1], self.integration_window)
|
43 |
+
self.tensor_invariance = [None for i in range(self.image_rank)]
|
44 |
+
# self.list_of_indices = [[x, 0] for x in tf.range(input_shape[0])]
|
45 |
+
# self.list_of_indices = tf.range(input_shape[0])
|
46 |
+
|
47 |
+
@tf.function
|
48 |
+
def call(self, inputs):
|
49 |
+
### List of indices - list of indices to replace very first timestep with zero after the roll operation
|
50 |
+
# list_of_indices = tf.pad(tf.expand_dims(tf.range(tf.shape(inputs)[0]), axis=1),
|
51 |
+
# paddings=[[0, 0], [0, 1]],
|
52 |
+
# mode="CONSTANT")
|
53 |
+
list_of_indices = tf.pad(tf.expand_dims(tf.range(tf.shape(inputs)[0]), axis=1),
|
54 |
+
paddings=[[0, 0], [0, 1]],
|
55 |
+
mode="CONSTANT")
|
56 |
+
roll_padding = tf.zeros([self.image_rank, 2], dtype=tf.int32)
|
57 |
+
roll_padding = tf.tensor_scatter_nd_update(roll_padding, indices=[[1, 0]], updates= [1])
|
58 |
+
images_chunks = tf.split(inputs, self.chunk_sizes, axis=1) ### Fragment current sample into multiple chunks with length equal to the integration window
|
59 |
+
first_chunk = True
|
60 |
+
# zero = torch.tensor(0, dtype=torch.float, device=self.device)
|
61 |
+
for chunk, n_timesteps in zip(images_chunks, self.chunk_sizes):
|
62 |
+
### n_timesteps - the number of timesteps for current chunk of integration window
|
63 |
+
Spikes_out = tf.zeros([tf.shape(chunk)[0], *self.image_shape, n_timesteps + 1])
|
64 |
+
### V_m_out - array for storing membrane potential
|
65 |
+
V_m_out = tf.zeros_like(chunk)
|
66 |
+
# V_m_temp = tf.zeros_like(chunk)
|
67 |
+
V_m_temp = tf.ones_like(chunk)
|
68 |
+
# V_m_temp = tf.tensor_scatter_nd_update(V_m_temp, indices=list_of_indices,
|
69 |
+
# updates=tf.ones([1, *self.image_shape]))
|
70 |
+
while tf.math.count_nonzero(V_m_temp) != 0:
|
71 |
+
tf.autograph.experimental.set_loop_options(shape_invariants=[(V_m_temp, tf.TensorShape(self.tensor_invariance))])
|
72 |
+
### V_m_chunk - cumulative summation (integration) along time dimension
|
73 |
+
V_m_chunk = tf.math.cumsum(tf.math.multiply(chunk, self.time_constant), axis=1)
|
74 |
+
### Thresholding chunks, all values bellow threshold value are zeroed
|
75 |
+
V_m_temp = tf.nn.relu(V_m_chunk - self.Vm_threshold)
|
76 |
+
# V_m_temp = tf.print(V_m_temp, [V_m_temp], 'breaking')
|
77 |
+
if tf.math.count_nonzero(V_m_temp) == 0: ### if Vm did not cross threshold, break the cycle
|
78 |
+
# V_m_out = V_m_out + V_m_chunk
|
79 |
+
# V_m_out = tf.print(V_m_out, [V_m_out], 'breaking')
|
80 |
+
break
|
81 |
+
### Cumsum of the thresholded cumsum - to avoid any future threshold crossings (additional zeroes) that can occur after threshold is hit:
|
82 |
+
V_m_temp = tf.math.cumsum(V_m_temp, axis=1)
|
83 |
+
### V_m_temp == 0 The amount of zero values before function crosses the threshold. Used to calculated how many timesteps it took for an integrator to fire an output spike
|
84 |
+
Spikes_out = Spikes_out + tf.one_hot(tf.reduce_sum(tf.cast((V_m_temp == 0), tf.int32), axis=1), depth=n_timesteps + 1)### One hotted zero counts
|
85 |
+
### TF roll operation is used to shift the vector values by 1, other timestep which crossed threshold is not included:
|
86 |
+
V_m_temp = tf.pad(V_m_temp, paddings=roll_padding, mode="CONSTANT")
|
87 |
+
V_m_temp, _ = tf.split(V_m_temp, [n_timesteps, 1], axis=1)
|
88 |
+
# V_m_temp = tf.roll(V_m_temp, shift=1, axis=1)
|
89 |
+
###Since roll operation will shift the last value to the first place, the first value should be 0'ed for a proper counting of 0 in the next code fragments.
|
90 |
+
# __, V_m_temp = tf.split(V_m_temp, [1, n_timesteps - 1], axis=1)
|
91 |
+
# V_m_temp = tf.concat((V_m_temp, tf.zeros_like(__)), axis=1)
|
92 |
+
# V_m_temp = tf.tensor_scatter_nd_update(V_m_temp, indices=list_of_indices,
|
93 |
+
# updates=tf.zeros([tf.shape(chunk)[0], *self.image_shape]))
|
94 |
+
# V_m_out = tf.where(V_m_temp == 0, V_m_out + V_m_chunk, 0) ### Resets V_m to 0 after firing
|
95 |
+
if self.refractory_period!=0: ### Resets (=0) number of timesteps after output spike is fired
|
96 |
+
V_m_temp = tf.roll(chunk, shift=self.refractory_period, axis=1)
|
97 |
+
V_m_temp[:, 0:(self.refractory_period-1), :, :, :] = 0
|
98 |
+
chunk = tf.where(V_m_temp == 0.0, 0.0, chunk) ### Removes spikes before firing. So new V_m can be calculated for a next spike.
|
99 |
+
# Spikes_out = torch.narrow(Spikes_out, dim=-1, start=0, length= n_timesteps)
|
100 |
+
Spikes_out, _ = tf.split(Spikes_out, [n_timesteps, 1], axis=-1) ### Onehot operation adds back time dimension to the last place, so it must be popped out
|
101 |
+
if first_chunk:
|
102 |
+
# V_m_final = V_m_out
|
103 |
+
Spikes_out_final = Spikes_out
|
104 |
+
first_chunk = False
|
105 |
+
else:
|
106 |
+
V_m_final = tf.concat((V_m_final, V_m_out), axis=1)
|
107 |
+
Spikes_out_final = tf.concat((Spikes_out_final, Spikes_out), axis=-1)
|
108 |
+
|
109 |
+
### Onehotting puts time as the last tensor dimension. 'movedim' moves time dimension to the 2nd place, after the batch number, as it was before.
|
110 |
+
Spikes_out_final = tf.experimental.numpy.moveaxis(Spikes_out_final, source=-1, destination=1)
|
111 |
+
# return V_m_final, Spikes_out_final
|
112 |
+
# print('LIF forward end:')
|
113 |
+
# print(f'{datetime.now().time().replace(microsecond=0)} --- ')
|
114 |
+
# print(Spikes_out.type())
|
115 |
+
if self.amplitude !=1.0:
|
116 |
+
Spikes_out_final = Spikes_out_final*self.amplitude
|
117 |
+
return Spikes_out_final
|
118 |
+
|
119 |
+
def sparse_data_generator_non_spiking(input_images, input_labels, batch_size=32, nb_steps=100, shuffle=True, flatten= False):
|
120 |
+
""" This generator takes datasets in analog format and generates network input as constant currents.
|
121 |
+
If repeat=True, encoding is rate-based, otherwise it is a latency encoding
|
122 |
+
Args:
|
123 |
+
X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
|
124 |
+
y: The labels
|
125 |
+
"""
|
126 |
+
# data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers)
|
127 |
+
data_loader_original = tf.data.Dataset.from_tensor_slices((tf.cast(input_images, tf.float32), input_labels))
|
128 |
+
if shuffle:
|
129 |
+
data_loader_original = data_loader_original.shuffle(buffer_size=100)
|
130 |
+
data_loader = data_loader_original.batch(batch_size=batch_size, drop_remainder=False)
|
131 |
+
|
132 |
+
number_of_batches = input_labels.__len__() // batch_size
|
133 |
+
counter = 0
|
134 |
+
time = tm.time()
|
135 |
+
|
136 |
+
for X, y in data_loader:
|
137 |
+
if flatten:
|
138 |
+
X = X.reshape(X.shape[0], -1)
|
139 |
+
# sample_dims = np.array(X.shape[1:], dtype=int)
|
140 |
+
sample_dims = X.shape[1:]
|
141 |
+
# X = torch.unsqueeze(X, dim=1)
|
142 |
+
X = tf.expand_dims(X, axis=1)
|
143 |
+
X = tf.repeat(X, repeats=nb_steps, axis=1)
|
144 |
+
time_taken = tm.time() - time
|
145 |
+
time = tm.time()
|
146 |
+
ETA = time_taken * (number_of_batches - counter)
|
147 |
+
sys.stdout.write(
|
148 |
+
"\rBatch: {0}/{1}, Progress: {2:0.2f}%, Time to process last batch: {3:0.2f} seconds, Estimated time to finish epoch: {4:0.2f} seconds | {5}:{6} minutes".format(
|
149 |
+
counter, number_of_batches, (counter / number_of_batches) * 100, time_taken, ETA, int(ETA // 60),
|
150 |
+
int(ETA % 60)))
|
151 |
+
sys.stdout.flush()
|
152 |
+
# X_batch = torch.tensor(X, device=device, dtype=torch.float)
|
153 |
+
# yield X.expand(-1, nb_steps, *sample_dims).to(device), y.to(device) ### Returns this values after each batch
|
154 |
+
counter += 1
|
155 |
+
yield X, y ### Returns this values after each batch
|
156 |
+
|
157 |
+
# return argument_free_generator()
|
158 |
+
|
159 |
+
class Reduce_sum(tf.keras.layers.Layer):
|
160 |
+
def __init__(self, name=None):
|
161 |
+
super(Reduce_sum, self).__init__(name=name)
|
162 |
+
|
163 |
+
|
164 |
+
def call(self, inputs):
|
165 |
+
return tf.math.reduce_sum(inputs, axis=1, keepdims=False)
|
166 |
+
|
167 |
+
# """
|
168 |
+
# class Integrator_layer(tf.keras.layers.Layer):
|
169 |
+
# def __init__(self, n_steps=100, integration_window=100, time_constant=1.0, leakyness=0.0,
|
170 |
+
# V_m_threshold = 2.0, refractory_period=0, amplitude=1.0, V_m_min=0, V_cm=2.5, device='cuda', name='I&F'):
|
171 |
+
# super(Integrator_layer, self).__init__(name=name)
|
172 |
+
# # self.threshold = nn.Threshold(V_m_threshold, 0)
|
173 |
+
# # self.zero = torch.tensor(0, dtype=torch.float, device=device)
|
174 |
+
# self.Vm_threshold = V_m_threshold
|
175 |
+
# self.integration_window = integration_window
|
176 |
+
# self.refractory_period = refractory_period
|
177 |
+
# self.time_constant = time_constant
|
178 |
+
# # self.epsilon = 0.001
|
179 |
+
# self.epsilon = tf.keras.backend.epsilon
|
180 |
+
# self.amplitude = amplitude
|
181 |
+
# # self.threshold = nn.Threshold(V_m_threshold - self.epsilon, 0) ### Thresholding function
|
182 |
+
# # self.threshold = tf.nn.relu(V_m_threshold - self.epsilon, 0) ### Thresholding function
|
183 |
+
# self.V_m_min = V_m_min
|
184 |
+
# self.device = device
|
185 |
+
#
|
186 |
+
# @tf.function
|
187 |
+
# def chunk_sizes(self, length, chunk_size):
|
188 |
+
# chunks = [chunk_size for x in range(length//chunk_size)]
|
189 |
+
# if length % chunk_size != 0:
|
190 |
+
# chunks.append(length % chunk_size)
|
191 |
+
# return chunks
|
192 |
+
#
|
193 |
+
# def build(self, input_shape):
|
194 |
+
# self.batch_size = input_shape[0]
|
195 |
+
# self.timesteps = input_shape[1]
|
196 |
+
# self.image_shape = input_shape[2:]
|
197 |
+
# # self.chunk_sizes = self.chunk_sizes(self.timesteps, self.integration_window)
|
198 |
+
# self.chunk_sizes = self.chunk_sizes(input_shape[1], self.integration_window)
|
199 |
+
# ###
|
200 |
+
# ###
|
201 |
+
# ###
|
202 |
+
# # self.list_of_indices = [[x, 0] for x in tf.range(input_shape[0])]
|
203 |
+
# # self.list_of_indices = tf.range(input_shape[0])
|
204 |
+
#
|
205 |
+
# @tf.function
|
206 |
+
# def call(self, inputs):
|
207 |
+
# ### List of indices - list of indices to replace very first timestep with zero after the roll operation
|
208 |
+
# list_of_indices = tf.pad(tf.expand_dims(tf.range(tf.shape(inputs)[0]), axis=1),
|
209 |
+
# paddings=[[0, 0], [0, 1]],
|
210 |
+
# mode="CONSTANT")
|
211 |
+
# images_chunks = tf.split(inputs, self.chunk_sizes, axis=1) ### Fragment current sample into multiple chunks with length equal to the integration window
|
212 |
+
# first_chunk = True
|
213 |
+
# # zero = torch.tensor(0, dtype=torch.float, device=self.device)
|
214 |
+
# for chunk, n_timesteps in zip(images_chunks, self.chunk_sizes):
|
215 |
+
# ### n_timesteps - the number of timesteps for current chunk of integration window
|
216 |
+
# Spikes_out = tf.zeros([tf.shape(chunk)[0], *self.image_shape, n_timesteps + 1])
|
217 |
+
# ### V_m_out - array for storing membrane potential
|
218 |
+
# V_m_out = tf.zeros_like(chunk)
|
219 |
+
# V_m_temp = tf.zeros_like(chunk)
|
220 |
+
# while tf.math.count_nonzero(V_m_temp) != 0:
|
221 |
+
# ### V_m_chunk - cumulative summation (integration) along time dimension
|
222 |
+
# V_m_chunk = tf.math.cumsum(tf.math.multiply(chunk, self.time_constant), axis=1)
|
223 |
+
# ### Thresholding chunks, all values bellow threshold value are zeroed
|
224 |
+
# V_m_temp = tf.nn.relu(V_m_chunk - self.Vm_threshold)
|
225 |
+
# if tf.math.count_nonzero(V_m_temp) == 0: ### if Vm did not cross threshold, break the cycle
|
226 |
+
# V_m_out = V_m_out + V_m_chunk
|
227 |
+
# break
|
228 |
+
# ### Cumsum of the thresholded cumsum - to avoid any future threshold crossings (additional zeroes) that can occur after threshold is hit:
|
229 |
+
# V_m_temp = tf.math.cumsum(V_m_temp, axis=1)
|
230 |
+
# ### V_m_temp == 0 The amount of zero values before function crosses the threshold. Used to calculated how many timesteps it took for an integrator to fire an output spike
|
231 |
+
# Spikes_out = Spikes_out + tf.one_hot(tf.reduce_sum(tf.cast((V_m_temp == 0), tf.int32), axis=1), depth=n_timesteps + 1)### One hotted zero counts
|
232 |
+
# ### TF roll operation is used to shift the vector values by 1, other timestep which crossed threshold is not included:
|
233 |
+
# V_m_temp = tf.roll(V_m_temp, shift=1, axis=1)
|
234 |
+
# ###Since roll operation will shift the last value to the first place, the first value should be 0'ed for a proper counting of 0 in the next code fragments.
|
235 |
+
# V_m_temp = tf.tensor_scatter_nd_update(V_m_temp, indices=list_of_indices,
|
236 |
+
# updates=tf.zeros([1, *self.image_shape]))
|
237 |
+
# V_m_out = tf.where(V_m_temp == 0, V_m_out + V_m_chunk, 0) ### Resets V_m to 0 after firing
|
238 |
+
# if self.refractory_period!=0: ### Resets (=0) number of timesteps after output spike is fired
|
239 |
+
# V_m_temp = tf.roll(chunk, shift=self.refractory_period, axis=1)
|
240 |
+
# V_m_temp[:, 0:(self.refractory_period-1), :, :, :] = 0
|
241 |
+
# chunk = tf.where(V_m_temp == 0.0, 0.0, chunk) ### Removes spikes before firing. So new V_m can be calculated for a next spike.
|
242 |
+
# # Spikes_out = torch.narrow(Spikes_out, dim=-1, start=0, length= n_timesteps)
|
243 |
+
# Spikes_out, _ = tf.split(Spikes_out, [n_timesteps, 1], axis=-1) ### Onehot operation adds back time dimension to the last place, so it must be popped out
|
244 |
+
# if first_chunk:
|
245 |
+
# V_m_final = V_m_out
|
246 |
+
# Spikes_out_final = Spikes_out
|
247 |
+
# first_chunk = False
|
248 |
+
# else:
|
249 |
+
# V_m_final = tf.concat((V_m_final, V_m_out), axis=1)
|
250 |
+
# Spikes_out_final = tf.concat((Spikes_out_final, Spikes_out), axis=-1)
|
251 |
+
#
|
252 |
+
# # Spikes_out_final = torch.movedim(Spikes_out_final, source=-1, destination=1)
|
253 |
+
# Spikes_out_final = tf.experimental.numpy.swapaxes(Spikes_out_final, axis1=-1,
|
254 |
+
# axis2=1) ### Onehotting puts time as the last tensor dimension. 'movedim' moves time dimension to the 2nd place, after the batch number, as it was before.
|
255 |
+
#
|
256 |
+
# # return V_m_final, Spikes_out_final
|
257 |
+
# # print('LIF forward end:')
|
258 |
+
# # print(f'{datetime.now().time().replace(microsecond=0)} --- ')
|
259 |
+
# # print(Spikes_out.type())
|
260 |
+
# if self.amplitude !=1.0:
|
261 |
+
# Spikes_out_final = Spikes_out_final*self.amplitude
|
262 |
+
# return Spikes_out_final
|
263 |
+
# """
|
264 |
+
#
|
265 |
+
# def sparse_data_generator_non_spiking(input_images, input_labels, batch_size=32, nb_steps=100, shuffle=True, flatten= False):
|
266 |
+
# """ This generator takes datasets in analog format and generates network input as constant currents.
|
267 |
+
# If repeat=True, encoding is rate-based, otherwise it is a latency encoding
|
268 |
+
# Args:
|
269 |
+
# X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
|
270 |
+
# y: The labels
|
271 |
+
# """
|
272 |
+
#
|
273 |
+
# # def argument_free_generator():
|
274 |
+
# # data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers)
|
275 |
+
# data_loader_original = tf.data.Dataset.from_tensor_slices((tf.cast(input_images, tf.float32), input_labels))
|
276 |
+
# if shuffle:
|
277 |
+
# data_loader_original = data_loader_original.shuffle(buffer_size=100)
|
278 |
+
# data_loader = data_loader_original.batch(batch_size=batch_size, drop_remainder=False)
|
279 |
+
#
|
280 |
+
# number_of_batches = input_labels.__len__() // batch_size
|
281 |
+
# counter = 0
|
282 |
+
# time = tm.time()
|
283 |
+
#
|
284 |
+
# for X, y in data_loader:
|
285 |
+
# if flatten:
|
286 |
+
# X = X.reshape(X.shape[0], -1)
|
287 |
+
# # sample_dims = np.array(X.shape[1:], dtype=int)
|
288 |
+
# sample_dims = X.shape[1:]
|
289 |
+
# # X = torch.unsqueeze(X, dim=1)
|
290 |
+
# X = tf.expand_dims(X, axis=1)
|
291 |
+
# X = tf.repeat(X, repeats=nb_steps, axis=1)
|
292 |
+
# time_taken = tm.time() - time
|
293 |
+
# time = tm.time()
|
294 |
+
# ETA = time_taken * (number_of_batches - counter)
|
295 |
+
# sys.stdout.write(
|
296 |
+
# "\rBatch: {0}/{1}, Progress: {2:0.2f}%, Time to process last batch: {3:0.2f} seconds, Estimated time to finish epoch: {4:0.2f} seconds | {5}:{6} minutes".format(
|
297 |
+
# counter, number_of_batches, (counter / number_of_batches) * 100, time_taken, ETA, int(ETA // 60),
|
298 |
+
# int(ETA % 60)))
|
299 |
+
# sys.stdout.flush()
|
300 |
+
# # X_batch = torch.tensor(X, device=device, dtype=torch.float)
|
301 |
+
# # yield X.expand(-1, nb_steps, *sample_dims).to(device), y.to(device) ### Returns this values after each batch
|
302 |
+
# counter += 1
|
303 |
+
# yield X, y ### Returns this values after each batch
|
304 |
+
#
|
305 |
+
# # return argument_free_generator()
|
306 |
+
#
|
307 |
+
#
|
308 |
+
# class Reduce_sum(tf.keras.layers.Layer):
|
309 |
+
# def __init__(self, name=None):
|
310 |
+
# super(Reduce_sum, self).__init__(name=name)
|
311 |
+
#
|
312 |
+
# def call(self, inputs):
|
313 |
+
# return tf.math.reduce_sum(inputs, axis=1, keepdims=False)
|
314 |
+
#
|
315 |
+
#
|