PhyloLM / plotting.py
Daetheys's picture
First version gradio
3d6ba31
import networkx as nx
import numpy as np
from Bio.Phylo import to_networkx
from networkx.drawing.nx_agraph import graphviz_layout
import plotly.graph_objects as go
import plotly.express as px
from Bio.Phylo.TreeConstruction import DistanceTreeConstructor, DistanceCalculator, _DistanceMatrix
from tools import compute_ordered_matrix,compute_umap
from phylogeny import prepare_tree
from constants import UNKNOWN_COLOR, DEFAULT_COLOR, UNKNOWN_COLOR_RGB, DEFAULT_COLOR_RGB
# ------------------------------------------------------------------------------------------------
#
# Sim Matrix Plotting
#
# ------------------------------------------------------------------------------------------------
def plot_sim_matrix_fig(ordered_sim_matrix,ordered_model_names,families,colors):
fig = px.imshow(
ordered_sim_matrix,
x=ordered_model_names,
y=ordered_model_names,
zmin=0, zmax=1,
color_continuous_scale='gray',
)
fig.update_layout(coloraxis_colorbar=dict(title='Similarity'),
margin=dict(l=0, r=0, t=0, b=0),
autosize=True,
)
fig.update_traces(
colorbar=dict(
thickness=20,
len=0.75,
xanchor="right",
x=1.02
)
)
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')
#Create rectangles for highlighted models
rectX = go.layout.Shape(
type="rect",
xref="x", yref="y",
x0=0, y0=0,
x1=0, y1=0,
line=dict(color="red", width=1),
fillcolor="rgba(0,0,0,0)",
name='rectX',
opacity=0,
)
fig.add_shape(rectX)
rectY = go.layout.Shape(
type="rect",
xref="x", yref="y",
x0=0, y0=0,
x1=0, y1=0,
line=dict(color="red", width=1),
fillcolor="rgba(0,0,0,0)",
name='rectY',
opacity=0,
)
fig.add_shape(rectY)
return fig
def update_sim_matrix_fig(fig, ordered_model_names, model_search_x=None, model_search_y=None):
if model_search_x in ordered_model_names:
idx_x = ordered_model_names.index(model_search_x)
fig.update_shapes(
selector=dict(name='rectX'),
x0=idx_x-0.5, y0=-0.5,
x1=idx_x+0.5, y1=len(ordered_model_names)-0.5,
opacity=0.7,
)
else:
fig.update_shapes(
selector=dict(name='rectX'),
opacity=0
)
if model_search_y in ordered_model_names:
idx_y = ordered_model_names.index(model_search_y)
fig.update_shapes(
selector=dict(name='rectY'),
x0=-0.5, y0=idx_y-0.5,
x1=len(ordered_model_names)-0.5, y1=idx_y+0.5,
opacity=0.7,
)
else:
fig.update_shapes(
selector=dict(name='rectY'),
opacity=0
)
return fig
# ------------------------------------------------------------------------------------------------
#
# 2D UMAP Plotting
#
# ------------------------------------------------------------------------------------------------
def alpha_scaling(val):
base = 0.35
return val**(1/(base+1/100))
def plot_umap_fig(dist_matrix, sim_matrix, model_names, families, colors,key='fig2',alpha_edges=None, alpha_names=None, alpha_markers=None):
embedding = compute_umap(dist_matrix,d=2)
fig = go.Figure()
#-- EDGES
# Calculate edge transparencies based on similarity
edges = []
for i in range(len(model_names)):
for j in range(i+1, len(model_names)): # Only process each pair once (i,j where i<j)
val = alpha_scaling(sim_matrix[i][j])
if val > 0.1:
edges.append((i, j, val, colors[families[i]]))
# Add all edges at once
for i, j, val, color in edges:
fig.add_trace(
go.Scatter(
x=[embedding[i,0], embedding[j,0]],
y=[embedding[i,1], embedding[j,1]],
mode='lines',
name='_edge',
line=dict(color=color, width=val),
opacity=alpha_edges,
showlegend=False,
hoverinfo='skip',
)
)
#-- NODES
marker_colors = [colors[f] for f in families]
fig.add_trace(
go.Scatter(
x=embedding[:,0],
y=embedding[:,1],
text=model_names,
mode='markers+text',
textposition='top center',
hoverinfo='text',
hoveron='points+fills',
showlegend=False,
name='_node',
marker=dict(
color=marker_colors,
size=8,
line_width=2,
opacity=alpha_markers,
),
textfont=dict(
color=f'rgba(0,0,0,{alpha_names})',
size=8,
family="Arial Black",
)
)
)
#-- LEGEND
legends = []
for f in set(families):
legends.append(
go.Scatter(
x=[None],
y=[None],
mode='markers',
marker=dict(
color=colors[f],
size=8,
line_width=2,
opacity=1
),
name=f,
)
)
fig.add_traces(legends)
#Add highlighted node
node = go.Scatter(
x=[0],
y=[0],
mode='markers+text',
textposition='top center',
textfont=dict(color='red', size=16, family="Arial Black"),
marker=dict(
color='red',
size=12,
symbol='circle',
line=dict(color='red', width=3)
),
showlegend=False,
name='node',
opacity=0,
)
fig.add_trace(node)
#Setup the layout
fig.update_layout(
margin=dict(l=0, r=0, t=0, b=0),
autosize=True,
)
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')
return fig
def update_umap_fig(fig, dist_matrix, model_names, families, colors, model_search_x=None, alpha_names=None, alpha_markers=None, alpha_edges=None, key='fig2'):
#Update nodes
fig.update_traces(
selector=dict(name='_node'),
textfont=dict(
color=f'rgba(0,0,0,{alpha_names})',
),
marker=dict(
opacity=alpha_markers
),
)
#Update edges
fig.update_traces(
selector=dict(mode='lines'),
line=dict(width=1),
opacity=alpha_edges
)
#Update highlighted node
if model_search_x in model_names:
searched_idx = model_names.index(model_search_x)
embedding = compute_umap(dist_matrix,d=2) #Cached computation
fig.update_traces(
selector=dict(name='node'),
x=[embedding[searched_idx,0]],
y=[embedding[searched_idx,1]],
text=[model_search_x],
marker=dict(
color=colors[families[searched_idx]],
),
hovertext=model_search_x,
opacity=1
)
else:
fig.update_traces(
selector=dict(name='node'),
x=[0],
y=[0],
text=[''],
opacity=0
)
return fig
# ------------------------------------------------------------------------------------------------
#
# Phylogenetic Tree Plotting
#
# ------------------------------------------------------------------------------------------------
def draw_graphviz(tree, label_func=str, prog='twopi', args='',
node_size=15, edge_width=0.0, alpha_edges=None, alpha_names=None,alpha_markers=None, **kwargs):
#Display a tree or clade as a graph using Plotly, with layout from the graphviz engine.
global UNKNOWN_COLOR, DEFAULT_COLOR
# Convert the Bio.Phylo tree to a NetworkX graph
G = to_networkx(tree)
# Relabel nodes using integers while keeping original labels
Gi = nx.convert_node_labels_to_integers(G, label_attribute='label')
# Apply the Graphviz layout
pos = graphviz_layout(Gi, prog=prog, args=args)
# Prepare node labels for display
def get_label_mapping(G, selection):
for node, data in G.nodes(data=True):
if (selection is None) or (node in selection):
try:
label = label_func(data.get('label', node))
if label not in (None, node.__class__.__name__):
yield (node, label)
except (LookupError, AttributeError, ValueError):
pass
# Extract labels
labels = dict(get_label_mapping(Gi, None))
nodelist = list(labels.keys())
# Collect node colors and create edge traces
edge_traces = []
node_traces_by_family = {}
node_colors = {}
node_families = {}
# Track if we find the searched model and its position
searched_model_node = None
searched_model_pos = None
default_color = (0,0,0)
# Get colors and families for all nodes
for node in Gi.nodes():
node_data = Gi.nodes[node].get('label')
if hasattr(node_data, 'color'):
node_colors[node] = node_data.color.to_rgb() if not(node_data.color is None) else default_color
else:
node_colors[node] = default_color
node_colors[node] = f'rgb({node_colors[node][0]},{node_colors[node][1]},{node_colors[node][2]})'
if hasattr(node_data, 'family'):
node_families[node] = node_data.family
else:
node_families[node] = None
# Create edge traces
for edge in Gi.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
# Use the child node's color for the edge if available
edge_color = node_colors[edge[1]]
if list(edge_color) == list(UNKNOWN_COLOR_RGB): # Use the parent node's color for edge's color except if it's an unknown nodes
edge_color = tuple(DEFAULT_COLOR_RGB)
#edge_color = f'rgb({edge_color[0]},{edge_color[1]},{edge_color[2]})'
edge_trace = go.Scatter(
x=[x0, x1, None],
y=[y0, y1, None],
line=dict(width=edge_width, color=edge_color),
hoverinfo='none',
mode='lines',
showlegend=False,
name='_edge',
opacity=alpha_edges,
)
edge_traces.append(edge_trace)
# Create node traces
node_traces = []
for node in nodelist:
x,y = pos[node]
text = labels.get(node, None)
color = node_colors.get(node, None)
node_trace = go.Scatter(
x=[x],
y=[y],
text=text,
mode='markers+text',
textposition='top center',
hoverinfo='text',
showlegend=False,
name='_node',
marker=dict(
color=color,
size=node_size,
line_width=2,
opacity=alpha_markers,
),
textfont=dict(
color=f'rgba(0,0,0,{alpha_names})',
size=8,
family="Arial Black",
)
)
node_traces.append(node_trace)
# Get color dict
colors = {}
families = []
for node in node_families.keys():
family = node_families[node]
if family is not None:
families.append(family)
colors[family] = node_colors.get(node, DEFAULT_COLOR)
else:
colors[family] = DEFAULT_COLOR
families = set(families)
#Custom legend
legends = []
for f in families:
legends.append(
go.Scatter(
x=[None],
y=[None],
mode='markers',
marker=dict(
color=colors[f],
size=8,
line_width=2,
opacity=1
),
name=f,
)
)
# Create the figure
fig = go.Figure(
data=edge_traces + node_traces,
layout=go.Layout(
showlegend=True,
hovermode='closest',
margin=dict(b=1, l=1, r=1, t=1),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
legend=dict(
yanchor="top",
y=0.99,
xanchor="right",
x=0.99
)
)
)
fig.add_traces(legends)
return fig
def get_color(index):
"""Get a color from plotly's qualitative color palette."""
colors = px.colors.qualitative.Plotly
return colors[index % len(colors)]
def plot_tree(sim_matrix, models, families,colors, alpha_names=None, alpha_markers=None, alpha_edges=None):
"""
Plot a phylogenetic tree based on a similarity matrix.
Parameters:
- sim_matrix: similarity matrix between models
- models: list of model names
- families: list of family names for each model
Returns:
- fig: Plotly figure object with the phylogenetic tree
"""
# Create color mapping for families
# Prepare the distance matrix
dist_matrix = -np.log(np.maximum(sim_matrix, 1e-10)) # Avoid log(0)
# Prepare the data for Bio.Phylo
low_triangle_kl_mean = [[dist_matrix[i][j] for j in range(i+1)] for i in range(len(dist_matrix))]
df = _DistanceMatrix(names=models, matrix=low_triangle_kl_mean)
# Setup Bio.Phylo
calculator = DistanceCalculator('identity')
constructor = DistanceTreeConstructor(calculator, 'nj')
# Build the tree
NJTree = constructor.nj(df)
NJTree.ladderize(reverse=False)
# Color the tree
prepare_tree(NJTree, models, families, colors)
# Generate the plotly figure
fig = draw_graphviz(NJTree, node_size=15, edge_width=6,alpha_names=alpha_names, alpha_markers=alpha_markers, alpha_edges=alpha_edges)
return fig
def update_tree_fig(fig, model_names, model_search=None,alpha_names=None, alpha_markers=None, alpha_edges=None):
#Update nodes
fig.update_traces(
selector=dict(name='_node'),
marker=dict(
opacity=alpha_markers,
),
textfont=dict(
color=f'rgba(0,0,0,{alpha_names})',
)
)
# Update edges
fig.update_traces(
selector=dict(name='_edge'),
opacity=alpha_edges,
)
for d in fig.data:
if d.name in ['_node','node']:
if d.text == 'mistralai/Mistral-7B-Instruct-v0.1':
print(d)
# Update highlighted node
fig.update_traces(
selector=dict(name='node'),
marker=dict(
size=15, # Bigger than normal nodes
line=None # Red border
),
textfont=dict(
color=f'rgba(0,0,0,{alpha_names})', size=16, family="Arial Black",
),
name='_node'
)
if model_search in model_names:
fig.update_traces(
selector=dict(name='_node',text=model_search),
marker=dict(
size=22, # Bigger than normal nodes
line=dict(color='red', width=4) # Red border
),
textfont=dict(
color='red', size=16, family="Arial Black",
),
name='node'
)
for d in fig.data:
if d.name in ['_node','node']:
if d.text == 'mistralai/Mistral-7B-Instruct-v0.1':
print(d)
else:
pass
return fig