import tensorflow as tf import time as tm import sys # import numpy as np # import cudnn as cd # from tensorflow.keras import datasets, layers, models # import matplotlib.pyplot as plt # from tensorflow.python.client import device_lib class Integrator_layer(tf.keras.layers.Layer): def __init__(self, n_steps=100, integration_window=100, time_constant=1.0, leakyness=0.0, V_m_threshold = 1.0, refractory_period=0, amplitude=1.0, V_m_min=0, V_cm=2.5, device='cuda', name='I&F'): super(Integrator_layer, self).__init__(name=name) # self.threshold = nn.Threshold(V_m_threshold, 0) # self.zero = torch.tensor(0, dtype=torch.float, device=device) self.Vm_threshold = V_m_threshold self.integration_window = integration_window self.refractory_period = refractory_period self.time_constant = time_constant # self.epsilon = 0.001 self.epsilon = tf.keras.backend.epsilon self.amplitude = amplitude # self.threshold = nn.Threshold(V_m_threshold - self.epsilon, 0) ### Thresholding function # self.threshold = tf.nn.relu(V_m_threshold - self.epsilon, 0) ### Thresholding function self.V_m_min = V_m_min self.device = device @tf.function def chunk_sizes(self, length, chunk_size): chunks = [chunk_size for x in range(length//chunk_size)] if length % chunk_size != 0: chunks.append(length % chunk_size) return chunks def build(self, input_shape): self.batch_size = input_shape[0] self.timesteps = input_shape[1] self.image_shape = input_shape[2:] self.image_rank = len(input_shape) # self.chunk_sizes = self.chunk_sizes(self.timesteps, self.integration_window) self.chunk_sizes = self.chunk_sizes(input_shape[1], self.integration_window) self.tensor_invariance = [None for i in range(self.image_rank)] # self.list_of_indices = [[x, 0] for x in tf.range(input_shape[0])] # self.list_of_indices = tf.range(input_shape[0]) @tf.function def call(self, inputs): ### List of indices - list of indices to replace very first timestep with zero after the roll operation # list_of_indices = tf.pad(tf.expand_dims(tf.range(tf.shape(inputs)[0]), axis=1), # paddings=[[0, 0], [0, 1]], # mode="CONSTANT") list_of_indices = tf.pad(tf.expand_dims(tf.range(tf.shape(inputs)[0]), axis=1), paddings=[[0, 0], [0, 1]], mode="CONSTANT") roll_padding = tf.zeros([self.image_rank, 2], dtype=tf.int32) roll_padding = tf.tensor_scatter_nd_update(roll_padding, indices=[[1, 0]], updates= [1]) images_chunks = tf.split(inputs, self.chunk_sizes, axis=1) ### Fragment current sample into multiple chunks with length equal to the integration window first_chunk = True # zero = torch.tensor(0, dtype=torch.float, device=self.device) for chunk, n_timesteps in zip(images_chunks, self.chunk_sizes): ### n_timesteps - the number of timesteps for current chunk of integration window Spikes_out = tf.zeros([tf.shape(chunk)[0], *self.image_shape, n_timesteps + 1]) ### V_m_out - array for storing membrane potential V_m_out = tf.zeros_like(chunk) # V_m_temp = tf.zeros_like(chunk) V_m_temp = tf.ones_like(chunk) # V_m_temp = tf.tensor_scatter_nd_update(V_m_temp, indices=list_of_indices, # updates=tf.ones([1, *self.image_shape])) while tf.math.count_nonzero(V_m_temp) != 0: tf.autograph.experimental.set_loop_options(shape_invariants=[(V_m_temp, tf.TensorShape(self.tensor_invariance))]) ### V_m_chunk - cumulative summation (integration) along time dimension V_m_chunk = tf.math.cumsum(tf.math.multiply(chunk, self.time_constant), axis=1) ### Thresholding chunks, all values bellow threshold value are zeroed V_m_temp = tf.nn.relu(V_m_chunk - self.Vm_threshold) # V_m_temp = tf.print(V_m_temp, [V_m_temp], 'breaking') if tf.math.count_nonzero(V_m_temp) == 0: ### if Vm did not cross threshold, break the cycle # V_m_out = V_m_out + V_m_chunk # V_m_out = tf.print(V_m_out, [V_m_out], 'breaking') break ### Cumsum of the thresholded cumsum - to avoid any future threshold crossings (additional zeroes) that can occur after threshold is hit: V_m_temp = tf.math.cumsum(V_m_temp, axis=1) ### V_m_temp == 0 The amount of zero values before function crosses the threshold. Used to calculated how many timesteps it took for an integrator to fire an output spike Spikes_out = Spikes_out + tf.one_hot(tf.reduce_sum(tf.cast((V_m_temp == 0), tf.int32), axis=1), depth=n_timesteps + 1)### One hotted zero counts ### TF roll operation is used to shift the vector values by 1, other timestep which crossed threshold is not included: V_m_temp = tf.pad(V_m_temp, paddings=roll_padding, mode="CONSTANT") V_m_temp, _ = tf.split(V_m_temp, [n_timesteps, 1], axis=1) # V_m_temp = tf.roll(V_m_temp, shift=1, axis=1) ###Since roll operation will shift the last value to the first place, the first value should be 0'ed for a proper counting of 0 in the next code fragments. # __, V_m_temp = tf.split(V_m_temp, [1, n_timesteps - 1], axis=1) # V_m_temp = tf.concat((V_m_temp, tf.zeros_like(__)), axis=1) # V_m_temp = tf.tensor_scatter_nd_update(V_m_temp, indices=list_of_indices, # updates=tf.zeros([tf.shape(chunk)[0], *self.image_shape])) # V_m_out = tf.where(V_m_temp == 0, V_m_out + V_m_chunk, 0) ### Resets V_m to 0 after firing if self.refractory_period!=0: ### Resets (=0) number of timesteps after output spike is fired V_m_temp = tf.roll(chunk, shift=self.refractory_period, axis=1) V_m_temp[:, 0:(self.refractory_period-1), :, :, :] = 0 chunk = tf.where(V_m_temp == 0.0, 0.0, chunk) ### Removes spikes before firing. So new V_m can be calculated for a next spike. # Spikes_out = torch.narrow(Spikes_out, dim=-1, start=0, length= n_timesteps) Spikes_out, _ = tf.split(Spikes_out, [n_timesteps, 1], axis=-1) ### Onehot operation adds back time dimension to the last place, so it must be popped out if first_chunk: # V_m_final = V_m_out Spikes_out_final = Spikes_out first_chunk = False else: V_m_final = tf.concat((V_m_final, V_m_out), axis=1) Spikes_out_final = tf.concat((Spikes_out_final, Spikes_out), axis=-1) ### Onehotting puts time as the last tensor dimension. 'movedim' moves time dimension to the 2nd place, after the batch number, as it was before. Spikes_out_final = tf.experimental.numpy.moveaxis(Spikes_out_final, source=-1, destination=1) # return V_m_final, Spikes_out_final # print('LIF forward end:') # print(f'{datetime.now().time().replace(microsecond=0)} --- ') # print(Spikes_out.type()) if self.amplitude !=1.0: Spikes_out_final = Spikes_out_final*self.amplitude return Spikes_out_final def sparse_data_generator_non_spiking(input_images, input_labels, batch_size=32, nb_steps=100, shuffle=True, flatten= False): """ This generator takes datasets in analog format and generates network input as constant currents. If repeat=True, encoding is rate-based, otherwise it is a latency encoding Args: X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples y: The labels """ # data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers) data_loader_original = tf.data.Dataset.from_tensor_slices((tf.cast(input_images, tf.float32), input_labels)) if shuffle: data_loader_original = data_loader_original.shuffle(buffer_size=100) data_loader = data_loader_original.batch(batch_size=batch_size, drop_remainder=False) number_of_batches = input_labels.__len__() // batch_size counter = 0 time = tm.time() for X, y in data_loader: if flatten: X = X.reshape(X.shape[0], -1) # sample_dims = np.array(X.shape[1:], dtype=int) sample_dims = X.shape[1:] # X = torch.unsqueeze(X, dim=1) X = tf.expand_dims(X, axis=1) X = tf.repeat(X, repeats=nb_steps, axis=1) time_taken = tm.time() - time time = tm.time() ETA = time_taken * (number_of_batches - counter) sys.stdout.write( "\rBatch: {0}/{1}, Progress: {2:0.2f}%, Time to process last batch: {3:0.2f} seconds, Estimated time to finish epoch: {4:0.2f} seconds | {5}:{6} minutes".format( counter, number_of_batches, (counter / number_of_batches) * 100, time_taken, ETA, int(ETA // 60), int(ETA % 60))) sys.stdout.flush() # X_batch = torch.tensor(X, device=device, dtype=torch.float) # yield X.expand(-1, nb_steps, *sample_dims).to(device), y.to(device) ### Returns this values after each batch counter += 1 yield X, y ### Returns this values after each batch # return argument_free_generator() class Reduce_sum(tf.keras.layers.Layer): def __init__(self, name=None): super(Reduce_sum, self).__init__(name=name) def call(self, inputs): return tf.math.reduce_sum(inputs, axis=1, keepdims=False) # """ # class Integrator_layer(tf.keras.layers.Layer): # def __init__(self, n_steps=100, integration_window=100, time_constant=1.0, leakyness=0.0, # V_m_threshold = 2.0, refractory_period=0, amplitude=1.0, V_m_min=0, V_cm=2.5, device='cuda', name='I&F'): # super(Integrator_layer, self).__init__(name=name) # # self.threshold = nn.Threshold(V_m_threshold, 0) # # self.zero = torch.tensor(0, dtype=torch.float, device=device) # self.Vm_threshold = V_m_threshold # self.integration_window = integration_window # self.refractory_period = refractory_period # self.time_constant = time_constant # # self.epsilon = 0.001 # self.epsilon = tf.keras.backend.epsilon # self.amplitude = amplitude # # self.threshold = nn.Threshold(V_m_threshold - self.epsilon, 0) ### Thresholding function # # self.threshold = tf.nn.relu(V_m_threshold - self.epsilon, 0) ### Thresholding function # self.V_m_min = V_m_min # self.device = device # # @tf.function # def chunk_sizes(self, length, chunk_size): # chunks = [chunk_size for x in range(length//chunk_size)] # if length % chunk_size != 0: # chunks.append(length % chunk_size) # return chunks # # def build(self, input_shape): # self.batch_size = input_shape[0] # self.timesteps = input_shape[1] # self.image_shape = input_shape[2:] # # self.chunk_sizes = self.chunk_sizes(self.timesteps, self.integration_window) # self.chunk_sizes = self.chunk_sizes(input_shape[1], self.integration_window) # ### # ### # ### # # self.list_of_indices = [[x, 0] for x in tf.range(input_shape[0])] # # self.list_of_indices = tf.range(input_shape[0]) # # @tf.function # def call(self, inputs): # ### List of indices - list of indices to replace very first timestep with zero after the roll operation # list_of_indices = tf.pad(tf.expand_dims(tf.range(tf.shape(inputs)[0]), axis=1), # paddings=[[0, 0], [0, 1]], # mode="CONSTANT") # images_chunks = tf.split(inputs, self.chunk_sizes, axis=1) ### Fragment current sample into multiple chunks with length equal to the integration window # first_chunk = True # # zero = torch.tensor(0, dtype=torch.float, device=self.device) # for chunk, n_timesteps in zip(images_chunks, self.chunk_sizes): # ### n_timesteps - the number of timesteps for current chunk of integration window # Spikes_out = tf.zeros([tf.shape(chunk)[0], *self.image_shape, n_timesteps + 1]) # ### V_m_out - array for storing membrane potential # V_m_out = tf.zeros_like(chunk) # V_m_temp = tf.zeros_like(chunk) # while tf.math.count_nonzero(V_m_temp) != 0: # ### V_m_chunk - cumulative summation (integration) along time dimension # V_m_chunk = tf.math.cumsum(tf.math.multiply(chunk, self.time_constant), axis=1) # ### Thresholding chunks, all values bellow threshold value are zeroed # V_m_temp = tf.nn.relu(V_m_chunk - self.Vm_threshold) # if tf.math.count_nonzero(V_m_temp) == 0: ### if Vm did not cross threshold, break the cycle # V_m_out = V_m_out + V_m_chunk # break # ### Cumsum of the thresholded cumsum - to avoid any future threshold crossings (additional zeroes) that can occur after threshold is hit: # V_m_temp = tf.math.cumsum(V_m_temp, axis=1) # ### V_m_temp == 0 The amount of zero values before function crosses the threshold. Used to calculated how many timesteps it took for an integrator to fire an output spike # Spikes_out = Spikes_out + tf.one_hot(tf.reduce_sum(tf.cast((V_m_temp == 0), tf.int32), axis=1), depth=n_timesteps + 1)### One hotted zero counts # ### TF roll operation is used to shift the vector values by 1, other timestep which crossed threshold is not included: # V_m_temp = tf.roll(V_m_temp, shift=1, axis=1) # ###Since roll operation will shift the last value to the first place, the first value should be 0'ed for a proper counting of 0 in the next code fragments. # V_m_temp = tf.tensor_scatter_nd_update(V_m_temp, indices=list_of_indices, # updates=tf.zeros([1, *self.image_shape])) # V_m_out = tf.where(V_m_temp == 0, V_m_out + V_m_chunk, 0) ### Resets V_m to 0 after firing # if self.refractory_period!=0: ### Resets (=0) number of timesteps after output spike is fired # V_m_temp = tf.roll(chunk, shift=self.refractory_period, axis=1) # V_m_temp[:, 0:(self.refractory_period-1), :, :, :] = 0 # chunk = tf.where(V_m_temp == 0.0, 0.0, chunk) ### Removes spikes before firing. So new V_m can be calculated for a next spike. # # Spikes_out = torch.narrow(Spikes_out, dim=-1, start=0, length= n_timesteps) # Spikes_out, _ = tf.split(Spikes_out, [n_timesteps, 1], axis=-1) ### Onehot operation adds back time dimension to the last place, so it must be popped out # if first_chunk: # V_m_final = V_m_out # Spikes_out_final = Spikes_out # first_chunk = False # else: # V_m_final = tf.concat((V_m_final, V_m_out), axis=1) # Spikes_out_final = tf.concat((Spikes_out_final, Spikes_out), axis=-1) # # # Spikes_out_final = torch.movedim(Spikes_out_final, source=-1, destination=1) # Spikes_out_final = tf.experimental.numpy.swapaxes(Spikes_out_final, axis1=-1, # axis2=1) ### Onehotting puts time as the last tensor dimension. 'movedim' moves time dimension to the 2nd place, after the batch number, as it was before. # # # return V_m_final, Spikes_out_final # # print('LIF forward end:') # # print(f'{datetime.now().time().replace(microsecond=0)} --- ') # # print(Spikes_out.type()) # if self.amplitude !=1.0: # Spikes_out_final = Spikes_out_final*self.amplitude # return Spikes_out_final # """ # # def sparse_data_generator_non_spiking(input_images, input_labels, batch_size=32, nb_steps=100, shuffle=True, flatten= False): # """ This generator takes datasets in analog format and generates network input as constant currents. # If repeat=True, encoding is rate-based, otherwise it is a latency encoding # Args: # X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples # y: The labels # """ # # # def argument_free_generator(): # # data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers) # data_loader_original = tf.data.Dataset.from_tensor_slices((tf.cast(input_images, tf.float32), input_labels)) # if shuffle: # data_loader_original = data_loader_original.shuffle(buffer_size=100) # data_loader = data_loader_original.batch(batch_size=batch_size, drop_remainder=False) # # number_of_batches = input_labels.__len__() // batch_size # counter = 0 # time = tm.time() # # for X, y in data_loader: # if flatten: # X = X.reshape(X.shape[0], -1) # # sample_dims = np.array(X.shape[1:], dtype=int) # sample_dims = X.shape[1:] # # X = torch.unsqueeze(X, dim=1) # X = tf.expand_dims(X, axis=1) # X = tf.repeat(X, repeats=nb_steps, axis=1) # time_taken = tm.time() - time # time = tm.time() # ETA = time_taken * (number_of_batches - counter) # sys.stdout.write( # "\rBatch: {0}/{1}, Progress: {2:0.2f}%, Time to process last batch: {3:0.2f} seconds, Estimated time to finish epoch: {4:0.2f} seconds | {5}:{6} minutes".format( # counter, number_of_batches, (counter / number_of_batches) * 100, time_taken, ETA, int(ETA // 60), # int(ETA % 60))) # sys.stdout.flush() # # X_batch = torch.tensor(X, device=device, dtype=torch.float) # # yield X.expand(-1, nb_steps, *sample_dims).to(device), y.to(device) ### Returns this values after each batch # counter += 1 # yield X, y ### Returns this values after each batch # # # return argument_free_generator() # # # class Reduce_sum(tf.keras.layers.Layer): # def __init__(self, name=None): # super(Reduce_sum, self).__init__(name=name) # # def call(self, inputs): # return tf.math.reduce_sum(inputs, axis=1, keepdims=False) # #