Spaces:
Running
on
T4
Running
on
T4
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """A collection of common Haiku modules for use in protein folding.""" | |
| import haiku as hk | |
| import jax.numpy as jnp | |
| class Linear(hk.Module): | |
| """Protein folding specific Linear Module. | |
| This differs from the standard Haiku Linear in a few ways: | |
| * It supports inputs of arbitrary rank | |
| * Initializers are specified by strings | |
| """ | |
| def __init__(self, | |
| num_output: int, | |
| initializer: str = 'linear', | |
| use_bias: bool = True, | |
| bias_init: float = 0., | |
| name: str = 'linear'): | |
| """Constructs Linear Module. | |
| Args: | |
| num_output: number of output channels. | |
| initializer: What initializer to use, should be one of {'linear', 'relu', | |
| 'zeros'} | |
| use_bias: Whether to include trainable bias | |
| bias_init: Value used to initialize bias. | |
| name: name of module, used for name scopes. | |
| """ | |
| super().__init__(name=name) | |
| self.num_output = num_output | |
| self.initializer = initializer | |
| self.use_bias = use_bias | |
| self.bias_init = bias_init | |
| def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: | |
| """Connects Module. | |
| Args: | |
| inputs: Tensor of shape [..., num_channel] | |
| Returns: | |
| output of shape [..., num_output] | |
| """ | |
| n_channels = int(inputs.shape[-1]) | |
| weight_shape = [n_channels, self.num_output] | |
| if self.initializer == 'linear': | |
| weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=1.) | |
| elif self.initializer == 'relu': | |
| weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=2.) | |
| elif self.initializer == 'zeros': | |
| weight_init = hk.initializers.Constant(0.0) | |
| weights = hk.get_parameter('weights', weight_shape, inputs.dtype, | |
| weight_init) | |
| # this is equivalent to einsum('...c,cd->...d', inputs, weights) | |
| # but turns out to be slightly faster | |
| inputs = jnp.swapaxes(inputs, -1, -2) | |
| output = jnp.einsum('...cb,cd->...db', inputs, weights) | |
| output = jnp.swapaxes(output, -1, -2) | |
| if self.use_bias: | |
| bias = hk.get_parameter('bias', [self.num_output], inputs.dtype, | |
| hk.initializers.Constant(self.bias_init)) | |
| output += bias | |
| return output | |