Spaces:
Running
on
T4
Running
on
T4
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Code for constructing the model.""" | |
| from typing import Any, Mapping, Optional, Union | |
| from absl import logging | |
| from alphafold.common import confidence | |
| from alphafold.model import features | |
| from alphafold.model import modules | |
| import haiku as hk | |
| import jax | |
| import ml_collections | |
| import numpy as np | |
| import tensorflow.compat.v1 as tf | |
| import tree | |
| def get_confidence_metrics( | |
| prediction_result: Mapping[str, Any]) -> Mapping[str, Any]: | |
| """Post processes prediction_result to get confidence metrics.""" | |
| confidence_metrics = {} | |
| confidence_metrics['plddt'] = confidence.compute_plddt( | |
| prediction_result['predicted_lddt']['logits']) | |
| if 'predicted_aligned_error' in prediction_result: | |
| confidence_metrics.update(confidence.compute_predicted_aligned_error( | |
| prediction_result['predicted_aligned_error']['logits'], | |
| prediction_result['predicted_aligned_error']['breaks'])) | |
| confidence_metrics['ptm'] = confidence.predicted_tm_score( | |
| prediction_result['predicted_aligned_error']['logits'], | |
| prediction_result['predicted_aligned_error']['breaks']) | |
| return confidence_metrics | |
| class RunModel: | |
| """Container for JAX model.""" | |
| def __init__(self, | |
| config: ml_collections.ConfigDict, | |
| params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None, | |
| is_training=True, | |
| return_representations=True): | |
| self.config = config | |
| self.params = params | |
| def _forward_fn(batch): | |
| model = modules.AlphaFold(self.config.model) | |
| return model( | |
| batch, | |
| is_training=is_training, | |
| compute_loss=False, | |
| ensemble_representations=False, | |
| return_representations=return_representations) | |
| self.apply = jax.jit(hk.transform(_forward_fn).apply) | |
| self.init = jax.jit(hk.transform(_forward_fn).init) | |
| def init_params(self, feat: features.FeatureDict, random_seed: int = 0): | |
| """Initializes the model parameters. | |
| If none were provided when this class was instantiated then the parameters | |
| are randomly initialized. | |
| Args: | |
| feat: A dictionary of NumPy feature arrays as output by | |
| RunModel.process_features. | |
| random_seed: A random seed to use to initialize the parameters if none | |
| were set when this class was initialized. | |
| """ | |
| if not self.params: | |
| # Init params randomly. | |
| rng = jax.random.PRNGKey(random_seed) | |
| self.params = hk.data_structures.to_mutable_dict( | |
| self.init(rng, feat)) | |
| logging.warning('Initialized parameters randomly') | |
| def process_features( | |
| self, | |
| raw_features: Union[tf.train.Example, features.FeatureDict], | |
| random_seed: int) -> features.FeatureDict: | |
| """Processes features to prepare for feeding them into the model. | |
| Args: | |
| raw_features: The output of the data pipeline either as a dict of NumPy | |
| arrays or as a tf.train.Example. | |
| random_seed: The random seed to use when processing the features. | |
| Returns: | |
| A dict of NumPy feature arrays suitable for feeding into the model. | |
| """ | |
| if isinstance(raw_features, dict): | |
| return features.np_example_to_features( | |
| np_example=raw_features, | |
| config=self.config, | |
| random_seed=random_seed) | |
| else: | |
| return features.tf_example_to_features( | |
| tf_example=raw_features, | |
| config=self.config, | |
| random_seed=random_seed) | |
| def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct: | |
| self.init_params(feat) | |
| logging.info('Running eval_shape with shape(feat) = %s', | |
| tree.map_structure(lambda x: x.shape, feat)) | |
| shape = jax.eval_shape(self.apply, self.params, jax.random.PRNGKey(0), feat) | |
| logging.info('Output shape was %s', shape) | |
| return shape | |
| def predict(self, feat: features.FeatureDict) -> Mapping[str, Any]: | |
| """Makes a prediction by inferencing the model on the provided features. | |
| Args: | |
| feat: A dictionary of NumPy feature arrays as output by | |
| RunModel.process_features. | |
| Returns: | |
| A dictionary of model outputs. | |
| """ | |
| self.init_params(feat) | |
| logging.info('Running predict with shape(feat) = %s', | |
| tree.map_structure(lambda x: x.shape, feat)) | |
| result = self.apply(self.params, jax.random.PRNGKey(0), feat) | |
| # This block is to ensure benchmark timings are accurate. Some blocking is | |
| # already happening when computing get_confidence_metrics, and this ensures | |
| # all outputs are blocked on. | |
| jax.tree_map(lambda x: x.block_until_ready(), result) | |
| if self.config.use_struct: | |
| result.update(get_confidence_metrics(result)) | |
| logging.info('Output shape was %s', | |
| tree.map_structure(lambda x: x.shape, result)) | |
| return result | |