import networkx as nx import numpy as np from Bio.Phylo import to_networkx from networkx.drawing.nx_agraph import graphviz_layout import plotly.graph_objects as go import plotly.express as px from Bio.Phylo.TreeConstruction import DistanceTreeConstructor, DistanceCalculator, _DistanceMatrix from tools import compute_ordered_matrix,compute_umap from phylogeny import prepare_tree from constants import UNKNOWN_COLOR, DEFAULT_COLOR, UNKNOWN_COLOR_RGB, DEFAULT_COLOR_RGB # ------------------------------------------------------------------------------------------------ # # Sim Matrix Plotting # # ------------------------------------------------------------------------------------------------ def plot_sim_matrix_fig(ordered_sim_matrix,ordered_model_names,families,colors): fig = px.imshow( ordered_sim_matrix, x=ordered_model_names, y=ordered_model_names, zmin=0, zmax=1, color_continuous_scale='gray', ) fig.update_layout(coloraxis_colorbar=dict(title='Similarity'), margin=dict(l=0, r=0, t=0, b=0), autosize=True, ) fig.update_traces( colorbar=dict( thickness=20, len=0.75, xanchor="right", x=1.02 ) ) fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range') fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range') #Create rectangles for highlighted models rectX = go.layout.Shape( type="rect", xref="x", yref="y", x0=0, y0=0, x1=0, y1=0, line=dict(color="red", width=1), fillcolor="rgba(0,0,0,0)", name='rectX', opacity=0, ) fig.add_shape(rectX) rectY = go.layout.Shape( type="rect", xref="x", yref="y", x0=0, y0=0, x1=0, y1=0, line=dict(color="red", width=1), fillcolor="rgba(0,0,0,0)", name='rectY', opacity=0, ) fig.add_shape(rectY) return fig def update_sim_matrix_fig(fig, ordered_model_names, model_search_x=None, model_search_y=None): if model_search_x in ordered_model_names: idx_x = ordered_model_names.index(model_search_x) fig.update_shapes( selector=dict(name='rectX'), x0=idx_x-0.5, y0=-0.5, x1=idx_x+0.5, y1=len(ordered_model_names)-0.5, opacity=0.7, ) else: fig.update_shapes( selector=dict(name='rectX'), opacity=0 ) if model_search_y in ordered_model_names: idx_y = ordered_model_names.index(model_search_y) fig.update_shapes( selector=dict(name='rectY'), x0=-0.5, y0=idx_y-0.5, x1=len(ordered_model_names)-0.5, y1=idx_y+0.5, opacity=0.7, ) else: fig.update_shapes( selector=dict(name='rectY'), opacity=0 ) return fig # ------------------------------------------------------------------------------------------------ # # 2D UMAP Plotting # # ------------------------------------------------------------------------------------------------ def alpha_scaling(val): base = 0.35 return val**(1/(base+1/100)) def plot_umap_fig(dist_matrix, sim_matrix, model_names, families, colors,key='fig2',alpha_edges=None, alpha_names=None, alpha_markers=None): embedding = compute_umap(dist_matrix,d=2) fig = go.Figure() #-- EDGES # Calculate edge transparencies based on similarity edges = [] for i in range(len(model_names)): for j in range(i+1, len(model_names)): # Only process each pair once (i,j where i 0.1: edges.append((i, j, val, colors[families[i]])) # Add all edges at once for i, j, val, color in edges: fig.add_trace( go.Scatter( x=[embedding[i,0], embedding[j,0]], y=[embedding[i,1], embedding[j,1]], mode='lines', name='_edge', line=dict(color=color, width=val), opacity=alpha_edges, showlegend=False, hoverinfo='skip', ) ) #-- NODES marker_colors = [colors[f] for f in families] fig.add_trace( go.Scatter( x=embedding[:,0], y=embedding[:,1], text=model_names, mode='markers+text', textposition='top center', hoverinfo='text', hoveron='points+fills', showlegend=False, name='_node', marker=dict( color=marker_colors, size=8, line_width=2, opacity=alpha_markers, ), textfont=dict( color=f'rgba(0,0,0,{alpha_names})', size=8, family="Arial Black", ) ) ) #-- LEGEND legends = [] for f in set(families): legends.append( go.Scatter( x=[None], y=[None], mode='markers', marker=dict( color=colors[f], size=8, line_width=2, opacity=1 ), name=f, ) ) fig.add_traces(legends) #Add highlighted node node = go.Scatter( x=[0], y=[0], mode='markers+text', textposition='top center', textfont=dict(color='red', size=16, family="Arial Black"), marker=dict( color='red', size=12, symbol='circle', line=dict(color='red', width=3) ), showlegend=False, name='node', opacity=0, ) fig.add_trace(node) #Setup the layout fig.update_layout( margin=dict(l=0, r=0, t=0, b=0), autosize=True, ) fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range') fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range') return fig def update_umap_fig(fig, dist_matrix, model_names, families, colors, model_search_x=None, alpha_names=None, alpha_markers=None, alpha_edges=None, key='fig2'): #Update nodes fig.update_traces( selector=dict(name='_node'), textfont=dict( color=f'rgba(0,0,0,{alpha_names})', ), marker=dict( opacity=alpha_markers ), ) #Update edges fig.update_traces( selector=dict(mode='lines'), line=dict(width=1), opacity=alpha_edges ) #Update highlighted node if model_search_x in model_names: searched_idx = model_names.index(model_search_x) embedding = compute_umap(dist_matrix,d=2) #Cached computation fig.update_traces( selector=dict(name='node'), x=[embedding[searched_idx,0]], y=[embedding[searched_idx,1]], text=[model_search_x], marker=dict( color=colors[families[searched_idx]], ), hovertext=model_search_x, opacity=1 ) else: fig.update_traces( selector=dict(name='node'), x=[0], y=[0], text=[''], opacity=0 ) return fig # ------------------------------------------------------------------------------------------------ # # Phylogenetic Tree Plotting # # ------------------------------------------------------------------------------------------------ def draw_graphviz(tree, label_func=str, prog='twopi', args='', node_size=15, edge_width=0.0, alpha_edges=None, alpha_names=None,alpha_markers=None, **kwargs): #Display a tree or clade as a graph using Plotly, with layout from the graphviz engine. global UNKNOWN_COLOR, DEFAULT_COLOR # Convert the Bio.Phylo tree to a NetworkX graph G = to_networkx(tree) # Relabel nodes using integers while keeping original labels Gi = nx.convert_node_labels_to_integers(G, label_attribute='label') # Apply the Graphviz layout pos = graphviz_layout(Gi, prog=prog, args=args) # Prepare node labels for display def get_label_mapping(G, selection): for node, data in G.nodes(data=True): if (selection is None) or (node in selection): try: label = label_func(data.get('label', node)) if label not in (None, node.__class__.__name__): yield (node, label) except (LookupError, AttributeError, ValueError): pass # Extract labels labels = dict(get_label_mapping(Gi, None)) nodelist = list(labels.keys()) # Collect node colors and create edge traces edge_traces = [] node_traces_by_family = {} node_colors = {} node_families = {} # Track if we find the searched model and its position searched_model_node = None searched_model_pos = None default_color = (0,0,0) # Get colors and families for all nodes for node in Gi.nodes(): node_data = Gi.nodes[node].get('label') if hasattr(node_data, 'color'): node_colors[node] = node_data.color.to_rgb() if not(node_data.color is None) else default_color else: node_colors[node] = default_color node_colors[node] = f'rgb({node_colors[node][0]},{node_colors[node][1]},{node_colors[node][2]})' if hasattr(node_data, 'family'): node_families[node] = node_data.family else: node_families[node] = None # Create edge traces for edge in Gi.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] # Use the child node's color for the edge if available edge_color = node_colors[edge[1]] if list(edge_color) == list(UNKNOWN_COLOR_RGB): # Use the parent node's color for edge's color except if it's an unknown nodes edge_color = tuple(DEFAULT_COLOR_RGB) #edge_color = f'rgb({edge_color[0]},{edge_color[1]},{edge_color[2]})' edge_trace = go.Scatter( x=[x0, x1, None], y=[y0, y1, None], line=dict(width=edge_width, color=edge_color), hoverinfo='none', mode='lines', showlegend=False, name='_edge', opacity=alpha_edges, ) edge_traces.append(edge_trace) # Create node traces node_traces = [] for node in nodelist: x,y = pos[node] text = labels.get(node, None) color = node_colors.get(node, None) node_trace = go.Scatter( x=[x], y=[y], text=text, mode='markers+text', textposition='top center', hoverinfo='text', showlegend=False, name='_node', marker=dict( color=color, size=node_size, line_width=2, opacity=alpha_markers, ), textfont=dict( color=f'rgba(0,0,0,{alpha_names})', size=8, family="Arial Black", ) ) node_traces.append(node_trace) # Get color dict colors = {} families = [] for node in node_families.keys(): family = node_families[node] if family is not None: families.append(family) colors[family] = node_colors.get(node, DEFAULT_COLOR) else: colors[family] = DEFAULT_COLOR families = set(families) #Custom legend legends = [] for f in families: legends.append( go.Scatter( x=[None], y=[None], mode='markers', marker=dict( color=colors[f], size=8, line_width=2, opacity=1 ), name=f, ) ) # Create the figure fig = go.Figure( data=edge_traces + node_traces, layout=go.Layout( showlegend=True, hovermode='closest', margin=dict(b=1, l=1, r=1, t=1), xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), legend=dict( yanchor="top", y=0.99, xanchor="right", x=0.99 ) ) ) fig.add_traces(legends) return fig def get_color(index): """Get a color from plotly's qualitative color palette.""" colors = px.colors.qualitative.Plotly return colors[index % len(colors)] def plot_tree(sim_matrix, models, families,colors, alpha_names=None, alpha_markers=None, alpha_edges=None): """ Plot a phylogenetic tree based on a similarity matrix. Parameters: - sim_matrix: similarity matrix between models - models: list of model names - families: list of family names for each model Returns: - fig: Plotly figure object with the phylogenetic tree """ # Create color mapping for families # Prepare the distance matrix dist_matrix = -np.log(np.maximum(sim_matrix, 1e-10)) # Avoid log(0) # Prepare the data for Bio.Phylo low_triangle_kl_mean = [[dist_matrix[i][j] for j in range(i+1)] for i in range(len(dist_matrix))] df = _DistanceMatrix(names=models, matrix=low_triangle_kl_mean) # Setup Bio.Phylo calculator = DistanceCalculator('identity') constructor = DistanceTreeConstructor(calculator, 'nj') # Build the tree NJTree = constructor.nj(df) NJTree.ladderize(reverse=False) # Color the tree prepare_tree(NJTree, models, families, colors) # Generate the plotly figure fig = draw_graphviz(NJTree, node_size=15, edge_width=6,alpha_names=alpha_names, alpha_markers=alpha_markers, alpha_edges=alpha_edges) return fig def update_tree_fig(fig, model_names, model_search=None,alpha_names=None, alpha_markers=None, alpha_edges=None): #Update nodes fig.update_traces( selector=dict(name='_node'), marker=dict( opacity=alpha_markers, ), textfont=dict( color=f'rgba(0,0,0,{alpha_names})', ) ) # Update edges fig.update_traces( selector=dict(name='_edge'), opacity=alpha_edges, ) for d in fig.data: if d.name in ['_node','node']: if d.text == 'mistralai/Mistral-7B-Instruct-v0.1': print(d) # Update highlighted node fig.update_traces( selector=dict(name='node'), marker=dict( size=15, # Bigger than normal nodes line=None # Red border ), textfont=dict( color=f'rgba(0,0,0,{alpha_names})', size=16, family="Arial Black", ), name='_node' ) if model_search in model_names: fig.update_traces( selector=dict(name='_node',text=model_search), marker=dict( size=22, # Bigger than normal nodes line=dict(color='red', width=4) # Red border ), textfont=dict( color='red', size=16, family="Arial Black", ), name='node' ) for d in fig.data: if d.name in ['_node','node']: if d.text == 'mistralai/Mistral-7B-Instruct-v0.1': print(d) else: pass return fig