PhyloLM / phylogeny.py
Daetheys's picture
First version gradio
3d6ba31
import numpy as np
from constants import UNKNOWN_COLOR, DEFAULT_COLOR, UNKNOWN_COLOR_RGB, DEFAULT_COLOR_RGB
def compute_P(alleles):
'''Compute the population matrix P(allele|gene) from the [alleles] given in input'''
P = []
# Process each gene position
for gene_alleles in alleles:
# Use Counter for more efficient counting
unique_alleles, counts = np.unique(gene_alleles, return_counts=True)
# Create frequency dictionary directly
d = dict(zip(unique_alleles, counts / len(gene_alleles)))
P.append(d)
return P
def compute_all_P(data, models):
'''Compute all population matrices from a given list of [models] on the data'''
all_P = {}
for mi, m in enumerate(models):
alleles = data[mi]
P = compute_P(alleles)
all_P[m] = P
return all_P
def compute_sim_matrix(models,all_P):
'''Compute the entire similarity matrix in one go'''
n_models = len(models)
n_genes = len(all_P[models[0]])
# Initialize matrices to store numerator and denominator terms
total_numerator = np.zeros((n_models, n_models))
left_denominators = np.zeros(n_models)
right_denominators = np.zeros(n_models)
# Process each gene position
for k in range(n_genes):
# Collect all alleles for this gene position
all_alleles = set()
for m in models:
all_alleles.update(all_P[m][k].keys())
all_alleles = list(all_alleles)
# Create frequency vectors for each model
freq_matrix = np.zeros((n_models, len(all_alleles)))
for i, m in enumerate(models):
for j, allele in enumerate(all_alleles):
if allele in all_P[m][k]:
freq_matrix[i, j] = all_P[m][k][allele]
# Update numerator: dot product of frequency vectors
total_numerator += np.dot(freq_matrix, freq_matrix.T)
# Update denominators: sum of squared frequencies
squared_sums = np.sum(freq_matrix**2, axis=1)
left_denominators += squared_sums
right_denominators += squared_sums
# Calculate final similarity matrix
denominator_matrix = np.sqrt(np.outer(left_denominators, right_denominators))
sim_matrix = total_numerator / denominator_matrix
return sim_matrix
def prepare_tree(tree, model_names, origins, colors):
"""Prepare and color the phylogenetic tree based on model families."""
# Remove inner node names and color leaf nodes
for clade in tree.find_clades():
if clade.name and (clade.name.startswith('Inner') or clade.name.startswith('Clade')):
#clade.name = None
pass
if clade.name == None or clade.name not in model_names:
clade.family = None
clade.flag = False
continue
# Color the clades if it is a leaf
index = model_names.index(clade.name)
clade.family = origins[index]
clade.flag = True
# Propagate colors up the tree when all children have the same color
all_clades = list(tree.find_clades())
clades = [clade for clade in all_clades if clade.flag is False]
# Iterate this process until there are no more clades to color
i = 0
while clades:
clade = clades[i % len(clades)]
children_families = [c.family for c in clade.clades]
children_families_set = set(children_families)
children_flags = [c.flag for c in clade.clades]
children_flags_set = set(children_flags)
if len(children_families_set) == 1: # If all children have the same color : this clade is locked with the same color
clade.family = children_families[0]
clade.flag = True
del clades[i % len(clades)]
elif len(children_families_set) == 2 and '?' in children_families_set: # If children have different colors and one is unknown : this clade is locked with the known color
clade.family = [f for f in children_families_set if f != '?'][0]
clade.flag = True
del clades[i % len(clades)]
elif len(children_flags_set) == 1: # If children have different colors : this clade is locked with no color
clade.flag = True
del clades[i % len(clades)]
elif clade.flag == True: #Sholdn't happen
del clades[i % len(clades)]
i += 1
#Set color associated with family to each clade
for clade in all_clades:
if clade.family is None:
clade.color = UNKNOWN_COLOR
else:
clade.color = colors[clade.family]