import numpy as np from constants import UNKNOWN_COLOR, DEFAULT_COLOR, UNKNOWN_COLOR_RGB, DEFAULT_COLOR_RGB def compute_P(alleles): '''Compute the population matrix P(allele|gene) from the [alleles] given in input''' P = [] # Process each gene position for gene_alleles in alleles: # Use Counter for more efficient counting unique_alleles, counts = np.unique(gene_alleles, return_counts=True) # Create frequency dictionary directly d = dict(zip(unique_alleles, counts / len(gene_alleles))) P.append(d) return P def compute_all_P(data, models): '''Compute all population matrices from a given list of [models] on the data''' all_P = {} for mi, m in enumerate(models): alleles = data[mi] P = compute_P(alleles) all_P[m] = P return all_P def compute_sim_matrix(models,all_P): '''Compute the entire similarity matrix in one go''' n_models = len(models) n_genes = len(all_P[models[0]]) # Initialize matrices to store numerator and denominator terms total_numerator = np.zeros((n_models, n_models)) left_denominators = np.zeros(n_models) right_denominators = np.zeros(n_models) # Process each gene position for k in range(n_genes): # Collect all alleles for this gene position all_alleles = set() for m in models: all_alleles.update(all_P[m][k].keys()) all_alleles = list(all_alleles) # Create frequency vectors for each model freq_matrix = np.zeros((n_models, len(all_alleles))) for i, m in enumerate(models): for j, allele in enumerate(all_alleles): if allele in all_P[m][k]: freq_matrix[i, j] = all_P[m][k][allele] # Update numerator: dot product of frequency vectors total_numerator += np.dot(freq_matrix, freq_matrix.T) # Update denominators: sum of squared frequencies squared_sums = np.sum(freq_matrix**2, axis=1) left_denominators += squared_sums right_denominators += squared_sums # Calculate final similarity matrix denominator_matrix = np.sqrt(np.outer(left_denominators, right_denominators)) sim_matrix = total_numerator / denominator_matrix return sim_matrix def prepare_tree(tree, model_names, origins, colors): """Prepare and color the phylogenetic tree based on model families.""" # Remove inner node names and color leaf nodes for clade in tree.find_clades(): if clade.name and (clade.name.startswith('Inner') or clade.name.startswith('Clade')): #clade.name = None pass if clade.name == None or clade.name not in model_names: clade.family = None clade.flag = False continue # Color the clades if it is a leaf index = model_names.index(clade.name) clade.family = origins[index] clade.flag = True # Propagate colors up the tree when all children have the same color all_clades = list(tree.find_clades()) clades = [clade for clade in all_clades if clade.flag is False] # Iterate this process until there are no more clades to color i = 0 while clades: clade = clades[i % len(clades)] children_families = [c.family for c in clade.clades] children_families_set = set(children_families) children_flags = [c.flag for c in clade.clades] children_flags_set = set(children_flags) if len(children_families_set) == 1: # If all children have the same color : this clade is locked with the same color clade.family = children_families[0] clade.flag = True del clades[i % len(clades)] elif len(children_families_set) == 2 and '?' in children_families_set: # If children have different colors and one is unknown : this clade is locked with the known color clade.family = [f for f in children_families_set if f != '?'][0] clade.flag = True del clades[i % len(clades)] elif len(children_flags_set) == 1: # If children have different colors : this clade is locked with no color clade.flag = True del clades[i % len(clades)] elif clade.flag == True: #Sholdn't happen del clades[i % len(clades)] i += 1 #Set color associated with family to each clade for clade in all_clades: if clade.family is None: clade.color = UNKNOWN_COLOR else: clade.color = colors[clade.family]