File size: 4,627 Bytes
3d6ba31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np

from constants import UNKNOWN_COLOR, DEFAULT_COLOR, UNKNOWN_COLOR_RGB, DEFAULT_COLOR_RGB

def compute_P(alleles):
    '''Compute the population matrix P(allele|gene) from the [alleles] given in input'''
    P = []
    # Process each gene position
    for gene_alleles in alleles:
        # Use Counter for more efficient counting
        unique_alleles, counts = np.unique(gene_alleles, return_counts=True)
        # Create frequency dictionary directly
        d = dict(zip(unique_alleles, counts / len(gene_alleles)))
        P.append(d)
    return P

def compute_all_P(data, models):
    '''Compute all population matrices from a given list of [models] on the data'''
    all_P = {}
    for mi, m in enumerate(models):
        alleles = data[mi]
        P = compute_P(alleles)
        all_P[m] = P
    return all_P

def compute_sim_matrix(models,all_P):
    '''Compute the entire similarity matrix in one go'''
    n_models = len(models)
    n_genes = len(all_P[models[0]])
    
    # Initialize matrices to store numerator and denominator terms
    total_numerator = np.zeros((n_models, n_models))
    left_denominators = np.zeros(n_models)
    right_denominators = np.zeros(n_models)
    
    # Process each gene position
    for k in range(n_genes):
        # Collect all alleles for this gene position
        all_alleles = set()
        for m in models:
            all_alleles.update(all_P[m][k].keys())
        all_alleles = list(all_alleles)
        
        # Create frequency vectors for each model
        freq_matrix = np.zeros((n_models, len(all_alleles)))
        for i, m in enumerate(models):
            for j, allele in enumerate(all_alleles):
                if allele in all_P[m][k]:
                    freq_matrix[i, j] = all_P[m][k][allele]
        
        # Update numerator: dot product of frequency vectors
        total_numerator += np.dot(freq_matrix, freq_matrix.T)
        
        # Update denominators: sum of squared frequencies
        squared_sums = np.sum(freq_matrix**2, axis=1)
        left_denominators += squared_sums
        right_denominators += squared_sums
    
    # Calculate final similarity matrix
    denominator_matrix = np.sqrt(np.outer(left_denominators, right_denominators))
    sim_matrix = total_numerator / denominator_matrix
    
    return sim_matrix

def prepare_tree(tree, model_names, origins, colors):
    """Prepare and color the phylogenetic tree based on model families."""
    # Remove inner node names and color leaf nodes
    for clade in tree.find_clades():
        if clade.name and (clade.name.startswith('Inner') or clade.name.startswith('Clade')):
            #clade.name = None
            pass
        if clade.name == None or clade.name not in model_names:
            clade.family = None
            clade.flag = False
            continue
        # Color the clades if it is a leaf
        index = model_names.index(clade.name)
        clade.family = origins[index]
        clade.flag = True
    
    # Propagate colors up the tree when all children have the same color
    all_clades = list(tree.find_clades())
    clades = [clade for clade in all_clades if clade.flag is False]
    
    # Iterate this process until there are no more clades to color
    i = 0
    while clades:
        clade = clades[i % len(clades)]
        children_families = [c.family for c in clade.clades]
        children_families_set = set(children_families)
        children_flags = [c.flag for c in clade.clades]
        children_flags_set = set(children_flags)
        
        if len(children_families_set) == 1: # If all children have the same color : this clade is locked with the same color
            clade.family = children_families[0]
            clade.flag = True
            del clades[i % len(clades)]
        elif len(children_families_set) == 2 and '?' in children_families_set: # If children have different colors and one is unknown : this clade is locked with the known color
            clade.family = [f for f in children_families_set if f != '?'][0]
            clade.flag = True
            del clades[i % len(clades)]
        elif len(children_flags_set) == 1: # If children have different colors : this clade is locked with no color
            clade.flag = True
            del clades[i % len(clades)]
        elif clade.flag == True: #Sholdn't happen
            del clades[i % len(clades)]
        i += 1

    #Set color associated with family to each clade
    for clade in all_clades:
        if clade.family is None:
            clade.color = UNKNOWN_COLOR
        else:
            clade.color = colors[clade.family]