Sreeja123's picture
scnn
b89907e
raw
history blame
18.8 kB
import tensorflow as tf
import time as tm
import sys
# import numpy as np
# import cudnn as cd
# from tensorflow.keras import datasets, layers, models
# import matplotlib.pyplot as plt
# from tensorflow.python.client import device_lib
class Integrator_layer(tf.keras.layers.Layer):
def __init__(self, n_steps=100, integration_window=100, time_constant=1.0, leakyness=0.0,
V_m_threshold = 1.0, refractory_period=0, amplitude=1.0, V_m_min=0, V_cm=2.5, device='cuda', name='I&F'):
super(Integrator_layer, self).__init__(name=name)
# self.threshold = nn.Threshold(V_m_threshold, 0)
# self.zero = torch.tensor(0, dtype=torch.float, device=device)
self.Vm_threshold = V_m_threshold
self.integration_window = integration_window
self.refractory_period = refractory_period
self.time_constant = time_constant
# self.epsilon = 0.001
self.epsilon = tf.keras.backend.epsilon
self.amplitude = amplitude
# self.threshold = nn.Threshold(V_m_threshold - self.epsilon, 0) ### Thresholding function
# self.threshold = tf.nn.relu(V_m_threshold - self.epsilon, 0) ### Thresholding function
self.V_m_min = V_m_min
self.device = device
@tf.function
def chunk_sizes(self, length, chunk_size):
chunks = [chunk_size for x in range(length//chunk_size)]
if length % chunk_size != 0:
chunks.append(length % chunk_size)
return chunks
def build(self, input_shape):
self.batch_size = input_shape[0]
self.timesteps = input_shape[1]
self.image_shape = input_shape[2:]
self.image_rank = len(input_shape)
# self.chunk_sizes = self.chunk_sizes(self.timesteps, self.integration_window)
self.chunk_sizes = self.chunk_sizes(input_shape[1], self.integration_window)
self.tensor_invariance = [None for i in range(self.image_rank)]
# self.list_of_indices = [[x, 0] for x in tf.range(input_shape[0])]
# self.list_of_indices = tf.range(input_shape[0])
@tf.function
def call(self, inputs):
### List of indices - list of indices to replace very first timestep with zero after the roll operation
# list_of_indices = tf.pad(tf.expand_dims(tf.range(tf.shape(inputs)[0]), axis=1),
# paddings=[[0, 0], [0, 1]],
# mode="CONSTANT")
list_of_indices = tf.pad(tf.expand_dims(tf.range(tf.shape(inputs)[0]), axis=1),
paddings=[[0, 0], [0, 1]],
mode="CONSTANT")
roll_padding = tf.zeros([self.image_rank, 2], dtype=tf.int32)
roll_padding = tf.tensor_scatter_nd_update(roll_padding, indices=[[1, 0]], updates= [1])
images_chunks = tf.split(inputs, self.chunk_sizes, axis=1) ### Fragment current sample into multiple chunks with length equal to the integration window
first_chunk = True
# zero = torch.tensor(0, dtype=torch.float, device=self.device)
for chunk, n_timesteps in zip(images_chunks, self.chunk_sizes):
### n_timesteps - the number of timesteps for current chunk of integration window
Spikes_out = tf.zeros([tf.shape(chunk)[0], *self.image_shape, n_timesteps + 1])
### V_m_out - array for storing membrane potential
V_m_out = tf.zeros_like(chunk)
# V_m_temp = tf.zeros_like(chunk)
V_m_temp = tf.ones_like(chunk)
# V_m_temp = tf.tensor_scatter_nd_update(V_m_temp, indices=list_of_indices,
# updates=tf.ones([1, *self.image_shape]))
while tf.math.count_nonzero(V_m_temp) != 0:
tf.autograph.experimental.set_loop_options(shape_invariants=[(V_m_temp, tf.TensorShape(self.tensor_invariance))])
### V_m_chunk - cumulative summation (integration) along time dimension
V_m_chunk = tf.math.cumsum(tf.math.multiply(chunk, self.time_constant), axis=1)
### Thresholding chunks, all values bellow threshold value are zeroed
V_m_temp = tf.nn.relu(V_m_chunk - self.Vm_threshold)
# V_m_temp = tf.print(V_m_temp, [V_m_temp], 'breaking')
if tf.math.count_nonzero(V_m_temp) == 0: ### if Vm did not cross threshold, break the cycle
# V_m_out = V_m_out + V_m_chunk
# V_m_out = tf.print(V_m_out, [V_m_out], 'breaking')
break
### Cumsum of the thresholded cumsum - to avoid any future threshold crossings (additional zeroes) that can occur after threshold is hit:
V_m_temp = tf.math.cumsum(V_m_temp, axis=1)
### V_m_temp == 0 The amount of zero values before function crosses the threshold. Used to calculated how many timesteps it took for an integrator to fire an output spike
Spikes_out = Spikes_out + tf.one_hot(tf.reduce_sum(tf.cast((V_m_temp == 0), tf.int32), axis=1), depth=n_timesteps + 1)### One hotted zero counts
### TF roll operation is used to shift the vector values by 1, other timestep which crossed threshold is not included:
V_m_temp = tf.pad(V_m_temp, paddings=roll_padding, mode="CONSTANT")
V_m_temp, _ = tf.split(V_m_temp, [n_timesteps, 1], axis=1)
# V_m_temp = tf.roll(V_m_temp, shift=1, axis=1)
###Since roll operation will shift the last value to the first place, the first value should be 0'ed for a proper counting of 0 in the next code fragments.
# __, V_m_temp = tf.split(V_m_temp, [1, n_timesteps - 1], axis=1)
# V_m_temp = tf.concat((V_m_temp, tf.zeros_like(__)), axis=1)
# V_m_temp = tf.tensor_scatter_nd_update(V_m_temp, indices=list_of_indices,
# updates=tf.zeros([tf.shape(chunk)[0], *self.image_shape]))
# V_m_out = tf.where(V_m_temp == 0, V_m_out + V_m_chunk, 0) ### Resets V_m to 0 after firing
if self.refractory_period!=0: ### Resets (=0) number of timesteps after output spike is fired
V_m_temp = tf.roll(chunk, shift=self.refractory_period, axis=1)
V_m_temp[:, 0:(self.refractory_period-1), :, :, :] = 0
chunk = tf.where(V_m_temp == 0.0, 0.0, chunk) ### Removes spikes before firing. So new V_m can be calculated for a next spike.
# Spikes_out = torch.narrow(Spikes_out, dim=-1, start=0, length= n_timesteps)
Spikes_out, _ = tf.split(Spikes_out, [n_timesteps, 1], axis=-1) ### Onehot operation adds back time dimension to the last place, so it must be popped out
if first_chunk:
# V_m_final = V_m_out
Spikes_out_final = Spikes_out
first_chunk = False
else:
V_m_final = tf.concat((V_m_final, V_m_out), axis=1)
Spikes_out_final = tf.concat((Spikes_out_final, Spikes_out), axis=-1)
### Onehotting puts time as the last tensor dimension. 'movedim' moves time dimension to the 2nd place, after the batch number, as it was before.
Spikes_out_final = tf.experimental.numpy.moveaxis(Spikes_out_final, source=-1, destination=1)
# return V_m_final, Spikes_out_final
# print('LIF forward end:')
# print(f'{datetime.now().time().replace(microsecond=0)} --- ')
# print(Spikes_out.type())
if self.amplitude !=1.0:
Spikes_out_final = Spikes_out_final*self.amplitude
return Spikes_out_final
def sparse_data_generator_non_spiking(input_images, input_labels, batch_size=32, nb_steps=100, shuffle=True, flatten= False):
""" This generator takes datasets in analog format and generates network input as constant currents.
If repeat=True, encoding is rate-based, otherwise it is a latency encoding
Args:
X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
y: The labels
"""
# data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers)
data_loader_original = tf.data.Dataset.from_tensor_slices((tf.cast(input_images, tf.float32), input_labels))
if shuffle:
data_loader_original = data_loader_original.shuffle(buffer_size=100)
data_loader = data_loader_original.batch(batch_size=batch_size, drop_remainder=False)
number_of_batches = input_labels.__len__() // batch_size
counter = 0
time = tm.time()
for X, y in data_loader:
if flatten:
X = X.reshape(X.shape[0], -1)
# sample_dims = np.array(X.shape[1:], dtype=int)
sample_dims = X.shape[1:]
# X = torch.unsqueeze(X, dim=1)
X = tf.expand_dims(X, axis=1)
X = tf.repeat(X, repeats=nb_steps, axis=1)
time_taken = tm.time() - time
time = tm.time()
ETA = time_taken * (number_of_batches - counter)
sys.stdout.write(
"\rBatch: {0}/{1}, Progress: {2:0.2f}%, Time to process last batch: {3:0.2f} seconds, Estimated time to finish epoch: {4:0.2f} seconds | {5}:{6} minutes".format(
counter, number_of_batches, (counter / number_of_batches) * 100, time_taken, ETA, int(ETA // 60),
int(ETA % 60)))
sys.stdout.flush()
# X_batch = torch.tensor(X, device=device, dtype=torch.float)
# yield X.expand(-1, nb_steps, *sample_dims).to(device), y.to(device) ### Returns this values after each batch
counter += 1
yield X, y ### Returns this values after each batch
# return argument_free_generator()
class Reduce_sum(tf.keras.layers.Layer):
def __init__(self, name=None):
super(Reduce_sum, self).__init__(name=name)
def call(self, inputs):
return tf.math.reduce_sum(inputs, axis=1, keepdims=False)
# """
# class Integrator_layer(tf.keras.layers.Layer):
# def __init__(self, n_steps=100, integration_window=100, time_constant=1.0, leakyness=0.0,
# V_m_threshold = 2.0, refractory_period=0, amplitude=1.0, V_m_min=0, V_cm=2.5, device='cuda', name='I&F'):
# super(Integrator_layer, self).__init__(name=name)
# # self.threshold = nn.Threshold(V_m_threshold, 0)
# # self.zero = torch.tensor(0, dtype=torch.float, device=device)
# self.Vm_threshold = V_m_threshold
# self.integration_window = integration_window
# self.refractory_period = refractory_period
# self.time_constant = time_constant
# # self.epsilon = 0.001
# self.epsilon = tf.keras.backend.epsilon
# self.amplitude = amplitude
# # self.threshold = nn.Threshold(V_m_threshold - self.epsilon, 0) ### Thresholding function
# # self.threshold = tf.nn.relu(V_m_threshold - self.epsilon, 0) ### Thresholding function
# self.V_m_min = V_m_min
# self.device = device
#
# @tf.function
# def chunk_sizes(self, length, chunk_size):
# chunks = [chunk_size for x in range(length//chunk_size)]
# if length % chunk_size != 0:
# chunks.append(length % chunk_size)
# return chunks
#
# def build(self, input_shape):
# self.batch_size = input_shape[0]
# self.timesteps = input_shape[1]
# self.image_shape = input_shape[2:]
# # self.chunk_sizes = self.chunk_sizes(self.timesteps, self.integration_window)
# self.chunk_sizes = self.chunk_sizes(input_shape[1], self.integration_window)
# ###
# ###
# ###
# # self.list_of_indices = [[x, 0] for x in tf.range(input_shape[0])]
# # self.list_of_indices = tf.range(input_shape[0])
#
# @tf.function
# def call(self, inputs):
# ### List of indices - list of indices to replace very first timestep with zero after the roll operation
# list_of_indices = tf.pad(tf.expand_dims(tf.range(tf.shape(inputs)[0]), axis=1),
# paddings=[[0, 0], [0, 1]],
# mode="CONSTANT")
# images_chunks = tf.split(inputs, self.chunk_sizes, axis=1) ### Fragment current sample into multiple chunks with length equal to the integration window
# first_chunk = True
# # zero = torch.tensor(0, dtype=torch.float, device=self.device)
# for chunk, n_timesteps in zip(images_chunks, self.chunk_sizes):
# ### n_timesteps - the number of timesteps for current chunk of integration window
# Spikes_out = tf.zeros([tf.shape(chunk)[0], *self.image_shape, n_timesteps + 1])
# ### V_m_out - array for storing membrane potential
# V_m_out = tf.zeros_like(chunk)
# V_m_temp = tf.zeros_like(chunk)
# while tf.math.count_nonzero(V_m_temp) != 0:
# ### V_m_chunk - cumulative summation (integration) along time dimension
# V_m_chunk = tf.math.cumsum(tf.math.multiply(chunk, self.time_constant), axis=1)
# ### Thresholding chunks, all values bellow threshold value are zeroed
# V_m_temp = tf.nn.relu(V_m_chunk - self.Vm_threshold)
# if tf.math.count_nonzero(V_m_temp) == 0: ### if Vm did not cross threshold, break the cycle
# V_m_out = V_m_out + V_m_chunk
# break
# ### Cumsum of the thresholded cumsum - to avoid any future threshold crossings (additional zeroes) that can occur after threshold is hit:
# V_m_temp = tf.math.cumsum(V_m_temp, axis=1)
# ### V_m_temp == 0 The amount of zero values before function crosses the threshold. Used to calculated how many timesteps it took for an integrator to fire an output spike
# Spikes_out = Spikes_out + tf.one_hot(tf.reduce_sum(tf.cast((V_m_temp == 0), tf.int32), axis=1), depth=n_timesteps + 1)### One hotted zero counts
# ### TF roll operation is used to shift the vector values by 1, other timestep which crossed threshold is not included:
# V_m_temp = tf.roll(V_m_temp, shift=1, axis=1)
# ###Since roll operation will shift the last value to the first place, the first value should be 0'ed for a proper counting of 0 in the next code fragments.
# V_m_temp = tf.tensor_scatter_nd_update(V_m_temp, indices=list_of_indices,
# updates=tf.zeros([1, *self.image_shape]))
# V_m_out = tf.where(V_m_temp == 0, V_m_out + V_m_chunk, 0) ### Resets V_m to 0 after firing
# if self.refractory_period!=0: ### Resets (=0) number of timesteps after output spike is fired
# V_m_temp = tf.roll(chunk, shift=self.refractory_period, axis=1)
# V_m_temp[:, 0:(self.refractory_period-1), :, :, :] = 0
# chunk = tf.where(V_m_temp == 0.0, 0.0, chunk) ### Removes spikes before firing. So new V_m can be calculated for a next spike.
# # Spikes_out = torch.narrow(Spikes_out, dim=-1, start=0, length= n_timesteps)
# Spikes_out, _ = tf.split(Spikes_out, [n_timesteps, 1], axis=-1) ### Onehot operation adds back time dimension to the last place, so it must be popped out
# if first_chunk:
# V_m_final = V_m_out
# Spikes_out_final = Spikes_out
# first_chunk = False
# else:
# V_m_final = tf.concat((V_m_final, V_m_out), axis=1)
# Spikes_out_final = tf.concat((Spikes_out_final, Spikes_out), axis=-1)
#
# # Spikes_out_final = torch.movedim(Spikes_out_final, source=-1, destination=1)
# Spikes_out_final = tf.experimental.numpy.swapaxes(Spikes_out_final, axis1=-1,
# axis2=1) ### Onehotting puts time as the last tensor dimension. 'movedim' moves time dimension to the 2nd place, after the batch number, as it was before.
#
# # return V_m_final, Spikes_out_final
# # print('LIF forward end:')
# # print(f'{datetime.now().time().replace(microsecond=0)} --- ')
# # print(Spikes_out.type())
# if self.amplitude !=1.0:
# Spikes_out_final = Spikes_out_final*self.amplitude
# return Spikes_out_final
# """
#
# def sparse_data_generator_non_spiking(input_images, input_labels, batch_size=32, nb_steps=100, shuffle=True, flatten= False):
# """ This generator takes datasets in analog format and generates network input as constant currents.
# If repeat=True, encoding is rate-based, otherwise it is a latency encoding
# Args:
# X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
# y: The labels
# """
#
# # def argument_free_generator():
# # data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers)
# data_loader_original = tf.data.Dataset.from_tensor_slices((tf.cast(input_images, tf.float32), input_labels))
# if shuffle:
# data_loader_original = data_loader_original.shuffle(buffer_size=100)
# data_loader = data_loader_original.batch(batch_size=batch_size, drop_remainder=False)
#
# number_of_batches = input_labels.__len__() // batch_size
# counter = 0
# time = tm.time()
#
# for X, y in data_loader:
# if flatten:
# X = X.reshape(X.shape[0], -1)
# # sample_dims = np.array(X.shape[1:], dtype=int)
# sample_dims = X.shape[1:]
# # X = torch.unsqueeze(X, dim=1)
# X = tf.expand_dims(X, axis=1)
# X = tf.repeat(X, repeats=nb_steps, axis=1)
# time_taken = tm.time() - time
# time = tm.time()
# ETA = time_taken * (number_of_batches - counter)
# sys.stdout.write(
# "\rBatch: {0}/{1}, Progress: {2:0.2f}%, Time to process last batch: {3:0.2f} seconds, Estimated time to finish epoch: {4:0.2f} seconds | {5}:{6} minutes".format(
# counter, number_of_batches, (counter / number_of_batches) * 100, time_taken, ETA, int(ETA // 60),
# int(ETA % 60)))
# sys.stdout.flush()
# # X_batch = torch.tensor(X, device=device, dtype=torch.float)
# # yield X.expand(-1, nb_steps, *sample_dims).to(device), y.to(device) ### Returns this values after each batch
# counter += 1
# yield X, y ### Returns this values after each batch
#
# # return argument_free_generator()
#
#
# class Reduce_sum(tf.keras.layers.Layer):
# def __init__(self, name=None):
# super(Reduce_sum, self).__init__(name=name)
#
# def call(self, inputs):
# return tf.math.reduce_sum(inputs, axis=1, keepdims=False)
#
#