|
import numpy as np |
|
|
|
from constants import UNKNOWN_COLOR, DEFAULT_COLOR, UNKNOWN_COLOR_RGB, DEFAULT_COLOR_RGB |
|
|
|
def compute_P(alleles): |
|
'''Compute the population matrix P(allele|gene) from the [alleles] given in input''' |
|
P = [] |
|
|
|
for gene_alleles in alleles: |
|
|
|
unique_alleles, counts = np.unique(gene_alleles, return_counts=True) |
|
|
|
d = dict(zip(unique_alleles, counts / len(gene_alleles))) |
|
P.append(d) |
|
return P |
|
|
|
def compute_all_P(data, models): |
|
'''Compute all population matrices from a given list of [models] on the data''' |
|
all_P = {} |
|
for mi, m in enumerate(models): |
|
alleles = data[mi] |
|
P = compute_P(alleles) |
|
all_P[m] = P |
|
return all_P |
|
|
|
def compute_sim_matrix(models,all_P): |
|
'''Compute the entire similarity matrix in one go''' |
|
n_models = len(models) |
|
n_genes = len(all_P[models[0]]) |
|
|
|
|
|
total_numerator = np.zeros((n_models, n_models)) |
|
left_denominators = np.zeros(n_models) |
|
right_denominators = np.zeros(n_models) |
|
|
|
|
|
for k in range(n_genes): |
|
|
|
all_alleles = set() |
|
for m in models: |
|
all_alleles.update(all_P[m][k].keys()) |
|
all_alleles = list(all_alleles) |
|
|
|
|
|
freq_matrix = np.zeros((n_models, len(all_alleles))) |
|
for i, m in enumerate(models): |
|
for j, allele in enumerate(all_alleles): |
|
if allele in all_P[m][k]: |
|
freq_matrix[i, j] = all_P[m][k][allele] |
|
|
|
|
|
total_numerator += np.dot(freq_matrix, freq_matrix.T) |
|
|
|
|
|
squared_sums = np.sum(freq_matrix**2, axis=1) |
|
left_denominators += squared_sums |
|
right_denominators += squared_sums |
|
|
|
|
|
denominator_matrix = np.sqrt(np.outer(left_denominators, right_denominators)) |
|
sim_matrix = total_numerator / denominator_matrix |
|
|
|
return sim_matrix |
|
|
|
def prepare_tree(tree, model_names, origins, colors): |
|
"""Prepare and color the phylogenetic tree based on model families.""" |
|
|
|
for clade in tree.find_clades(): |
|
if clade.name and (clade.name.startswith('Inner') or clade.name.startswith('Clade')): |
|
|
|
pass |
|
if clade.name == None or clade.name not in model_names: |
|
clade.family = None |
|
clade.flag = False |
|
continue |
|
|
|
index = model_names.index(clade.name) |
|
clade.family = origins[index] |
|
clade.flag = True |
|
|
|
|
|
all_clades = list(tree.find_clades()) |
|
clades = [clade for clade in all_clades if clade.flag is False] |
|
|
|
|
|
i = 0 |
|
while clades: |
|
clade = clades[i % len(clades)] |
|
children_families = [c.family for c in clade.clades] |
|
children_families_set = set(children_families) |
|
children_flags = [c.flag for c in clade.clades] |
|
children_flags_set = set(children_flags) |
|
|
|
if len(children_families_set) == 1: |
|
clade.family = children_families[0] |
|
clade.flag = True |
|
del clades[i % len(clades)] |
|
elif len(children_families_set) == 2 and '?' in children_families_set: |
|
clade.family = [f for f in children_families_set if f != '?'][0] |
|
clade.flag = True |
|
del clades[i % len(clades)] |
|
elif len(children_flags_set) == 1: |
|
clade.flag = True |
|
del clades[i % len(clades)] |
|
elif clade.flag == True: |
|
del clades[i % len(clades)] |
|
i += 1 |
|
|
|
|
|
for clade in all_clades: |
|
if clade.family is None: |
|
clade.color = UNKNOWN_COLOR |
|
else: |
|
clade.color = colors[clade.family] |
|
|