Spaces:
Sleeping
Sleeping
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """lDDT protein distance score.""" | |
| import jax.numpy as jnp | |
| def lddt(predicted_points, | |
| true_points, | |
| true_points_mask, | |
| cutoff=15., | |
| per_residue=False): | |
| """Measure (approximate) lDDT for a batch of coordinates. | |
| lDDT reference: | |
| Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local | |
| superposition-free score for comparing protein structures and models using | |
| distance difference tests. Bioinformatics 29, 2722–2728 (2013). | |
| lDDT is a measure of the difference between the true distance matrix and the | |
| distance matrix of the predicted points. The difference is computed only on | |
| points closer than cutoff *in the true structure*. | |
| This function does not compute the exact lDDT value that the original paper | |
| describes because it does not include terms for physical feasibility | |
| (e.g. bond length violations). Therefore this is only an approximate | |
| lDDT score. | |
| Args: | |
| predicted_points: (batch, length, 3) array of predicted 3D points | |
| true_points: (batch, length, 3) array of true 3D points | |
| true_points_mask: (batch, length, 1) binary-valued float array. This mask | |
| should be 1 for points that exist in the true points. | |
| cutoff: Maximum distance for a pair of points to be included | |
| per_residue: If true, return score for each residue. Note that the overall | |
| lDDT is not exactly the mean of the per_residue lDDT's because some | |
| residues have more contacts than others. | |
| Returns: | |
| An (approximate, see above) lDDT score in the range 0-1. | |
| """ | |
| assert len(predicted_points.shape) == 3 | |
| assert predicted_points.shape[-1] == 3 | |
| assert true_points_mask.shape[-1] == 1 | |
| assert len(true_points_mask.shape) == 3 | |
| # Compute true and predicted distance matrices. | |
| dmat_true = jnp.sqrt(1e-10 + jnp.sum( | |
| (true_points[:, :, None] - true_points[:, None, :])**2, axis=-1)) | |
| dmat_predicted = jnp.sqrt(1e-10 + jnp.sum( | |
| (predicted_points[:, :, None] - | |
| predicted_points[:, None, :])**2, axis=-1)) | |
| dists_to_score = ( | |
| (dmat_true < cutoff).astype(jnp.float32) * true_points_mask * | |
| jnp.transpose(true_points_mask, [0, 2, 1]) * | |
| (1. - jnp.eye(dmat_true.shape[1])) # Exclude self-interaction. | |
| ) | |
| # Shift unscored distances to be far away. | |
| dist_l1 = jnp.abs(dmat_true - dmat_predicted) | |
| # True lDDT uses a number of fixed bins. | |
| # We ignore the physical plausibility correction to lDDT, though. | |
| score = 0.25 * ((dist_l1 < 0.5).astype(jnp.float32) + | |
| (dist_l1 < 1.0).astype(jnp.float32) + | |
| (dist_l1 < 2.0).astype(jnp.float32) + | |
| (dist_l1 < 4.0).astype(jnp.float32)) | |
| # Normalize over the appropriate axes. | |
| reduce_axes = (-1,) if per_residue else (-2, -1) | |
| norm = 1. / (1e-10 + jnp.sum(dists_to_score, axis=reduce_axes)) | |
| score = norm * (1e-10 + jnp.sum(dists_to_score * score, axis=reduce_axes)) | |
| return score | |