|
import networkx as nx |
|
import numpy as np |
|
from Bio.Phylo import to_networkx |
|
from networkx.drawing.nx_agraph import graphviz_layout |
|
import plotly.graph_objects as go |
|
import plotly.express as px |
|
from Bio.Phylo.TreeConstruction import DistanceTreeConstructor, DistanceCalculator, _DistanceMatrix |
|
|
|
from tools import compute_ordered_matrix,compute_umap |
|
from phylogeny import prepare_tree |
|
from constants import UNKNOWN_COLOR, DEFAULT_COLOR, UNKNOWN_COLOR_RGB, DEFAULT_COLOR_RGB |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_sim_matrix_fig(ordered_sim_matrix,ordered_model_names,families,colors): |
|
fig = px.imshow( |
|
ordered_sim_matrix, |
|
x=ordered_model_names, |
|
y=ordered_model_names, |
|
zmin=0, zmax=1, |
|
color_continuous_scale='gray', |
|
) |
|
|
|
fig.update_layout(coloraxis_colorbar=dict(title='Similarity'), |
|
margin=dict(l=0, r=0, t=0, b=0), |
|
autosize=True, |
|
) |
|
|
|
fig.update_traces( |
|
colorbar=dict( |
|
thickness=20, |
|
len=0.75, |
|
xanchor="right", |
|
x=1.02 |
|
) |
|
) |
|
|
|
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range') |
|
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range') |
|
|
|
|
|
rectX = go.layout.Shape( |
|
type="rect", |
|
xref="x", yref="y", |
|
x0=0, y0=0, |
|
x1=0, y1=0, |
|
line=dict(color="red", width=1), |
|
fillcolor="rgba(0,0,0,0)", |
|
name='rectX', |
|
opacity=0, |
|
) |
|
fig.add_shape(rectX) |
|
rectY = go.layout.Shape( |
|
type="rect", |
|
xref="x", yref="y", |
|
x0=0, y0=0, |
|
x1=0, y1=0, |
|
line=dict(color="red", width=1), |
|
fillcolor="rgba(0,0,0,0)", |
|
name='rectY', |
|
opacity=0, |
|
) |
|
fig.add_shape(rectY) |
|
|
|
return fig |
|
|
|
def update_sim_matrix_fig(fig, ordered_model_names, model_search_x=None, model_search_y=None): |
|
if model_search_x in ordered_model_names: |
|
idx_x = ordered_model_names.index(model_search_x) |
|
fig.update_shapes( |
|
selector=dict(name='rectX'), |
|
x0=idx_x-0.5, y0=-0.5, |
|
x1=idx_x+0.5, y1=len(ordered_model_names)-0.5, |
|
opacity=0.7, |
|
) |
|
else: |
|
fig.update_shapes( |
|
selector=dict(name='rectX'), |
|
opacity=0 |
|
) |
|
if model_search_y in ordered_model_names: |
|
idx_y = ordered_model_names.index(model_search_y) |
|
fig.update_shapes( |
|
selector=dict(name='rectY'), |
|
x0=-0.5, y0=idx_y-0.5, |
|
x1=len(ordered_model_names)-0.5, y1=idx_y+0.5, |
|
opacity=0.7, |
|
) |
|
else: |
|
fig.update_shapes( |
|
selector=dict(name='rectY'), |
|
opacity=0 |
|
) |
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def alpha_scaling(val): |
|
base = 0.35 |
|
return val**(1/(base+1/100)) |
|
|
|
def plot_umap_fig(dist_matrix, sim_matrix, model_names, families, colors,key='fig2',alpha_edges=None, alpha_names=None, alpha_markers=None): |
|
embedding = compute_umap(dist_matrix,d=2) |
|
|
|
fig = go.Figure() |
|
|
|
|
|
|
|
edges = [] |
|
for i in range(len(model_names)): |
|
for j in range(i+1, len(model_names)): |
|
val = alpha_scaling(sim_matrix[i][j]) |
|
if val > 0.1: |
|
edges.append((i, j, val, colors[families[i]])) |
|
|
|
|
|
for i, j, val, color in edges: |
|
fig.add_trace( |
|
go.Scatter( |
|
x=[embedding[i,0], embedding[j,0]], |
|
y=[embedding[i,1], embedding[j,1]], |
|
mode='lines', |
|
name='_edge', |
|
line=dict(color=color, width=val), |
|
opacity=alpha_edges, |
|
showlegend=False, |
|
hoverinfo='skip', |
|
) |
|
) |
|
|
|
|
|
marker_colors = [colors[f] for f in families] |
|
fig.add_trace( |
|
go.Scatter( |
|
x=embedding[:,0], |
|
y=embedding[:,1], |
|
text=model_names, |
|
mode='markers+text', |
|
textposition='top center', |
|
hoverinfo='text', |
|
hoveron='points+fills', |
|
showlegend=False, |
|
name='_node', |
|
marker=dict( |
|
color=marker_colors, |
|
size=8, |
|
line_width=2, |
|
opacity=alpha_markers, |
|
), |
|
textfont=dict( |
|
color=f'rgba(0,0,0,{alpha_names})', |
|
size=8, |
|
family="Arial Black", |
|
) |
|
) |
|
) |
|
|
|
|
|
legends = [] |
|
for f in set(families): |
|
legends.append( |
|
go.Scatter( |
|
x=[None], |
|
y=[None], |
|
mode='markers', |
|
marker=dict( |
|
color=colors[f], |
|
size=8, |
|
line_width=2, |
|
opacity=1 |
|
), |
|
name=f, |
|
|
|
) |
|
) |
|
fig.add_traces(legends) |
|
|
|
|
|
node = go.Scatter( |
|
x=[0], |
|
y=[0], |
|
mode='markers+text', |
|
textposition='top center', |
|
textfont=dict(color='red', size=16, family="Arial Black"), |
|
marker=dict( |
|
color='red', |
|
size=12, |
|
symbol='circle', |
|
line=dict(color='red', width=3) |
|
), |
|
showlegend=False, |
|
name='node', |
|
opacity=0, |
|
) |
|
fig.add_trace(node) |
|
|
|
|
|
fig.update_layout( |
|
margin=dict(l=0, r=0, t=0, b=0), |
|
autosize=True, |
|
) |
|
|
|
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range') |
|
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range') |
|
|
|
return fig |
|
|
|
def update_umap_fig(fig, dist_matrix, model_names, families, colors, model_search_x=None, alpha_names=None, alpha_markers=None, alpha_edges=None, key='fig2'): |
|
|
|
fig.update_traces( |
|
selector=dict(name='_node'), |
|
textfont=dict( |
|
color=f'rgba(0,0,0,{alpha_names})', |
|
), |
|
marker=dict( |
|
opacity=alpha_markers |
|
), |
|
) |
|
|
|
|
|
fig.update_traces( |
|
selector=dict(mode='lines'), |
|
line=dict(width=1), |
|
opacity=alpha_edges |
|
) |
|
|
|
|
|
if model_search_x in model_names: |
|
searched_idx = model_names.index(model_search_x) |
|
embedding = compute_umap(dist_matrix,d=2) |
|
fig.update_traces( |
|
selector=dict(name='node'), |
|
x=[embedding[searched_idx,0]], |
|
y=[embedding[searched_idx,1]], |
|
text=[model_search_x], |
|
marker=dict( |
|
color=colors[families[searched_idx]], |
|
), |
|
hovertext=model_search_x, |
|
opacity=1 |
|
) |
|
else: |
|
fig.update_traces( |
|
selector=dict(name='node'), |
|
x=[0], |
|
y=[0], |
|
text=[''], |
|
opacity=0 |
|
) |
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def draw_graphviz(tree, label_func=str, prog='twopi', args='', |
|
node_size=15, edge_width=0.0, alpha_edges=None, alpha_names=None,alpha_markers=None, **kwargs): |
|
|
|
|
|
global UNKNOWN_COLOR, DEFAULT_COLOR |
|
|
|
G = to_networkx(tree) |
|
|
|
|
|
Gi = nx.convert_node_labels_to_integers(G, label_attribute='label') |
|
|
|
|
|
pos = graphviz_layout(Gi, prog=prog, args=args) |
|
|
|
|
|
def get_label_mapping(G, selection): |
|
for node, data in G.nodes(data=True): |
|
if (selection is None) or (node in selection): |
|
try: |
|
label = label_func(data.get('label', node)) |
|
if label not in (None, node.__class__.__name__): |
|
yield (node, label) |
|
except (LookupError, AttributeError, ValueError): |
|
pass |
|
|
|
|
|
labels = dict(get_label_mapping(Gi, None)) |
|
nodelist = list(labels.keys()) |
|
|
|
|
|
edge_traces = [] |
|
node_traces_by_family = {} |
|
node_colors = {} |
|
node_families = {} |
|
|
|
|
|
searched_model_node = None |
|
searched_model_pos = None |
|
|
|
default_color = (0,0,0) |
|
|
|
|
|
for node in Gi.nodes(): |
|
node_data = Gi.nodes[node].get('label') |
|
if hasattr(node_data, 'color'): |
|
node_colors[node] = node_data.color.to_rgb() if not(node_data.color is None) else default_color |
|
else: |
|
node_colors[node] = default_color |
|
node_colors[node] = f'rgb({node_colors[node][0]},{node_colors[node][1]},{node_colors[node][2]})' |
|
|
|
if hasattr(node_data, 'family'): |
|
node_families[node] = node_data.family |
|
else: |
|
node_families[node] = None |
|
|
|
|
|
for edge in Gi.edges(): |
|
x0, y0 = pos[edge[0]] |
|
x1, y1 = pos[edge[1]] |
|
|
|
|
|
edge_color = node_colors[edge[1]] |
|
if list(edge_color) == list(UNKNOWN_COLOR_RGB): |
|
edge_color = tuple(DEFAULT_COLOR_RGB) |
|
|
|
edge_trace = go.Scatter( |
|
x=[x0, x1, None], |
|
y=[y0, y1, None], |
|
line=dict(width=edge_width, color=edge_color), |
|
hoverinfo='none', |
|
mode='lines', |
|
showlegend=False, |
|
name='_edge', |
|
opacity=alpha_edges, |
|
) |
|
edge_traces.append(edge_trace) |
|
|
|
|
|
node_traces = [] |
|
for node in nodelist: |
|
x,y = pos[node] |
|
text = labels.get(node, None) |
|
color = node_colors.get(node, None) |
|
node_trace = go.Scatter( |
|
x=[x], |
|
y=[y], |
|
text=text, |
|
mode='markers+text', |
|
textposition='top center', |
|
hoverinfo='text', |
|
showlegend=False, |
|
name='_node', |
|
marker=dict( |
|
color=color, |
|
size=node_size, |
|
line_width=2, |
|
opacity=alpha_markers, |
|
), |
|
textfont=dict( |
|
color=f'rgba(0,0,0,{alpha_names})', |
|
size=8, |
|
family="Arial Black", |
|
) |
|
) |
|
node_traces.append(node_trace) |
|
|
|
|
|
colors = {} |
|
families = [] |
|
for node in node_families.keys(): |
|
family = node_families[node] |
|
if family is not None: |
|
families.append(family) |
|
colors[family] = node_colors.get(node, DEFAULT_COLOR) |
|
else: |
|
colors[family] = DEFAULT_COLOR |
|
|
|
families = set(families) |
|
|
|
|
|
legends = [] |
|
for f in families: |
|
legends.append( |
|
go.Scatter( |
|
x=[None], |
|
y=[None], |
|
mode='markers', |
|
marker=dict( |
|
color=colors[f], |
|
size=8, |
|
line_width=2, |
|
opacity=1 |
|
), |
|
name=f, |
|
|
|
) |
|
) |
|
|
|
|
|
fig = go.Figure( |
|
data=edge_traces + node_traces, |
|
layout=go.Layout( |
|
showlegend=True, |
|
hovermode='closest', |
|
margin=dict(b=1, l=1, r=1, t=1), |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
legend=dict( |
|
yanchor="top", |
|
y=0.99, |
|
xanchor="right", |
|
x=0.99 |
|
) |
|
) |
|
) |
|
|
|
fig.add_traces(legends) |
|
|
|
return fig |
|
|
|
def get_color(index): |
|
"""Get a color from plotly's qualitative color palette.""" |
|
colors = px.colors.qualitative.Plotly |
|
return colors[index % len(colors)] |
|
|
|
def plot_tree(sim_matrix, models, families,colors, alpha_names=None, alpha_markers=None, alpha_edges=None): |
|
""" |
|
Plot a phylogenetic tree based on a similarity matrix. |
|
|
|
Parameters: |
|
- sim_matrix: similarity matrix between models |
|
- models: list of model names |
|
- families: list of family names for each model |
|
|
|
Returns: |
|
- fig: Plotly figure object with the phylogenetic tree |
|
""" |
|
|
|
|
|
|
|
dist_matrix = -np.log(np.maximum(sim_matrix, 1e-10)) |
|
|
|
|
|
low_triangle_kl_mean = [[dist_matrix[i][j] for j in range(i+1)] for i in range(len(dist_matrix))] |
|
df = _DistanceMatrix(names=models, matrix=low_triangle_kl_mean) |
|
|
|
|
|
calculator = DistanceCalculator('identity') |
|
constructor = DistanceTreeConstructor(calculator, 'nj') |
|
|
|
|
|
NJTree = constructor.nj(df) |
|
NJTree.ladderize(reverse=False) |
|
|
|
|
|
prepare_tree(NJTree, models, families, colors) |
|
|
|
|
|
fig = draw_graphviz(NJTree, node_size=15, edge_width=6,alpha_names=alpha_names, alpha_markers=alpha_markers, alpha_edges=alpha_edges) |
|
|
|
return fig |
|
|
|
def update_tree_fig(fig, model_names, model_search=None,alpha_names=None, alpha_markers=None, alpha_edges=None): |
|
|
|
fig.update_traces( |
|
selector=dict(name='_node'), |
|
marker=dict( |
|
opacity=alpha_markers, |
|
), |
|
textfont=dict( |
|
color=f'rgba(0,0,0,{alpha_names})', |
|
) |
|
) |
|
|
|
|
|
fig.update_traces( |
|
selector=dict(name='_edge'), |
|
opacity=alpha_edges, |
|
) |
|
|
|
for d in fig.data: |
|
if d.name in ['_node','node']: |
|
if d.text == 'mistralai/Mistral-7B-Instruct-v0.1': |
|
print(d) |
|
|
|
|
|
fig.update_traces( |
|
selector=dict(name='node'), |
|
marker=dict( |
|
size=15, |
|
line=None |
|
), |
|
textfont=dict( |
|
color=f'rgba(0,0,0,{alpha_names})', size=16, family="Arial Black", |
|
), |
|
name='_node' |
|
) |
|
if model_search in model_names: |
|
fig.update_traces( |
|
selector=dict(name='_node',text=model_search), |
|
marker=dict( |
|
size=22, |
|
line=dict(color='red', width=4) |
|
), |
|
textfont=dict( |
|
color='red', size=16, family="Arial Black", |
|
), |
|
name='node' |
|
) |
|
for d in fig.data: |
|
if d.name in ['_node','node']: |
|
if d.text == 'mistralai/Mistral-7B-Instruct-v0.1': |
|
print(d) |
|
else: |
|
pass |
|
|
|
return fig |