Neural Cellular Automata (Based on https://distill.pub/2020/growing-ca/) implemented in Jax (Flax)
Installation
from source
git clone [email protected]:shyamsn97/jax-nca.git
cd jax-nca
python setup.py install
from PYPI
pip install jax-nca
How do NCAs work?
For more information, view the awesome article https://distill.pub/2020/growing-ca/ -- Mordvintsev, et al., "Growing Neural Cellular Automata", Distill, 2020
Image below describes a single update step: https://github.com/distillpub/post--growing-ca/blob/master/public/figures/model.svg
Why Jax?
Note: This project served as a nice introduction to jax, so its performance can probably be improved
NCAs are autoregressive models like RNNs, where new states are calculated from previous ones. With jax, we can make these operations a lot more performant with jax.lax.scan
and jax.jit
(https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)
Instead of writing the nca growth process as
def multi_step(params, nca, current_state, num_steps):
# params: parameters for NCA
# nca: Flax Module describing NCA
# current_state: Current NCA state
# num_steps: number of steps to run
for i in range(num_steps):
current_state = nca.apply(params, current_state)
return current_state
We can write this with jax.lax.scan
def multi_step(params, nca, current_state, num_steps):
# params: parameters for NCA
# nca: Flax Module describing NCA
# current_state: Current NCA state
# num_steps: number of steps to run
def forward(carry, inp):
carry = nca.apply({"params": params}, carry)
return carry, carry
final_state, nca_states = jax.lax.scan(forward, current_state, None, length=num_steps)
return final_state
The actual multi_step implementation can be found here: https://github.com/shyamsn97/jax-nca/blob/main/jax_nca/nca.py#L103
Usage
See notebooks/Gecko.ipynb for a full example
Currently there's a bug with the stochastic update, so only cell_fire_rate = 1.0
works at the moment
Creating and using NCA
class NCA(nn.Module):
num_hidden_channels: int
num_target_channels: int = 3
alpha_living_threshold: float = 0.1
cell_fire_rate: float = 1.0
trainable_perception: bool = False
alpha: float = 1.0
"""
num_hidden_channels: Number of hidden channels for each cell to use
num_target_channels: Number of target channels to be used
alpha_living_threshold: threshold to determine whether a cell lives or dies
cell_fire_rate: probability that a cell receives an update per step
trainable_perception: if true, instead of using sobel filters use a trainable conv net
alpha: scalar value to be multiplied to updates
"""
...
from jax_nca.nca import NCA
# usage
nca = NCA(
num_hidden_channels = 16,
num_target_channels = 3,
trainable_perception = False,
cell_fire_rate = 1.0,
alpha_living_threshold = 0.1
)
nca_seed = nca.create_seed(
nca.num_hidden_channels, nca.num_target_channels, shape=(64,64), batch_size=1
)
rng = jax.random.PRNGKey(0)
params = = nca.init(rng, nca_seed, rng)["params"]
update = nca.apply({"params":params}, nca_seed, jax.random.PRNGKey(10))
# multi step
final_state, nca_states = nca.multi_step(poarams, nca_seed, jax.random.PRNGKey(10), num_steps=32)
To train the NCA
from jax_nca.dataset import ImageDataset
from jax_nca.trainer import EmojiTrainer
dataset = ImageDataset(emoji='🦎', img_size=64)
nca = NCA(
num_hidden_channels = 16,
num_target_channels = 3,
trainable_perception = False,
cell_fire_rate = 1.0,
alpha_living_threshold = 0.1
)
trainer = EmojiTrainer(dataset, nca, n_damage=0)
trainer.train(100000, batch_size=8, seed=10, lr=2e-4, min_steps=64, max_steps=96)
# to access train state:
state = trainer.state
# save
nca.save(state.params, "saved_params")
# load params
loaded_params = nca.load("saved_params")