First version gradio
Browse files- app.py +396 -0
- constants.py +8 -0
- family_table.json +1 -0
- inputs/math.json +1 -0
- llm_run.py +48 -0
- loading.py +132 -0
- packages.txt +2 -0
- phylogeny.py +114 -0
- plotting.py +522 -0
- requirements.txt +15 -0
- tools.py +80 -0
app.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import ujson as json
|
5 |
+
|
6 |
+
from loading import load_data, save_git
|
7 |
+
from tools import compute_ordered_matrix
|
8 |
+
from plotting import plot_sim_matrix_fig, plot_umap_fig, plot_tree, update_sim_matrix_fig, update_umap_fig, update_tree_fig
|
9 |
+
from llm_run import download_llm_to_cache, load_model, llm_run
|
10 |
+
|
11 |
+
def reload_figures():
|
12 |
+
global MODEL_SEARCHED_X, MODEL_SEARCHED_Y, ALPHA_EDGES, ALPHA_NAMES, ALPHA_MARKERS, FIGS, ORDERED_MODEL_NAMES
|
13 |
+
fig1 = update_sim_matrix_fig(FIGS['fig1'], ORDERED_MODEL_NAMES, model_search_x=MODEL_SEARCHED_X, model_search_y=MODEL_SEARCHED_Y)
|
14 |
+
fig2 = update_umap_fig(FIGS['fig2'], DIST_MATRIX, MODEL_NAMES, FAMILIES, COLORS, model_search_x=MODEL_SEARCHED_X, alpha_edges=ALPHA_EDGES['fig2'], alpha_names=ALPHA_NAMES['fig2'], alpha_markers=ALPHA_MARKERS['fig2'])
|
15 |
+
fig4 = update_tree_fig(FIGS['fig4'], MODEL_NAMES, model_search=MODEL_SEARCHED_X, alpha_edges=ALPHA_EDGES['fig4'], alpha_names=ALPHA_NAMES['fig4'], alpha_markers=ALPHA_MARKERS['fig4'])
|
16 |
+
return [fig1,fig2,fig4]
|
17 |
+
|
18 |
+
def search_bar_changeX(value):
|
19 |
+
global MODEL_SEARCHED_X
|
20 |
+
MODEL_SEARCHED_X = value
|
21 |
+
return reload_figures()
|
22 |
+
|
23 |
+
def search_bar_changeY(value):
|
24 |
+
global MODEL_SEARCHED_Y
|
25 |
+
MODEL_SEARCHED_Y = value
|
26 |
+
return reload_figures()
|
27 |
+
|
28 |
+
def slider_changeAlphaMarkers(value,key):
|
29 |
+
global ALPHA_MARKERS
|
30 |
+
ALPHA_MARKERS[key] = value
|
31 |
+
return reload_figures()
|
32 |
+
|
33 |
+
def slider_changeAlphaNames(value,key):
|
34 |
+
global ALPHA_NAMES
|
35 |
+
ALPHA_NAMES[key] = value
|
36 |
+
return reload_figures()
|
37 |
+
|
38 |
+
def slider_changeAlphaEdges(value,key):
|
39 |
+
global ALPHA_EDGES
|
40 |
+
ALPHA_EDGES[key] = value
|
41 |
+
return reload_figures()
|
42 |
+
|
43 |
+
def search_bar_gr(model_names,slider=True,double_search=False,key=None):
|
44 |
+
global MODEL_SEARCHED_X,MODEL_SEARCHED_Y,ALPHA_EDGES,ALPHA_NAMES, ALPHA_MARKERS
|
45 |
+
#col1,col2 = gr.Row([0.2,0.8])
|
46 |
+
ret = []
|
47 |
+
with gr.Column(scale=1) as col1:
|
48 |
+
with gr.Group():
|
49 |
+
if MODEL_SEARCHED_X is None:
|
50 |
+
index = 0
|
51 |
+
else:
|
52 |
+
index = model_names.index(MODEL_SEARCHED_X)
|
53 |
+
ms_x = gr.Dropdown(label='Search'+(' X' if double_search else ''),choices=model_names,value=model_names[index],key='model_search_x_'+key,interactive=True)
|
54 |
+
#set MODEL_SEARCH_X
|
55 |
+
ret.append(ms_x)
|
56 |
+
if double_search:
|
57 |
+
if MODEL_SEARCHED_Y is None:
|
58 |
+
index = 0
|
59 |
+
else:
|
60 |
+
index = model_names.index(MODEL_SEARCHED_Y)
|
61 |
+
ms_y = gr.Dropdown(label='Search Y',choices=model_names,value=model_names[index],key='model_search_y_'+key,interactive=True)
|
62 |
+
ret.append(ms_y)
|
63 |
+
if slider:
|
64 |
+
with gr.Group():
|
65 |
+
values = np.arange(0, 1.05,0.05)
|
66 |
+
#truncate values to the 100th
|
67 |
+
values = np.round(values,2)
|
68 |
+
alpha_edges = gr.Slider(label='Alpha Edges',
|
69 |
+
minimum=0,
|
70 |
+
maximum=1,
|
71 |
+
step=0.05,
|
72 |
+
value=ALPHA_EDGES[key],
|
73 |
+
key='alpha_edges_'+key,
|
74 |
+
interactive=True)
|
75 |
+
|
76 |
+
values = np.arange(0, 1.05,0.05)
|
77 |
+
#truncate values to the 100th
|
78 |
+
values = np.round(values,2)
|
79 |
+
alpha_names = gr.Slider(label='Alpha Names',
|
80 |
+
minimum=0,
|
81 |
+
maximum=1,
|
82 |
+
step=0.05,
|
83 |
+
value=ALPHA_NAMES[key],
|
84 |
+
key='alpha_names_'+key,
|
85 |
+
interactive=True)
|
86 |
+
|
87 |
+
values = np.arange(0, 1.05,0.05)
|
88 |
+
#truncate values to the 100th
|
89 |
+
values = np.round(values,2)
|
90 |
+
alpha_markers = gr.Slider(label='Alpha Markers',
|
91 |
+
minimum=0,
|
92 |
+
maximum=1,
|
93 |
+
step=0.05,
|
94 |
+
value=ALPHA_MARKERS[key],
|
95 |
+
key='alpha_markers_'+key,
|
96 |
+
interactive=True)
|
97 |
+
ret.append(alpha_edges)
|
98 |
+
ret.append(alpha_names)
|
99 |
+
ret.append(alpha_markers)
|
100 |
+
col2 = gr.Column(scale=5)
|
101 |
+
ret.insert(0,col2)
|
102 |
+
return ret
|
103 |
+
|
104 |
+
import spaces
|
105 |
+
@spaces.GPU(duration=300)
|
106 |
+
def _run(path,genes,N,progress_bar):
|
107 |
+
#Load the model
|
108 |
+
progress_bar(0.20, desc="Loading Model...",total=100)
|
109 |
+
try:
|
110 |
+
model,tokenizer = load_model(path)
|
111 |
+
except ValueError as e:
|
112 |
+
print(f"Error loading model '{path}': {e}")
|
113 |
+
gr.Warning("Model couldn't load. This space currently only works with AutoModelForCausalLM models. Please check the model architecture and try again.")
|
114 |
+
return None
|
115 |
+
except OSError as e:
|
116 |
+
print(f"Error loading model '{path}': {e}")
|
117 |
+
gr.Warning("Model doesn't seem to exist on the HuggingFace Hub. Please check the model name and try again.")
|
118 |
+
return None
|
119 |
+
except RuntimeError as e:
|
120 |
+
if 'out of memory' in str(e):
|
121 |
+
print(f"Error loading model '{path}': {e}")
|
122 |
+
gr.Warning("Loading the model triggered an out of memory error. It may be too big for the GPU (80Go RAM). Please try again with a smaller model.")
|
123 |
+
return None
|
124 |
+
else:
|
125 |
+
print(f"Error loading model '{path}': {e}")
|
126 |
+
gr.Warning("Model couldn't be loaded. Please check the logs or report an issue.")
|
127 |
+
return None
|
128 |
+
except Exception as e:
|
129 |
+
print(f"Error loading model '{path}': {e}")
|
130 |
+
gr.Warning("Model couldn't be loaded. Please check logs or report an issue.")
|
131 |
+
return None
|
132 |
+
progress_bar(0.25, desc="Generating data...",total=100)
|
133 |
+
for i,output in enumerate(llm_run(model,tokenizer,genes,N)):
|
134 |
+
progress_bar(0.25 + i*(70/len(genes))/100, desc=f"Generating data... {i+1}/{len(genes)}",total=100)
|
135 |
+
return output
|
136 |
+
|
137 |
+
def run(path,progress_bar):
|
138 |
+
global DEFAULT_FAMILY_NAME, PHYLOLM_N
|
139 |
+
family = DEFAULT_FAMILY_NAME
|
140 |
+
N = PHYLOLM_N
|
141 |
+
#Loading bar
|
142 |
+
progress_bar(0, desc="Downloading model...",total=100)
|
143 |
+
try:
|
144 |
+
# Download the model to cache
|
145 |
+
if download_llm_to_cache(path) is None:
|
146 |
+
gr.Warning("Model not found on Hugging Face Hub. Please check the model name and try again.")
|
147 |
+
return None
|
148 |
+
except OSError as e:
|
149 |
+
print(f"Error downloading model: {e}")
|
150 |
+
gr.Warning("Model not found on Hugging Face Hub. Please check the model name and try again.")
|
151 |
+
return None
|
152 |
+
|
153 |
+
# Load the model
|
154 |
+
progress_bar(0.10, desc="Loading contexts...",total=100)
|
155 |
+
|
156 |
+
with open('inputs/math.json', 'r') as f:
|
157 |
+
genes = json.load(f)
|
158 |
+
|
159 |
+
# Load the model and run
|
160 |
+
progress_bar(0.15, desc="Waiting for GPU...",total=100)
|
161 |
+
|
162 |
+
try:
|
163 |
+
output = _run(path,genes,N,progress_bar)
|
164 |
+
if output is None:
|
165 |
+
return None
|
166 |
+
except Exception as e:
|
167 |
+
print(f"Error running model: {e}")
|
168 |
+
gr.Warning("Something unexpected happened during the run or the loading of the model. Please check the logs or report an issue.")
|
169 |
+
return None
|
170 |
+
|
171 |
+
progress_bar(0.95, desc="Saving data ...",total=100)
|
172 |
+
|
173 |
+
alleles = [[compl[j]['generated_text'][len(gene):][:4] for j in range(len(compl))] for gene,compl in zip(genes,output)]
|
174 |
+
save_git(alleles,genes,path,family)
|
175 |
+
|
176 |
+
progress_bar(1, desc="Done!",total=100)
|
177 |
+
|
178 |
+
|
179 |
+
def prepare_run(model_name,progress_bar=gr.Progress()):
|
180 |
+
global MODEL_SEARCHED_X,MODEL_NAMES
|
181 |
+
if model_name in MODEL_NAMES:
|
182 |
+
gr.Warning('Model already exists in the database.')
|
183 |
+
MODEL_SEARCHED_X = model_name
|
184 |
+
reload_figures()
|
185 |
+
return
|
186 |
+
run(model_name,progress_bar)
|
187 |
+
|
188 |
+
def reload_env():
|
189 |
+
global SIM_MAT_SEARCH_X, SIM_MAT_SEARCH_Y, VIZ_SEARCH, TREE_SEARCH
|
190 |
+
global MODEL_NAMES, FAMILIES, COLORS, SIM_MATRIX, DIST_MATRIX
|
191 |
+
global FIGS, FIGS_OBJECTS
|
192 |
+
|
193 |
+
# Load models for the dropdown
|
194 |
+
data, model_names, families, sim_matrix, colors = load_data()
|
195 |
+
|
196 |
+
sim_matrix_safe = np.where(sim_matrix == 0, np.finfo(np.float64).eps, sim_matrix)
|
197 |
+
dist_matrix = -np.log(sim_matrix_safe)
|
198 |
+
|
199 |
+
#Set globals
|
200 |
+
MODEL_NAMES = model_names
|
201 |
+
FAMILIES = families
|
202 |
+
COLORS = colors
|
203 |
+
SIM_MATRIX = sim_matrix
|
204 |
+
DIST_MATRIX = dist_matrix
|
205 |
+
|
206 |
+
#Update Figs
|
207 |
+
ordered_sim_matrix, ordered_model_names = compute_ordered_matrix(sim_matrix,dist_matrix, model_names)
|
208 |
+
ORDERED_MODEL_NAMES = ordered_model_names
|
209 |
+
FIGS['fig1'] = plot_sim_matrix_fig(ordered_sim_matrix, ordered_model_names, families, colors)
|
210 |
+
FIGS['fig2'] = plot_umap_fig(dist_matrix, sim_matrix, model_names, families, colors,
|
211 |
+
alpha_edges=ALPHA_EDGES['fig2'],alpha_names=ALPHA_NAMES['fig2'],alpha_markers=ALPHA_MARKERS['fig2'])
|
212 |
+
FIGS['fig4'] = plot_tree(sim_matrix, model_names, families, colors,alpha_edges=ALPHA_EDGES['fig4'],alpha_names=ALPHA_NAMES['fig4'],alpha_markers=ALPHA_MARKERS['fig4'])
|
213 |
+
|
214 |
+
#Update search bars
|
215 |
+
sim_mat_search_x = gr.Dropdown(label='Search X',choices=model_names,value=model_names[0],key='model_search_x_fig1',interactive=True)
|
216 |
+
sim_mat_search_y = gr.Dropdown(label='Search Y',choices=model_names,value=model_names[0],key='model_search_y_fig1',interactive=True)
|
217 |
+
viz_search = gr.Dropdown(label='Search',choices=model_names,value=model_names[0],key='model_search_fig2',interactive=True)
|
218 |
+
tree_search = gr.Dropdown(label='Search',choices=model_names,value=model_names[0],key='model_search_fig4',interactive=True)
|
219 |
+
|
220 |
+
return FIGS['fig1'], FIGS['fig2'], FIGS['fig4'], sim_mat_search_x, sim_mat_search_y, viz_search, tree_search
|
221 |
+
|
222 |
+
|
223 |
+
|
224 |
+
# Load environment variables
|
225 |
+
|
226 |
+
USERNAME = os.environ['GITHUB_USERNAME']
|
227 |
+
TOKEN = os.environ['GITHUB_TOKEN']
|
228 |
+
MAIL = os.environ['GITHUB_MAIL']
|
229 |
+
|
230 |
+
MODEL_SEARCHED_X = None
|
231 |
+
MODEL_SEARCHED_Y = None
|
232 |
+
ALPHA_EDGES = {'fig2':0.05, 'fig3':0.05,'fig4':1.0}
|
233 |
+
ALPHA_NAMES = {'fig2':0.0, 'fig3':0.0,'fig4':0.0}
|
234 |
+
ALPHA_MARKERS = {'fig2':0.8, 'fig3':0.8,'fig4':1.0}
|
235 |
+
|
236 |
+
FIGS = {'fig1':None,'fig2':None,'fig3':None,'fig4':None}
|
237 |
+
FIGS_OBJECTS = [None,None,None]
|
238 |
+
MODEL_NAMES = None
|
239 |
+
FAMILIES = None
|
240 |
+
COLORS = None
|
241 |
+
ORDERED_MODEL_NAMES = None
|
242 |
+
SIM_MATRIX = None
|
243 |
+
DIST_MATRIX = None
|
244 |
+
|
245 |
+
DEFAULT_FAMILY_NAME = '?'
|
246 |
+
PHYLOLM_N = 32
|
247 |
+
|
248 |
+
SIM_MAT_SEARCH_X = None
|
249 |
+
SIM_MAT_SEARCH_Y = None
|
250 |
+
VIZ_SEARCH = None
|
251 |
+
TREE_SEARCH = None
|
252 |
+
|
253 |
+
# Build the Gradio interface
|
254 |
+
with gr.Blocks(title="PhyloLM", theme=gr.themes.Default()) as demo:
|
255 |
+
gr.Markdown("# PhyloLM: Phylogenetic Mapping of Language Models")
|
256 |
+
|
257 |
+
gr.Markdown(
|
258 |
+
"Welcome to PhyloLM ([paper](https://arxiv.org/abs/2404.04671) - [code](https://github.com/Nicolas-Yax/PhyloLM)) — a tool for comparing language models based on their **behavioral similarity**, inspired by methods from comparative genomics. "
|
259 |
+
"Instead of architecture or weights, we use output behavior on diagnostic prompts as a behavioral fingerprint to compute a distance metric, akin to how biologists compare species using genetic data. This makes it possible to draw a unique map of all LLMs (various architectures, gated and non gated, ...)."
|
260 |
+
"The goal of this space is to create a collaborative space where everyone can visualize these maps and extend them with models of their choice. "
|
261 |
+
)
|
262 |
+
|
263 |
+
gr.Markdown("## Explore Maps of Models")
|
264 |
+
|
265 |
+
gr.Markdown(
|
266 |
+
"This interactive space allows users to explore model similarities through four types of visualizations:\n"
|
267 |
+
"- A similarity matrix (values range from 0 = dissimilar to 1 = highly similar). \n"
|
268 |
+
"- 2D and 3D scatter plots representing how close or far from each other LLMs are (plotted using UMAP). \n"
|
269 |
+
"- A tree to visualize distances between models (distance from leaf A to leaf B in the tree is similar to the distance between the two models)\n\n"
|
270 |
+
)
|
271 |
+
|
272 |
+
# Load models for the dropdown
|
273 |
+
data, model_names, families, sim_matrix, colors = load_data()
|
274 |
+
|
275 |
+
sim_matrix_safe = np.where(sim_matrix == 0, np.finfo(np.float64).eps, sim_matrix)
|
276 |
+
dist_matrix = -np.log(sim_matrix_safe)
|
277 |
+
|
278 |
+
#Set globals
|
279 |
+
MODEL_NAMES = model_names
|
280 |
+
FAMILIES = families
|
281 |
+
COLORS = colors
|
282 |
+
SIM_MATRIX = sim_matrix
|
283 |
+
DIST_MATRIX = dist_matrix
|
284 |
+
|
285 |
+
# Create the tabs
|
286 |
+
tab_state = gr.State(value="Similarity Matrix") # Default tab
|
287 |
+
tabs = gr.Tabs(["Similarity Matrix", "2D Visualization","Tree Visualization"])
|
288 |
+
with tabs:
|
289 |
+
with gr.TabItem("Similarity Matrix"):
|
290 |
+
# Similarity matrix visualization
|
291 |
+
with gr.Row():
|
292 |
+
col2,sim_mat_search_x,sim_mat_search_y = search_bar_gr(model_names,slider=False,double_search=True,key='fig1')
|
293 |
+
with col2:
|
294 |
+
ordered_sim_matrix, ordered_model_names = compute_ordered_matrix(sim_matrix,dist_matrix, model_names)
|
295 |
+
fig = plot_sim_matrix_fig(ordered_sim_matrix, ordered_model_names, families, colors)
|
296 |
+
sim_matrix_output = gr.Plot(fig,label="Similarity Matrix")
|
297 |
+
FIGS['fig1'] = fig
|
298 |
+
ORDERED_MODEL_NAMES = ordered_model_names
|
299 |
+
FIGS_OBJECTS[0] = sim_matrix_output
|
300 |
+
with gr.TabItem("2D Visualization"):
|
301 |
+
# 2D visualization
|
302 |
+
with gr.Row():
|
303 |
+
col2,viz_search,viz_alpha_edge,viz_alpha_name,viz_alpha_marker = search_bar_gr(model_names,slider=True,double_search=False,key='fig2')
|
304 |
+
with col2:
|
305 |
+
fig = plot_umap_fig(dist_matrix, sim_matrix, model_names, families, colors,
|
306 |
+
alpha_edges=ALPHA_EDGES['fig2'],alpha_names=ALPHA_NAMES['fig2'],alpha_markers=ALPHA_MARKERS['fig2'])
|
307 |
+
plot_output = gr.Plot(fig,label="2D Visualization")
|
308 |
+
FIGS['fig2'] = fig
|
309 |
+
FIGS_OBJECTS[1] = plot_output
|
310 |
+
with gr.TabItem("Tree Visualization"):
|
311 |
+
# Tree visualization
|
312 |
+
with gr.Row():
|
313 |
+
col2,tree_search,tree_alpha_edge,tree_alpha_name,tree_alpha_marker = search_bar_gr(model_names,slider=True,double_search=False,key='fig4')
|
314 |
+
with col2:
|
315 |
+
fig = plot_tree(sim_matrix, model_names, families, colors,alpha_edges=ALPHA_EDGES['fig4'],alpha_names=ALPHA_NAMES['fig4'],alpha_markers=ALPHA_MARKERS['fig4'])
|
316 |
+
tree_output = gr.Plot(fig,label="Tree Visualization")
|
317 |
+
FIGS['fig4'] = fig
|
318 |
+
FIGS_OBJECTS[2] = tree_output
|
319 |
+
|
320 |
+
|
321 |
+
# Submit model section
|
322 |
+
gr.Markdown("## Submitting a Model")
|
323 |
+
|
324 |
+
gr.Markdown(
|
325 |
+
"You may contribute new models to this collaborative space using compute resources. "
|
326 |
+
"Once processed, the model will be compared to existing ones, and its results added to a shared public database. "
|
327 |
+
"Model families (e.g., LLaMA, OPT, Mistral) are extracted from Hugging Face model cards and used only for visualization (e.g., coloring plots); they are **not** involved in the computation of similarity."
|
328 |
+
)
|
329 |
+
|
330 |
+
gr.Markdown(
|
331 |
+
"**To add a new model:**\n"
|
332 |
+
"1. Enter the name of a model hosted on Hugging Face (e.g., `'mistralai/Mistral-7B-Instruct-v0.3'`).\n"
|
333 |
+
"2. Click on the **Run PhyloLM** button.\n"
|
334 |
+
"- If the model has already been processed, you'll be notified and no new run will start.\n"
|
335 |
+
"- If it hasn't been processed, it will be downloaded and be evaluated.\n\n"
|
336 |
+
"⚠️ Be careful when submitting large LLMs (typically >15B parameters) as they may exceed the GPU RAM or the time limit, leading to failed runs."
|
337 |
+
)
|
338 |
+
|
339 |
+
with gr.Group():
|
340 |
+
model_input = gr.Textbox(label="Model", interactive=True)
|
341 |
+
submit_btn = gr.Button("Run PhyloLM", variant="primary")
|
342 |
+
|
343 |
+
|
344 |
+
# Disclaimer and citation
|
345 |
+
gr.Markdown("## Disclaimer")
|
346 |
+
gr.Markdown(
|
347 |
+
"This is a research prototype and may contain bugs or limitations. "
|
348 |
+
"All computed data are public and hosted on [GitHub](https://github.com/PhyloLM/Data). "
|
349 |
+
"If you'd like to contribute additional models — especially for gated or large models that cannot be processed via the web interface — "
|
350 |
+
"you are welcome to submit a pull request to the repository cited above. "
|
351 |
+
"All results are computed on the 'Math' set of genes used in the original paper."
|
352 |
+
)
|
353 |
+
|
354 |
+
gr.Markdown("## Citation")
|
355 |
+
gr.Markdown("If you find this project useful for your research, please consider citing the following paper:")
|
356 |
+
|
357 |
+
#bibtex
|
358 |
+
gr.Code('''@inproceedings{
|
359 |
+
yax2025phylolm,
|
360 |
+
title={Phylo{LM}: Inferring the Phylogeny of Large Language Models and Predicting their Performances in Benchmarks},
|
361 |
+
author={Nicolas Yax and Pierre-Yves Oudeyer and Stefano Palminteri},
|
362 |
+
booktitle={The Thirteenth International Conference on Learning Representations},
|
363 |
+
year={2025},
|
364 |
+
url={https://openreview.net/forum?id=rTQNGQxm4K}
|
365 |
+
}''',language=None)
|
366 |
+
|
367 |
+
# Change actions from search bars
|
368 |
+
sim_mat_search_x.change(fn=search_bar_changeX, inputs=sim_mat_search_x, outputs=FIGS_OBJECTS)
|
369 |
+
sim_mat_search_y.change(fn=search_bar_changeY, inputs=sim_mat_search_y, outputs=FIGS_OBJECTS)
|
370 |
+
|
371 |
+
viz_search.change(fn=search_bar_changeX, inputs=viz_search, outputs=FIGS_OBJECTS)
|
372 |
+
|
373 |
+
tree_search.change(fn=search_bar_changeX, inputs=tree_search, outputs=FIGS_OBJECTS)
|
374 |
+
|
375 |
+
# Change actions from sliders
|
376 |
+
viz_alpha_edge.change(fn=lambda x : slider_changeAlphaEdges(x,'fig2'), inputs=viz_alpha_edge, outputs=FIGS_OBJECTS)
|
377 |
+
viz_alpha_name.change(fn=lambda x : slider_changeAlphaNames(x,'fig2'), inputs=viz_alpha_name, outputs=FIGS_OBJECTS)
|
378 |
+
viz_alpha_marker.change(fn=lambda x : slider_changeAlphaMarkers(x,'fig2'), inputs=viz_alpha_marker, outputs=FIGS_OBJECTS)
|
379 |
+
|
380 |
+
tree_alpha_edge.change(fn=lambda x : slider_changeAlphaEdges(x,'fig4'), inputs=tree_alpha_edge, outputs=FIGS_OBJECTS)
|
381 |
+
tree_alpha_name.change(fn=lambda x : slider_changeAlphaNames(x,'fig4'), inputs=tree_alpha_name, outputs=FIGS_OBJECTS)
|
382 |
+
tree_alpha_marker.change(fn=lambda x : slider_changeAlphaMarkers(x,'fig4'), inputs=tree_alpha_marker, outputs=FIGS_OBJECTS)
|
383 |
+
|
384 |
+
# Run PhyloLM button
|
385 |
+
submit_btn.click(fn=prepare_run, inputs=[model_input], outputs=[model_input]).then(fn=reload_env, inputs=[], outputs=FIGS_OBJECTS+ [sim_mat_search_x, sim_mat_search_y, viz_search, tree_search])
|
386 |
+
|
387 |
+
#Set more globals
|
388 |
+
SIM_MAT_SEARCH_X = sim_mat_search_x
|
389 |
+
SIM_MAT_SEARCH_Y = sim_mat_search_y
|
390 |
+
VIZ_SEARCH = viz_search
|
391 |
+
TREE_SEARCH = tree_search
|
392 |
+
|
393 |
+
|
394 |
+
|
395 |
+
if __name__ == "__main__":
|
396 |
+
demo.launch()
|
constants.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from matplotlib import colors as mcolors
|
2 |
+
|
3 |
+
UNKNOWN_COLOR = 'gray'
|
4 |
+
UNKNOWN_COLOR_RGB = mcolors.to_rgb(UNKNOWN_COLOR)
|
5 |
+
UNKNOWN_COLOR_RGB = tuple([int(255 * c) for c in UNKNOWN_COLOR_RGB])
|
6 |
+
DEFAULT_COLOR = 'black'
|
7 |
+
DEFAULT_COLOR_RGB = mcolors.to_rgb(DEFAULT_COLOR)
|
8 |
+
DEFAULT_COLOR_RGB = tuple([int(255 * c) for c in DEFAULT_COLOR_RGB])
|
family_table.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"JosephusCheung\/Guanaco":"Llama","Intel\/neural-chat-7b-v3":"Mistral","Intel\/neural-chat-7b-v3-1":"Mistral","teknium\/OpenHermes-2-Mistral-7B":"Mistral","teknium\/OpenHermes-2.5-Mistral-7B":"Mistral","teknium\/OpenHermes-13B":"Llama","teknium\/OpenHermes-7B":"Llama","mistralai\/Mistral-7B-Instruct-v0.2":"Mistral","mistralai\/Mixtral-8x7B-Instruct-v0.1":"Mistral","mistralai\/Mistral-7B-v0.1":"Mistral","mistralai\/Mixtral-8x7B-v0.1":"Mistral","mistralai\/Mistral-7B-Instruct-v0.1":"Mistral","chavinlo\/alpaca-native":"Llama","CausalLM\/14B":"Qwen","CausalLM\/7B":"Qwen","bigscience\/bloom-3b":"Bloom","bigscience\/bloom-7b1":"Bloom","bigscience\/bloomz-3b":"Bloom","bigscience\/bloom":"Bloom","bigscience\/bloomz-7b1":"Bloom","berkeley-nest\/Starling-LM-7B-alpha":"Mistral","EleutherAI\/pythia-6.9b":"Pythia","EleutherAI\/pythia-1.4b":"Pythia","EleutherAI\/pythia-12b":"Pythia","EleutherAI\/pythia-2.8b":"Pythia","EleutherAI\/pythia-410m":"Pythia","EleutherAI\/pythia-70m":"Pythia","EleutherAI\/pythia-160m":"Pythia","roneneldan\/TinyStories-1M":"TinyStories","lmsys\/vicuna-13b-v1.5":"Llama","lmsys\/vicuna-7b-v1.1":"Llama","lmsys\/vicuna-13b-v1.3":"Llama","lmsys\/vicuna-7b-v1.5":"Llama","lmsys\/vicuna-13b-v1.1":"Llama","lmsys\/vicuna-7b-v1.3":"Llama","google\/gemma-7b":"Gemma","google\/codegemma-7b":"Gemma","google\/gemma-2b-it":"Gemma","google\/codegemma-2b":"Gemma","google\/codegemma-7b-it":"Gemma","google\/gemma-1.1-7b-it":"Gemma","google\/gemma-2b":"Gemma","google\/gemma-1.1-2b-it":"Gemma","google\/gemma-7b-it":"Gemma","microsoft\/Orca-2-13b":"Llama","microsoft\/Orca-2-7b":"Llama","Imran1\/MedChat3.5":"Mistral","tenyx\/TenyxChat-7B-v1":"Mistral","databricks\/dolly-v2-7b":"Pythia","databricks\/dolly-v2-3b":"Pythia","databricks\/dolly-v2-12b":"Pythia","Qwen\/Qwen-1_8B":"Qwen","Qwen\/Qwen1.5-0.5B":"Qwen","Qwen\/Qwen1.5-72B-Chat":"Qwen","Qwen\/Qwen1.5-7B-Chat":"Qwen","Qwen\/Qwen1.5-2B-Chat":"Qwen","Qwen\/Qwen1.5-7B":"Qwen","Qwen\/Qwen1.5-72B":"Qwen","Qwen\/Qwen1.5-32B-Chat":"Qwen","Qwen\/Qwen1.5-4B-Chat":"Qwen","Qwen\/Qwen1.5-1.8B":"Qwen","Qwen\/Qwen1.5-14B-Chat":"Qwen","Qwen\/Qwen1.5-0.5B-Chat":"Qwen","Qwen\/Qwen-72B":"Qwen","Qwen\/Qwen-14B":"Qwen","Qwen\/Qwen1.5-4B":"Qwen","Qwen\/Qwen1.5-14B":"Qwen","Qwen\/Qwen-7B":"Qwen","Qwen\/Qwen1.5-32B":"Qwen","OpenAssistant\/oasst-sft-4-pythia-12b-epoch-3.5":"Pythia","mlabonne\/NeuralHermes-2.5-Mistral-7B":"Mistral","facebook\/opt-6.7b":"OPT","facebook\/opt-125m":"OPT","facebook\/opt-66b":"OPT","facebook\/opt-13b":"OPT","facebook\/opt-1.3b":"OPT","facebook\/opt-30b":"OPT","facebook\/opt-350m":"OPT","facebook\/opt-2.7b":"OPT","HuggingFaceH4\/zephyr-7b-beta":"Mistral","HuggingFaceH4\/zephyr-7b-alpha":"Mistral","openchat\/openchat_v2":"Llama","openchat\/openchat_v2_w":"Llama","openchat\/openchat_v3.2":"Llama","openchat\/openchat_v3.1":"Llama","openchat\/openchat_3.5":"Mistral","openchat\/openchat_v3.2_super":"Llama","TigerResearch\/tigerbot-13b-base-v2":"Llama","TigerResearch\/tigerbot-7b-base-v2":"Bloom","TigerResearch\/tigerbot-7b-chat":"Llama","TigerResearch\/tigerbot-13b-chat-v1":"Llama","TigerResearch\/tigerbot-7b-sft-v1":"Bloom","TigerResearch\/tigerbot-7b-sft-v2":"Bloom","TigerResearch\/tigerbot-13b-chat-v2":"Llama","TigerResearch\/tigerbot-13b-chat-v3":"Llama","TigerResearch\/tigerbot-13b-chat-v4":"Llama","TigerResearch\/tigerbot-7b-base-v1":"Bloom","TigerResearch\/tigerbot-13b-base-v1":"Llama","TigerResearch\/tigerbot-7b-base":"Llama","fxmarty\/tiny-llama-fast-tokenizer":"fxmarty","project-baize\/baize-v2-7b":"Llama","huggyllama\/llama-7b":"Llama","huggyllama\/llama-13b":"Llama","Arc53\/docsgpt-7b-mistral":"Mistral","meta-llama\/Llama-2-7b-hf":"Llama","meta-llama\/Llama-2-13b-hf":"Llama","meta-llama\/Llama-2-7b":"Llama"}
|
inputs/math.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
["# In observing a Tetrahedron...", "# Strong and Weak Form Solution \u2013 FEA\n\nPartial Differential Equations \u2013 PDE is called \u201cstro", "# Loci Browse Articles\n\nDisplaying 41 - 50 of 323\n\nThis article describes methods for cr", "Paul's Online Notes\nHome / Calculus I / Review / Trig Functions\nShow Mobile Notice Show All Not", "# Recursive formula for joint moments in free probability\n\nSuppose $\\mathfrak{A}$ is an algebra ", "# Is the oxygen molecule $O_2$ fermion or boson?\n\nI ask something makes me confused.\n", "# Python Indices of numbers greater than K\n\nIn this tutorial, we are going to ", "To find total cost, to the nearest cent, to cool the house for this 24-hour p", "# Propositional Logic", "Modeling the train reservation kata -", "[texhax] environment ", "# stability \u2026 boring old and simple stability\n\n[xxx@yyy~]$uptime 15:07:51 up 505 days, 47 min", "By accessing our 180 Days of Math for Sixth Grade Answers Key Day 72 regularly, students can get b", "E. Square Root of Permutation\ntime limit per test\n2 seconds\nmemory l", "# Skills\n\nThe Discworld skill model is large and complex. It is broken up into eight branches.\n", "# Just Some Division\n\nNumber Theory Level 1\n\n$$N$$ is a positive integer such that $$10", "MathSciNet bibliographic data MR343259 55B25 (55G35 57E15) Matumoto, Takao Eq", "InTech uses cookies to offer you the best online experience. By continuing to use our sit", "## Train Tracks\n\nConsider a segment of", "## Stream: general\n\n### Topic: detecti", "This vignette discusses data.table\u2019s reference semant", "A cell made up of two hydrogen electrodes. The positive electrode is in cont", "# Better way to calculate coordinates in Tikz?\n\nI am great fan of pgf and tikz in general to pro", "# Physics (Version 8.4", "# Abc conjecture\n\n\ufeff\nAbc conjecture\n\nThe abc conjecture is a conjecture", "# No offense intended, but\u2026\n\nI have", "Select Board & Class\n\nAreas Related to Circles\n\nTo und", "Determine whether $\\sum_{n=1}^\\infty \\frac{\\sin^2 n}{n^2}$ conv", "# Calculating Derivative \u2013 Third root \u2013 Exercise 1106\n\nExercise\n\nFind the derivati", "# BE Thesis\n\nSenast inlagda poster:\n2018-09-06\n", "# All Questions\n\n1,360 questions\nFilter by\nSorted by\nTagged with\n1answer\n113 views\n\n###", "# Ignatius and the P", "## anonymous one year ago What is the fifth term of the sequence who", "size - Maple Help\n\nMTM\n\n size\n", "http://en.wikipedia.org/wiki/Taylor_series\n\n## Taylor series in several var", "# I Using determinant to find constraints on equation\n\nTags:\n1. Jan 15, 2017\n\n### TheDemx27\n", "# Decimal Numbers\n\nDecimal numbers are similar to fractions in basic principle.\n\nThey are often u", "Sales Toll Free No: 1-800-481-2338\n\n# How to Divide Monomial by a Non Zero Constant?\n\nTopPolynomial", "# Transformation of continuous uniform distribution\n", "CTF Team at the University of British Columbia\n\n# [corCTF 2021] smogofwar\n\n25 Aug 2021 ", "0\nResearch Papers\n\n# Analytic and Geometric ", "# Difference between revisions of \"1984 AIME Problems/Problem 11\"\n\n## Problem\n\nA ", "# Image Mosaicking\u00b6\n\n#", "# When two smaller atoms combine into a larger atom what has occurred?", "# Bessel functions\n\n(diff) \u2190 Older revision | La", "## WeBWorK Main Forum\n\n### Why Giga newton is not ", "### Homes\n\nThere are ", "## Return to Question\n\n2 deleted 256 characters in body\n\nA complex manifold $X$ is said ", "# Algebra Examples\n\nFind Pivot Positions and Pivot Co", "# The Physics Behind the American Death Tr", "Data\n\n1. Title: Webs and $q$-Howe dualities in types $\\mathbf{B}\\mathbf{C}\\math", "# All Questions\n\n1,524 questions\n", "# Pydon'ts\n\n## Improve your Python programming skills\n\n### Start here.\n\n294\n###", "# Math Help - 2-norm of a ma", "[texhax] \\mid Description\n\n", "# Q : 3\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0Find the area of the region bou", "as.epidata {EpiILM} R Documentation\n\nDiscrete Tim", "# Basic Matrix Row Operations Tips\n\nThere are four basic op", "# Nets within nets from the Grothendieck\u00a0const", "Free Version\nEasy\n\n# Same Base Exponential Equations\n\n", "# Preserving Plant Diversity\n\n## Location:", "# Crash on HWNDComponentPeer::destroyWindowCallback\n\nIt just feels like you\u2019r", "# SNR\u00b6\n\nclass gammapy.astro.source.SNR(e_sn='1e51 erg', theta=<Quantity 0.1>, n_ISM=<Qua", "# Forecasting Pseudo Random Numbers Using Deep Learning\n\nPublisher: IEEE\n\n", "# Including both transformed and original data (untransformed) in a multivariable linear regressio", "# Technical Fridays\n\nFriday, September 1, 2017\n\nIn 1973, the Un", "# Representation of simple groups\n\nLet $G$ be a finite simple group, prove th", "# The Impact of Meteorological Facto", "Advertisement Remove all ads\n\n# If\u00a0y=log[x+sqrt(x^2+a^2)]\u00a0show that\u00a0(x^2+a^2)(d^2y)/(dx^2)+xdy", "2020 | Book\n\n# Principles of Data M", "# Thread: Math problem Parenthesis & PEMDAS\n\n1. ## Math problem Parenthesis & PEMDAS\n\nhi i n", "# Math Help - Physics problem; Find expression for veloci", "# Base class for finite field element", "# 0.1 Review exercises (ch 3-13) \u00a0(Page 11/12)\n\n Page 11 / 12\n\n130. Out", "All Rights Reserved. However, a compass needle will not be steady in the magneti", "# How is rest mass $m_0$ in $E=m_0c^2$ related to mass $m$ in $F=ma$?\n\nA p", "# All Questions\n\n7 views\n\n### Curve fitting of a list\n\nI have list obtained using a", "1 $\\begingroup$ Close", "# Annual income of A a", "## Seminars and Colloquia by Series\n\n### Geometric Equations for Matroid Varieties\n\nSeries\nSIAM S", "# How to prove that $C=\\{x: Ax\\le ", "# Phase locked Loop in Demodulation\n\nCan someone please clarify how a PLL works and how it can th", "# What is forward difference interpolation?\n\n## What is forward difference interpolati", "# Math Help - matlab code hel", "Question\n\n# What is the speed of the sound in a perfectly rigid rod?\n\nOpen in App\nSoluti", "# If $\\Sigma \\models \\phi$, then for some finite $\\Delta \\subset\\Sigma$, $\\Delta \\models", "# Need some help on a proof!\n\n1. Dec 12, 2004\n\n### MathematicalMatt\n\nHowdy, I just stumbled o", "## Calculus (3rd Edition)\n\n$f(x)=[x]$ has a jump discontinuity at $x=n$.\nThe function ", "Orthogonal functions\n\nOrthogonality\n\nTwo fu", "## MacKenzie's fundamental principle of greenkeeping\n\n##### 05 May 2017\n\nI taught two semi", "2 added 246 characters in body\n\n\"If you are walking between two policemen goin", "# \u3010BZOJ 4571\u3011[SCOI2016] \u7f8e\u5473\n\n#include<bits/stdc++.h>\n#define LL long long\nus", "# Math Help - confidence interval help!!!!\n\n1. ## confide", "# Alternatives\n\n## Summing Squares: Finding or Proving a Formula\n", "Thank you for visiting nature", "# How do I keep a string of text together without", "# Q6. In a model of a ship, the mast", "# How does the Taylor Series converge at all points f", "# Revealed preference\n\nRevealed preference theory, pioneered by American ec", "Previous issue \u00b7\u00a0 Next issue \u00b7\u00a0", "Bits - Maple Programming Help\n\nHome : Su", "# Difference between revisions of \"", "## [POJ2411]Mondriaan\\'s Dream\n\n \u6210", "# Period ofWeeks() method in Java\n\nJava 8Object Oriented ProgrammingProgramming\n\nThe", "Debugging graphs with ease: experimental Visual Studio plugin\n\nRevis", "k-mer overrepresentation of WGS Illumina reads\n0\n0\nEntering edit mode\n3.8 years a", "# Hydrometeorology Research Group\n\nIn\u00a0[5]:\nfrom IPython.display import HTM", "# \ud83d\udd35\u26aa\ud83d\udd34 Bioinspired tough gel sheath for robust and versatile surface functionalization \u2013 Content Mar", "# Recurring Decimal To Fraction Cal", "# Integral related to the modified Bessel function\n\nI would like to solve t", "# Viewpoint: Particle Decays Point to an Arrow of Time\n\n\u2022 Michael Ze", "CiteULike is a free online bibliography manager. Register a", "# Social Media, Misinformation, and Voting Decisions\n\nWorking Pap", "# Test Video\n\nAller \u00e0 : Navigation, rechercher\n\nThis is an", "Enterprise Multiples\n\n.\n\nEV/SALES\n\nEqu", "PLANET Discussion: Daeridune\n\nDiscussion in 'The Manaverse W", "# Tag Info\n\n## Hot answers tagged total\n\n6\n\nDon't subtrac", "# Welcome to grmpy\u2019s documentation!\u00b6\n\ngrmpy is an open-source packag", "# Solve the linear equatio", "## February 25, 2008\n\n### A Questio", "# Concept and expression of a real function\n\n## Concept o", "# Optimization Week 9: Convex conjugate (Fenche", "# Inversions of Insertion Sort and Bubble Sort\n\nAn array with bubblesort time $$\\Theta(n)$$ is noth", "## Simple power series\n\nHey all,\n\nDoes anybody know if $x^{\\alpha}$ can be written in terms of an ", "# HiggsTools\n\nStephen Jones (MPI Muni", "# Mathematics 1010 online\n\n## Complex Numbers\n\nRecall how we built the numbe", "# string.replace.regex\n\nSynt", "# Math Help - Odd Integers\n\n1. ## Odd Integers\n\nWhat is the product if the largest of three cons"]
|
llm_run.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import transformers
|
2 |
+
from huggingface_hub import snapshot_download,constants
|
3 |
+
|
4 |
+
def download_llm_to_cache(model_name, revision="main", cache_dir=None):
|
5 |
+
"""
|
6 |
+
Download an LLM from the Hugging Face Hub to the cache without loading it into memory.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
model_name (str): The name of the model on Hugging Face Hub (e.g., "meta-llama/Llama-2-7b-hf")
|
10 |
+
revision (str, optional): The specific model version to use. Defaults to "main".
|
11 |
+
cache_dir (str, optional): The cache directory to use. If None, uses the default HF cache directory.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
str: Path to the model in cache
|
15 |
+
"""
|
16 |
+
# Get default cache dir if not specified
|
17 |
+
if cache_dir is None:
|
18 |
+
cache_dir = constants.HUGGINGFACE_HUB_CACHE
|
19 |
+
|
20 |
+
try:
|
21 |
+
# Download model to cache without loading into memory
|
22 |
+
cached_path = snapshot_download(
|
23 |
+
repo_id=model_name,
|
24 |
+
revision=revision,
|
25 |
+
cache_dir=cache_dir,
|
26 |
+
local_files_only=False # Set to True if you want to check local cache only
|
27 |
+
)
|
28 |
+
|
29 |
+
print(f"Model '{model_name}' is available in cache at: {cached_path}")
|
30 |
+
return cached_path
|
31 |
+
|
32 |
+
except Exception as e:
|
33 |
+
print(f"Error downloading model '{model_name}': {e}")
|
34 |
+
return None
|
35 |
+
|
36 |
+
def load_model(path,cache_dir=None):
|
37 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(path,cache_dir=cache_dir,device_map='auto')
|
38 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(path,cache_dir=cache_dir,device_map='auto')
|
39 |
+
return model,tokenizer
|
40 |
+
|
41 |
+
def llm_run(model,tokenizer,genes,N):
|
42 |
+
generate = transformers.pipeline('text-generation',model=model, tokenizer=tokenizer,device_map='auto')
|
43 |
+
output = []
|
44 |
+
for i,gene in enumerate(genes):
|
45 |
+
out = generate([gene], min_new_tokens=4, max_new_tokens=4, do_sample=True, num_return_sequences=N)
|
46 |
+
output.append(out[0])
|
47 |
+
yield output
|
48 |
+
return output
|
loading.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import ujson as json
|
3 |
+
import pygit2
|
4 |
+
|
5 |
+
from phylogeny import compute_all_P, compute_sim_matrix
|
6 |
+
from plotting import get_color, UNKNOWN_COLOR, DEFAULT_COLOR
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
#
|
9 |
+
# Loading data
|
10 |
+
#
|
11 |
+
# ------------------------------------------------------------------------------------------------
|
12 |
+
def load_data():
|
13 |
+
global UNKNOWN_COLOR, DEFAULT_COLOR, MODEL_SEARCHED_X
|
14 |
+
data, model_names,families = load_git()
|
15 |
+
if data is None:
|
16 |
+
return
|
17 |
+
|
18 |
+
#Rename families if needed
|
19 |
+
with open('family_table.json','r') as f:
|
20 |
+
rename_table = json.load(f)
|
21 |
+
|
22 |
+
for i in range(len(model_names)):
|
23 |
+
try:
|
24 |
+
families[i] = rename_table[model_names[i]]
|
25 |
+
except KeyError:
|
26 |
+
pass
|
27 |
+
|
28 |
+
all_P = compute_all_P(data, model_names)
|
29 |
+
sim_matrix = compute_sim_matrix(model_names, all_P)
|
30 |
+
|
31 |
+
k = list(all_P.keys())[0]
|
32 |
+
|
33 |
+
unknown_color = UNKNOWN_COLOR
|
34 |
+
|
35 |
+
unique_families = list(set([f for f in families]))
|
36 |
+
colors = {}
|
37 |
+
idx = 0
|
38 |
+
for i, family in enumerate(unique_families):
|
39 |
+
color = get_color(idx)
|
40 |
+
idx += 1
|
41 |
+
while color == unknown_color: # Avoid using the unknown color for a family
|
42 |
+
color = get_color(idx)
|
43 |
+
idx += 1
|
44 |
+
colors[family] = color
|
45 |
+
|
46 |
+
colors['?'] = unknown_color # Assign the unknown color to the unknown family
|
47 |
+
|
48 |
+
return data, model_names, families, sim_matrix, colors
|
49 |
+
|
50 |
+
def load_git():
|
51 |
+
cred = pygit2.UserPass(os.environ['GITHUB_USERNAME'], os.environ['GITHUB_TOKEN'])
|
52 |
+
if os.path.exists('Data'):
|
53 |
+
repo = pygit2.Repository('Data')
|
54 |
+
remote = repo.remotes['origin'] # Use named reference instead of index
|
55 |
+
remote.fetch()
|
56 |
+
|
57 |
+
# Get the current branch name
|
58 |
+
branch_name = repo.head.shorthand
|
59 |
+
|
60 |
+
# Find the reference to the remote branch
|
61 |
+
remote_ref_name = f'refs/remotes/origin/{branch_name}'
|
62 |
+
|
63 |
+
# Merge the changes into the current branch
|
64 |
+
remote_commit = repo.lookup_reference(remote_ref_name).target
|
65 |
+
|
66 |
+
else:
|
67 |
+
repo = pygit2.clone_repository('https://github.com/PhyloLM/Data', './Data', bare=False, callbacks=GitHubRemoteCallbacks(os.environ['GITHUB_USERNAME'], os.environ['GITHUB_TOKEN']))
|
68 |
+
|
69 |
+
data_array = []
|
70 |
+
model_names = []
|
71 |
+
families = []
|
72 |
+
for foname in os.listdir('Data/math'):
|
73 |
+
#check if it is a directory
|
74 |
+
if not os.path.isdir(os.path.join('Data/math',foname)):
|
75 |
+
continue
|
76 |
+
for fname in os.listdir('Data/math/'+foname):
|
77 |
+
if not fname.endswith('.json'):
|
78 |
+
continue
|
79 |
+
with open(os.path.join('Data/math',foname,fname),'r') as f:
|
80 |
+
d = json.load(f)
|
81 |
+
families.append(d['family'])
|
82 |
+
model_names.append(foname+'/'+fname[:-5])
|
83 |
+
data_array.append(d['alleles'])
|
84 |
+
|
85 |
+
if data_array == []:
|
86 |
+
return None,[],[]
|
87 |
+
return data_array,model_names,families
|
88 |
+
|
89 |
+
# ------------------------------------------------------------------------------------------------
|
90 |
+
#
|
91 |
+
# Git functions
|
92 |
+
#
|
93 |
+
# ------------------------------------------------------------------------------------------------
|
94 |
+
|
95 |
+
class GitHubRemoteCallbacks(pygit2.RemoteCallbacks):
|
96 |
+
def __init__(self, username, token):
|
97 |
+
self.username = username
|
98 |
+
self.token = token
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
def credentials(self, url, username_from_url, allowed_types):
|
102 |
+
return pygit2.UserPass(self.username, self.token)
|
103 |
+
|
104 |
+
# ------------------------------------------------------------------------------------------------
|
105 |
+
#
|
106 |
+
# Saving data
|
107 |
+
#
|
108 |
+
# ------------------------------------------------------------------------------------------------
|
109 |
+
|
110 |
+
def save_git(alleles,genes,model,family):
|
111 |
+
repo = pygit2.Repository('Data')
|
112 |
+
remo = repo.remotes['origin']
|
113 |
+
|
114 |
+
d = {'family':family,'alleles':alleles}
|
115 |
+
model_name = model
|
116 |
+
data_path = f'math/{model_name}.json'
|
117 |
+
path = os.path.join('Data',data_path)
|
118 |
+
#create the file folder path
|
119 |
+
if not os.path.exists(os.path.dirname(path)):
|
120 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
121 |
+
#Open the file
|
122 |
+
with open(path,'w') as f:
|
123 |
+
json.dump(d,f)
|
124 |
+
|
125 |
+
repo.index.add(data_path)
|
126 |
+
repo.index.write()
|
127 |
+
reference='HEAD'
|
128 |
+
tree = repo.index.write_tree()
|
129 |
+
author = pygit2.Signature(os.environ['GITHUB_USERNAME'], os.environ['GITHUB_MAIL'])
|
130 |
+
commiter = pygit2.Signature(os.environ['GITHUB_USERNAME'], os.environ['GITHUB_MAIL'])
|
131 |
+
oid = repo.create_commit(reference, author, commiter, f'Add data for model {model}', tree, [repo.head.target])
|
132 |
+
remo.push(['refs/heads/main'],callbacks=GitHubRemoteCallbacks(os.environ['GITHUB_USERNAME'],os.environ['GITHUB_TOKEN']))
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
graphviz
|
2 |
+
graphviz-dev
|
phylogeny.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from constants import UNKNOWN_COLOR, DEFAULT_COLOR, UNKNOWN_COLOR_RGB, DEFAULT_COLOR_RGB
|
4 |
+
|
5 |
+
def compute_P(alleles):
|
6 |
+
'''Compute the population matrix P(allele|gene) from the [alleles] given in input'''
|
7 |
+
P = []
|
8 |
+
# Process each gene position
|
9 |
+
for gene_alleles in alleles:
|
10 |
+
# Use Counter for more efficient counting
|
11 |
+
unique_alleles, counts = np.unique(gene_alleles, return_counts=True)
|
12 |
+
# Create frequency dictionary directly
|
13 |
+
d = dict(zip(unique_alleles, counts / len(gene_alleles)))
|
14 |
+
P.append(d)
|
15 |
+
return P
|
16 |
+
|
17 |
+
def compute_all_P(data, models):
|
18 |
+
'''Compute all population matrices from a given list of [models] on the data'''
|
19 |
+
all_P = {}
|
20 |
+
for mi, m in enumerate(models):
|
21 |
+
alleles = data[mi]
|
22 |
+
P = compute_P(alleles)
|
23 |
+
all_P[m] = P
|
24 |
+
return all_P
|
25 |
+
|
26 |
+
def compute_sim_matrix(models,all_P):
|
27 |
+
'''Compute the entire similarity matrix in one go'''
|
28 |
+
n_models = len(models)
|
29 |
+
n_genes = len(all_P[models[0]])
|
30 |
+
|
31 |
+
# Initialize matrices to store numerator and denominator terms
|
32 |
+
total_numerator = np.zeros((n_models, n_models))
|
33 |
+
left_denominators = np.zeros(n_models)
|
34 |
+
right_denominators = np.zeros(n_models)
|
35 |
+
|
36 |
+
# Process each gene position
|
37 |
+
for k in range(n_genes):
|
38 |
+
# Collect all alleles for this gene position
|
39 |
+
all_alleles = set()
|
40 |
+
for m in models:
|
41 |
+
all_alleles.update(all_P[m][k].keys())
|
42 |
+
all_alleles = list(all_alleles)
|
43 |
+
|
44 |
+
# Create frequency vectors for each model
|
45 |
+
freq_matrix = np.zeros((n_models, len(all_alleles)))
|
46 |
+
for i, m in enumerate(models):
|
47 |
+
for j, allele in enumerate(all_alleles):
|
48 |
+
if allele in all_P[m][k]:
|
49 |
+
freq_matrix[i, j] = all_P[m][k][allele]
|
50 |
+
|
51 |
+
# Update numerator: dot product of frequency vectors
|
52 |
+
total_numerator += np.dot(freq_matrix, freq_matrix.T)
|
53 |
+
|
54 |
+
# Update denominators: sum of squared frequencies
|
55 |
+
squared_sums = np.sum(freq_matrix**2, axis=1)
|
56 |
+
left_denominators += squared_sums
|
57 |
+
right_denominators += squared_sums
|
58 |
+
|
59 |
+
# Calculate final similarity matrix
|
60 |
+
denominator_matrix = np.sqrt(np.outer(left_denominators, right_denominators))
|
61 |
+
sim_matrix = total_numerator / denominator_matrix
|
62 |
+
|
63 |
+
return sim_matrix
|
64 |
+
|
65 |
+
def prepare_tree(tree, model_names, origins, colors):
|
66 |
+
"""Prepare and color the phylogenetic tree based on model families."""
|
67 |
+
# Remove inner node names and color leaf nodes
|
68 |
+
for clade in tree.find_clades():
|
69 |
+
if clade.name and (clade.name.startswith('Inner') or clade.name.startswith('Clade')):
|
70 |
+
#clade.name = None
|
71 |
+
pass
|
72 |
+
if clade.name == None or clade.name not in model_names:
|
73 |
+
clade.family = None
|
74 |
+
clade.flag = False
|
75 |
+
continue
|
76 |
+
# Color the clades if it is a leaf
|
77 |
+
index = model_names.index(clade.name)
|
78 |
+
clade.family = origins[index]
|
79 |
+
clade.flag = True
|
80 |
+
|
81 |
+
# Propagate colors up the tree when all children have the same color
|
82 |
+
all_clades = list(tree.find_clades())
|
83 |
+
clades = [clade for clade in all_clades if clade.flag is False]
|
84 |
+
|
85 |
+
# Iterate this process until there are no more clades to color
|
86 |
+
i = 0
|
87 |
+
while clades:
|
88 |
+
clade = clades[i % len(clades)]
|
89 |
+
children_families = [c.family for c in clade.clades]
|
90 |
+
children_families_set = set(children_families)
|
91 |
+
children_flags = [c.flag for c in clade.clades]
|
92 |
+
children_flags_set = set(children_flags)
|
93 |
+
|
94 |
+
if len(children_families_set) == 1: # If all children have the same color : this clade is locked with the same color
|
95 |
+
clade.family = children_families[0]
|
96 |
+
clade.flag = True
|
97 |
+
del clades[i % len(clades)]
|
98 |
+
elif len(children_families_set) == 2 and '?' in children_families_set: # If children have different colors and one is unknown : this clade is locked with the known color
|
99 |
+
clade.family = [f for f in children_families_set if f != '?'][0]
|
100 |
+
clade.flag = True
|
101 |
+
del clades[i % len(clades)]
|
102 |
+
elif len(children_flags_set) == 1: # If children have different colors : this clade is locked with no color
|
103 |
+
clade.flag = True
|
104 |
+
del clades[i % len(clades)]
|
105 |
+
elif clade.flag == True: #Sholdn't happen
|
106 |
+
del clades[i % len(clades)]
|
107 |
+
i += 1
|
108 |
+
|
109 |
+
#Set color associated with family to each clade
|
110 |
+
for clade in all_clades:
|
111 |
+
if clade.family is None:
|
112 |
+
clade.color = UNKNOWN_COLOR
|
113 |
+
else:
|
114 |
+
clade.color = colors[clade.family]
|
plotting.py
ADDED
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import networkx as nx
|
2 |
+
import numpy as np
|
3 |
+
from Bio.Phylo import to_networkx
|
4 |
+
from networkx.drawing.nx_agraph import graphviz_layout
|
5 |
+
import plotly.graph_objects as go
|
6 |
+
import plotly.express as px
|
7 |
+
from Bio.Phylo.TreeConstruction import DistanceTreeConstructor, DistanceCalculator, _DistanceMatrix
|
8 |
+
|
9 |
+
from tools import compute_ordered_matrix,compute_umap
|
10 |
+
from phylogeny import prepare_tree
|
11 |
+
from constants import UNKNOWN_COLOR, DEFAULT_COLOR, UNKNOWN_COLOR_RGB, DEFAULT_COLOR_RGB
|
12 |
+
|
13 |
+
# ------------------------------------------------------------------------------------------------
|
14 |
+
#
|
15 |
+
# Sim Matrix Plotting
|
16 |
+
#
|
17 |
+
# ------------------------------------------------------------------------------------------------
|
18 |
+
|
19 |
+
def plot_sim_matrix_fig(ordered_sim_matrix,ordered_model_names,families,colors):
|
20 |
+
fig = px.imshow(
|
21 |
+
ordered_sim_matrix,
|
22 |
+
x=ordered_model_names,
|
23 |
+
y=ordered_model_names,
|
24 |
+
zmin=0, zmax=1,
|
25 |
+
color_continuous_scale='gray',
|
26 |
+
)
|
27 |
+
|
28 |
+
fig.update_layout(coloraxis_colorbar=dict(title='Similarity'),
|
29 |
+
margin=dict(l=0, r=0, t=0, b=0),
|
30 |
+
autosize=True,
|
31 |
+
)
|
32 |
+
|
33 |
+
fig.update_traces(
|
34 |
+
colorbar=dict(
|
35 |
+
thickness=20,
|
36 |
+
len=0.75,
|
37 |
+
xanchor="right",
|
38 |
+
x=1.02
|
39 |
+
)
|
40 |
+
)
|
41 |
+
|
42 |
+
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')
|
43 |
+
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')
|
44 |
+
|
45 |
+
#Create rectangles for highlighted models
|
46 |
+
rectX = go.layout.Shape(
|
47 |
+
type="rect",
|
48 |
+
xref="x", yref="y",
|
49 |
+
x0=0, y0=0,
|
50 |
+
x1=0, y1=0,
|
51 |
+
line=dict(color="red", width=1),
|
52 |
+
fillcolor="rgba(0,0,0,0)",
|
53 |
+
name='rectX',
|
54 |
+
opacity=0,
|
55 |
+
)
|
56 |
+
fig.add_shape(rectX)
|
57 |
+
rectY = go.layout.Shape(
|
58 |
+
type="rect",
|
59 |
+
xref="x", yref="y",
|
60 |
+
x0=0, y0=0,
|
61 |
+
x1=0, y1=0,
|
62 |
+
line=dict(color="red", width=1),
|
63 |
+
fillcolor="rgba(0,0,0,0)",
|
64 |
+
name='rectY',
|
65 |
+
opacity=0,
|
66 |
+
)
|
67 |
+
fig.add_shape(rectY)
|
68 |
+
|
69 |
+
return fig
|
70 |
+
|
71 |
+
def update_sim_matrix_fig(fig, ordered_model_names, model_search_x=None, model_search_y=None):
|
72 |
+
if model_search_x in ordered_model_names:
|
73 |
+
idx_x = ordered_model_names.index(model_search_x)
|
74 |
+
fig.update_shapes(
|
75 |
+
selector=dict(name='rectX'),
|
76 |
+
x0=idx_x-0.5, y0=-0.5,
|
77 |
+
x1=idx_x+0.5, y1=len(ordered_model_names)-0.5,
|
78 |
+
opacity=0.7,
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
fig.update_shapes(
|
82 |
+
selector=dict(name='rectX'),
|
83 |
+
opacity=0
|
84 |
+
)
|
85 |
+
if model_search_y in ordered_model_names:
|
86 |
+
idx_y = ordered_model_names.index(model_search_y)
|
87 |
+
fig.update_shapes(
|
88 |
+
selector=dict(name='rectY'),
|
89 |
+
x0=-0.5, y0=idx_y-0.5,
|
90 |
+
x1=len(ordered_model_names)-0.5, y1=idx_y+0.5,
|
91 |
+
opacity=0.7,
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
fig.update_shapes(
|
95 |
+
selector=dict(name='rectY'),
|
96 |
+
opacity=0
|
97 |
+
)
|
98 |
+
return fig
|
99 |
+
|
100 |
+
# ------------------------------------------------------------------------------------------------
|
101 |
+
#
|
102 |
+
# 2D UMAP Plotting
|
103 |
+
#
|
104 |
+
# ------------------------------------------------------------------------------------------------
|
105 |
+
|
106 |
+
def alpha_scaling(val):
|
107 |
+
base = 0.35
|
108 |
+
return val**(1/(base+1/100))
|
109 |
+
|
110 |
+
def plot_umap_fig(dist_matrix, sim_matrix, model_names, families, colors,key='fig2',alpha_edges=None, alpha_names=None, alpha_markers=None):
|
111 |
+
embedding = compute_umap(dist_matrix,d=2)
|
112 |
+
|
113 |
+
fig = go.Figure()
|
114 |
+
|
115 |
+
#-- EDGES
|
116 |
+
# Calculate edge transparencies based on similarity
|
117 |
+
edges = []
|
118 |
+
for i in range(len(model_names)):
|
119 |
+
for j in range(i+1, len(model_names)): # Only process each pair once (i,j where i<j)
|
120 |
+
val = alpha_scaling(sim_matrix[i][j])
|
121 |
+
if val > 0.1:
|
122 |
+
edges.append((i, j, val, colors[families[i]]))
|
123 |
+
|
124 |
+
# Add all edges at once
|
125 |
+
for i, j, val, color in edges:
|
126 |
+
fig.add_trace(
|
127 |
+
go.Scatter(
|
128 |
+
x=[embedding[i,0], embedding[j,0]],
|
129 |
+
y=[embedding[i,1], embedding[j,1]],
|
130 |
+
mode='lines',
|
131 |
+
name='_edge',
|
132 |
+
line=dict(color=color, width=val),
|
133 |
+
opacity=alpha_edges,
|
134 |
+
showlegend=False,
|
135 |
+
hoverinfo='skip',
|
136 |
+
)
|
137 |
+
)
|
138 |
+
|
139 |
+
#-- NODES
|
140 |
+
marker_colors = [colors[f] for f in families]
|
141 |
+
fig.add_trace(
|
142 |
+
go.Scatter(
|
143 |
+
x=embedding[:,0],
|
144 |
+
y=embedding[:,1],
|
145 |
+
text=model_names,
|
146 |
+
mode='markers+text',
|
147 |
+
textposition='top center',
|
148 |
+
hoverinfo='text',
|
149 |
+
hoveron='points+fills',
|
150 |
+
showlegend=False,
|
151 |
+
name='_node',
|
152 |
+
marker=dict(
|
153 |
+
color=marker_colors,
|
154 |
+
size=8,
|
155 |
+
line_width=2,
|
156 |
+
opacity=alpha_markers,
|
157 |
+
),
|
158 |
+
textfont=dict(
|
159 |
+
color=f'rgba(0,0,0,{alpha_names})',
|
160 |
+
size=8,
|
161 |
+
family="Arial Black",
|
162 |
+
)
|
163 |
+
)
|
164 |
+
)
|
165 |
+
|
166 |
+
#-- LEGEND
|
167 |
+
legends = []
|
168 |
+
for f in set(families):
|
169 |
+
legends.append(
|
170 |
+
go.Scatter(
|
171 |
+
x=[None],
|
172 |
+
y=[None],
|
173 |
+
mode='markers',
|
174 |
+
marker=dict(
|
175 |
+
color=colors[f],
|
176 |
+
size=8,
|
177 |
+
line_width=2,
|
178 |
+
opacity=1
|
179 |
+
),
|
180 |
+
name=f,
|
181 |
+
|
182 |
+
)
|
183 |
+
)
|
184 |
+
fig.add_traces(legends)
|
185 |
+
|
186 |
+
#Add highlighted node
|
187 |
+
node = go.Scatter(
|
188 |
+
x=[0],
|
189 |
+
y=[0],
|
190 |
+
mode='markers+text',
|
191 |
+
textposition='top center',
|
192 |
+
textfont=dict(color='red', size=16, family="Arial Black"),
|
193 |
+
marker=dict(
|
194 |
+
color='red',
|
195 |
+
size=12,
|
196 |
+
symbol='circle',
|
197 |
+
line=dict(color='red', width=3)
|
198 |
+
),
|
199 |
+
showlegend=False,
|
200 |
+
name='node',
|
201 |
+
opacity=0,
|
202 |
+
)
|
203 |
+
fig.add_trace(node)
|
204 |
+
|
205 |
+
#Setup the layout
|
206 |
+
fig.update_layout(
|
207 |
+
margin=dict(l=0, r=0, t=0, b=0),
|
208 |
+
autosize=True,
|
209 |
+
)
|
210 |
+
|
211 |
+
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')
|
212 |
+
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')
|
213 |
+
|
214 |
+
return fig
|
215 |
+
|
216 |
+
def update_umap_fig(fig, dist_matrix, model_names, families, colors, model_search_x=None, alpha_names=None, alpha_markers=None, alpha_edges=None, key='fig2'):
|
217 |
+
#Update nodes
|
218 |
+
fig.update_traces(
|
219 |
+
selector=dict(name='_node'),
|
220 |
+
textfont=dict(
|
221 |
+
color=f'rgba(0,0,0,{alpha_names})',
|
222 |
+
),
|
223 |
+
marker=dict(
|
224 |
+
opacity=alpha_markers
|
225 |
+
),
|
226 |
+
)
|
227 |
+
|
228 |
+
#Update edges
|
229 |
+
fig.update_traces(
|
230 |
+
selector=dict(mode='lines'),
|
231 |
+
line=dict(width=1),
|
232 |
+
opacity=alpha_edges
|
233 |
+
)
|
234 |
+
|
235 |
+
#Update highlighted node
|
236 |
+
if model_search_x in model_names:
|
237 |
+
searched_idx = model_names.index(model_search_x)
|
238 |
+
embedding = compute_umap(dist_matrix,d=2) #Cached computation
|
239 |
+
fig.update_traces(
|
240 |
+
selector=dict(name='node'),
|
241 |
+
x=[embedding[searched_idx,0]],
|
242 |
+
y=[embedding[searched_idx,1]],
|
243 |
+
text=[model_search_x],
|
244 |
+
marker=dict(
|
245 |
+
color=colors[families[searched_idx]],
|
246 |
+
),
|
247 |
+
hovertext=model_search_x,
|
248 |
+
opacity=1
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
fig.update_traces(
|
252 |
+
selector=dict(name='node'),
|
253 |
+
x=[0],
|
254 |
+
y=[0],
|
255 |
+
text=[''],
|
256 |
+
opacity=0
|
257 |
+
)
|
258 |
+
return fig
|
259 |
+
|
260 |
+
# ------------------------------------------------------------------------------------------------
|
261 |
+
#
|
262 |
+
# Phylogenetic Tree Plotting
|
263 |
+
#
|
264 |
+
# ------------------------------------------------------------------------------------------------
|
265 |
+
|
266 |
+
def draw_graphviz(tree, label_func=str, prog='twopi', args='',
|
267 |
+
node_size=15, edge_width=0.0, alpha_edges=None, alpha_names=None,alpha_markers=None, **kwargs):
|
268 |
+
#Display a tree or clade as a graph using Plotly, with layout from the graphviz engine.
|
269 |
+
|
270 |
+
global UNKNOWN_COLOR, DEFAULT_COLOR
|
271 |
+
# Convert the Bio.Phylo tree to a NetworkX graph
|
272 |
+
G = to_networkx(tree)
|
273 |
+
|
274 |
+
# Relabel nodes using integers while keeping original labels
|
275 |
+
Gi = nx.convert_node_labels_to_integers(G, label_attribute='label')
|
276 |
+
|
277 |
+
# Apply the Graphviz layout
|
278 |
+
pos = graphviz_layout(Gi, prog=prog, args=args)
|
279 |
+
|
280 |
+
# Prepare node labels for display
|
281 |
+
def get_label_mapping(G, selection):
|
282 |
+
for node, data in G.nodes(data=True):
|
283 |
+
if (selection is None) or (node in selection):
|
284 |
+
try:
|
285 |
+
label = label_func(data.get('label', node))
|
286 |
+
if label not in (None, node.__class__.__name__):
|
287 |
+
yield (node, label)
|
288 |
+
except (LookupError, AttributeError, ValueError):
|
289 |
+
pass
|
290 |
+
|
291 |
+
# Extract labels
|
292 |
+
labels = dict(get_label_mapping(Gi, None))
|
293 |
+
nodelist = list(labels.keys())
|
294 |
+
|
295 |
+
# Collect node colors and create edge traces
|
296 |
+
edge_traces = []
|
297 |
+
node_traces_by_family = {}
|
298 |
+
node_colors = {}
|
299 |
+
node_families = {}
|
300 |
+
|
301 |
+
# Track if we find the searched model and its position
|
302 |
+
searched_model_node = None
|
303 |
+
searched_model_pos = None
|
304 |
+
|
305 |
+
default_color = (0,0,0)
|
306 |
+
|
307 |
+
# Get colors and families for all nodes
|
308 |
+
for node in Gi.nodes():
|
309 |
+
node_data = Gi.nodes[node].get('label')
|
310 |
+
if hasattr(node_data, 'color'):
|
311 |
+
node_colors[node] = node_data.color.to_rgb() if not(node_data.color is None) else default_color
|
312 |
+
else:
|
313 |
+
node_colors[node] = default_color
|
314 |
+
node_colors[node] = f'rgb({node_colors[node][0]},{node_colors[node][1]},{node_colors[node][2]})'
|
315 |
+
|
316 |
+
if hasattr(node_data, 'family'):
|
317 |
+
node_families[node] = node_data.family
|
318 |
+
else:
|
319 |
+
node_families[node] = None
|
320 |
+
|
321 |
+
# Create edge traces
|
322 |
+
for edge in Gi.edges():
|
323 |
+
x0, y0 = pos[edge[0]]
|
324 |
+
x1, y1 = pos[edge[1]]
|
325 |
+
|
326 |
+
# Use the child node's color for the edge if available
|
327 |
+
edge_color = node_colors[edge[1]]
|
328 |
+
if list(edge_color) == list(UNKNOWN_COLOR_RGB): # Use the parent node's color for edge's color except if it's an unknown nodes
|
329 |
+
edge_color = tuple(DEFAULT_COLOR_RGB)
|
330 |
+
#edge_color = f'rgb({edge_color[0]},{edge_color[1]},{edge_color[2]})'
|
331 |
+
edge_trace = go.Scatter(
|
332 |
+
x=[x0, x1, None],
|
333 |
+
y=[y0, y1, None],
|
334 |
+
line=dict(width=edge_width, color=edge_color),
|
335 |
+
hoverinfo='none',
|
336 |
+
mode='lines',
|
337 |
+
showlegend=False,
|
338 |
+
name='_edge',
|
339 |
+
opacity=alpha_edges,
|
340 |
+
)
|
341 |
+
edge_traces.append(edge_trace)
|
342 |
+
|
343 |
+
# Create node traces
|
344 |
+
node_traces = []
|
345 |
+
for node in nodelist:
|
346 |
+
x,y = pos[node]
|
347 |
+
text = labels.get(node, None)
|
348 |
+
color = node_colors.get(node, None)
|
349 |
+
node_trace = go.Scatter(
|
350 |
+
x=[x],
|
351 |
+
y=[y],
|
352 |
+
text=text,
|
353 |
+
mode='markers+text',
|
354 |
+
textposition='top center',
|
355 |
+
hoverinfo='text',
|
356 |
+
showlegend=False,
|
357 |
+
name='_node',
|
358 |
+
marker=dict(
|
359 |
+
color=color,
|
360 |
+
size=node_size,
|
361 |
+
line_width=2,
|
362 |
+
opacity=alpha_markers,
|
363 |
+
),
|
364 |
+
textfont=dict(
|
365 |
+
color=f'rgba(0,0,0,{alpha_names})',
|
366 |
+
size=8,
|
367 |
+
family="Arial Black",
|
368 |
+
)
|
369 |
+
)
|
370 |
+
node_traces.append(node_trace)
|
371 |
+
|
372 |
+
# Get color dict
|
373 |
+
colors = {}
|
374 |
+
families = []
|
375 |
+
for node in node_families.keys():
|
376 |
+
family = node_families[node]
|
377 |
+
if family is not None:
|
378 |
+
families.append(family)
|
379 |
+
colors[family] = node_colors.get(node, DEFAULT_COLOR)
|
380 |
+
else:
|
381 |
+
colors[family] = DEFAULT_COLOR
|
382 |
+
|
383 |
+
families = set(families)
|
384 |
+
|
385 |
+
#Custom legend
|
386 |
+
legends = []
|
387 |
+
for f in families:
|
388 |
+
legends.append(
|
389 |
+
go.Scatter(
|
390 |
+
x=[None],
|
391 |
+
y=[None],
|
392 |
+
mode='markers',
|
393 |
+
marker=dict(
|
394 |
+
color=colors[f],
|
395 |
+
size=8,
|
396 |
+
line_width=2,
|
397 |
+
opacity=1
|
398 |
+
),
|
399 |
+
name=f,
|
400 |
+
|
401 |
+
)
|
402 |
+
)
|
403 |
+
|
404 |
+
# Create the figure
|
405 |
+
fig = go.Figure(
|
406 |
+
data=edge_traces + node_traces,
|
407 |
+
layout=go.Layout(
|
408 |
+
showlegend=True,
|
409 |
+
hovermode='closest',
|
410 |
+
margin=dict(b=1, l=1, r=1, t=1),
|
411 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
412 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
413 |
+
legend=dict(
|
414 |
+
yanchor="top",
|
415 |
+
y=0.99,
|
416 |
+
xanchor="right",
|
417 |
+
x=0.99
|
418 |
+
)
|
419 |
+
)
|
420 |
+
)
|
421 |
+
|
422 |
+
fig.add_traces(legends)
|
423 |
+
|
424 |
+
return fig
|
425 |
+
|
426 |
+
def get_color(index):
|
427 |
+
"""Get a color from plotly's qualitative color palette."""
|
428 |
+
colors = px.colors.qualitative.Plotly
|
429 |
+
return colors[index % len(colors)]
|
430 |
+
|
431 |
+
def plot_tree(sim_matrix, models, families,colors, alpha_names=None, alpha_markers=None, alpha_edges=None):
|
432 |
+
"""
|
433 |
+
Plot a phylogenetic tree based on a similarity matrix.
|
434 |
+
|
435 |
+
Parameters:
|
436 |
+
- sim_matrix: similarity matrix between models
|
437 |
+
- models: list of model names
|
438 |
+
- families: list of family names for each model
|
439 |
+
|
440 |
+
Returns:
|
441 |
+
- fig: Plotly figure object with the phylogenetic tree
|
442 |
+
"""
|
443 |
+
# Create color mapping for families
|
444 |
+
|
445 |
+
# Prepare the distance matrix
|
446 |
+
dist_matrix = -np.log(np.maximum(sim_matrix, 1e-10)) # Avoid log(0)
|
447 |
+
|
448 |
+
# Prepare the data for Bio.Phylo
|
449 |
+
low_triangle_kl_mean = [[dist_matrix[i][j] for j in range(i+1)] for i in range(len(dist_matrix))]
|
450 |
+
df = _DistanceMatrix(names=models, matrix=low_triangle_kl_mean)
|
451 |
+
|
452 |
+
# Setup Bio.Phylo
|
453 |
+
calculator = DistanceCalculator('identity')
|
454 |
+
constructor = DistanceTreeConstructor(calculator, 'nj')
|
455 |
+
|
456 |
+
# Build the tree
|
457 |
+
NJTree = constructor.nj(df)
|
458 |
+
NJTree.ladderize(reverse=False)
|
459 |
+
|
460 |
+
# Color the tree
|
461 |
+
prepare_tree(NJTree, models, families, colors)
|
462 |
+
|
463 |
+
# Generate the plotly figure
|
464 |
+
fig = draw_graphviz(NJTree, node_size=15, edge_width=6,alpha_names=alpha_names, alpha_markers=alpha_markers, alpha_edges=alpha_edges)
|
465 |
+
|
466 |
+
return fig
|
467 |
+
|
468 |
+
def update_tree_fig(fig, model_names, model_search=None,alpha_names=None, alpha_markers=None, alpha_edges=None):
|
469 |
+
#Update nodes
|
470 |
+
fig.update_traces(
|
471 |
+
selector=dict(name='_node'),
|
472 |
+
marker=dict(
|
473 |
+
opacity=alpha_markers,
|
474 |
+
),
|
475 |
+
textfont=dict(
|
476 |
+
color=f'rgba(0,0,0,{alpha_names})',
|
477 |
+
)
|
478 |
+
)
|
479 |
+
|
480 |
+
# Update edges
|
481 |
+
fig.update_traces(
|
482 |
+
selector=dict(name='_edge'),
|
483 |
+
opacity=alpha_edges,
|
484 |
+
)
|
485 |
+
|
486 |
+
for d in fig.data:
|
487 |
+
if d.name in ['_node','node']:
|
488 |
+
if d.text == 'mistralai/Mistral-7B-Instruct-v0.1':
|
489 |
+
print(d)
|
490 |
+
|
491 |
+
# Update highlighted node
|
492 |
+
fig.update_traces(
|
493 |
+
selector=dict(name='node'),
|
494 |
+
marker=dict(
|
495 |
+
size=15, # Bigger than normal nodes
|
496 |
+
line=None # Red border
|
497 |
+
),
|
498 |
+
textfont=dict(
|
499 |
+
color=f'rgba(0,0,0,{alpha_names})', size=16, family="Arial Black",
|
500 |
+
),
|
501 |
+
name='_node'
|
502 |
+
)
|
503 |
+
if model_search in model_names:
|
504 |
+
fig.update_traces(
|
505 |
+
selector=dict(name='_node',text=model_search),
|
506 |
+
marker=dict(
|
507 |
+
size=22, # Bigger than normal nodes
|
508 |
+
line=dict(color='red', width=4) # Red border
|
509 |
+
),
|
510 |
+
textfont=dict(
|
511 |
+
color='red', size=16, family="Arial Black",
|
512 |
+
),
|
513 |
+
name='node'
|
514 |
+
)
|
515 |
+
for d in fig.data:
|
516 |
+
if d.name in ['_node','node']:
|
517 |
+
if d.text == 'mistralai/Mistral-7B-Instruct-v0.1':
|
518 |
+
print(d)
|
519 |
+
else:
|
520 |
+
pass
|
521 |
+
|
522 |
+
return fig
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib
|
2 |
+
networkx
|
3 |
+
numpy==1.23.0
|
4 |
+
biopython
|
5 |
+
plotly
|
6 |
+
scikit-learn
|
7 |
+
streamlit
|
8 |
+
transformers
|
9 |
+
torch
|
10 |
+
pygit2
|
11 |
+
fastcluster
|
12 |
+
pygraphviz
|
13 |
+
accelerate
|
14 |
+
umap-learn
|
15 |
+
ujson
|
tools.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.spatial.distance import squareform
|
3 |
+
from fastcluster import linkage
|
4 |
+
import umap
|
5 |
+
|
6 |
+
# ------------------------------------------------------------------------------------------------
|
7 |
+
#
|
8 |
+
# Sim Matrix Ordering
|
9 |
+
#
|
10 |
+
# ------------------------------------------------------------------------------------------------
|
11 |
+
|
12 |
+
def seriation(Z,N,cur_index):
|
13 |
+
'''
|
14 |
+
input:
|
15 |
+
- Z is a hierarchical tree (dendrogram)
|
16 |
+
- N is the number of points given to the clustering process
|
17 |
+
- cur_index is the position in the tree for the recursive traversal
|
18 |
+
output:
|
19 |
+
- order implied by the hierarchical tree Z
|
20 |
+
|
21 |
+
seriation computes the order implied by a hierarchical tree (dendrogram)
|
22 |
+
'''
|
23 |
+
if cur_index < N:
|
24 |
+
return [cur_index]
|
25 |
+
else:
|
26 |
+
left = int(Z[cur_index-N,0])
|
27 |
+
right = int(Z[cur_index-N,1])
|
28 |
+
return (seriation(Z,N,left) + seriation(Z,N,right))
|
29 |
+
|
30 |
+
def compute_serial_matrix(dist_mat,method="ward"):
|
31 |
+
'''
|
32 |
+
input:
|
33 |
+
- dist_mat is a distance matrix
|
34 |
+
- method = ["ward","single","average","complete"]
|
35 |
+
output:
|
36 |
+
- seriated_dist is the input dist_mat,
|
37 |
+
but with re-ordered rows and columns
|
38 |
+
according to the seriation, i.e. the
|
39 |
+
order implied by the hierarchical tree
|
40 |
+
- res_order is the order implied by
|
41 |
+
the hierarhical tree
|
42 |
+
- res_linkage is the hierarhical tree (dendrogram)
|
43 |
+
|
44 |
+
compute_serial_matrix transforms a distance matrix into
|
45 |
+
a sorted distance matrix according to the order implied
|
46 |
+
by the hierarchical tree (dendrogram)
|
47 |
+
'''
|
48 |
+
N = len(dist_mat)
|
49 |
+
flat_dist_mat = squareform(dist_mat)
|
50 |
+
res_linkage = linkage(flat_dist_mat, method=method,preserve_input=True)
|
51 |
+
res_order = seriation(res_linkage, N, N + N-2)
|
52 |
+
seriated_dist = np.zeros((N,N))
|
53 |
+
a,b = np.triu_indices(N,k=1)
|
54 |
+
seriated_dist[a,b] = dist_mat[ [res_order[i] for i in a], [res_order[j] for j in b]]
|
55 |
+
seriated_dist[b,a] = seriated_dist[a,b]
|
56 |
+
|
57 |
+
return seriated_dist, res_order, res_linkage
|
58 |
+
|
59 |
+
def compute_ordered_matrix(sim_matrix,dist_matrix, model_names):
|
60 |
+
if len(sim_matrix) >= 2:
|
61 |
+
# Compute serial matrix (hierarchical clustering) for tab1
|
62 |
+
ordered_dist_matrix, order, Z = compute_serial_matrix(dist_matrix)
|
63 |
+
ordered_sim_matrix = sim_matrix[order][:, order]
|
64 |
+
ordered_model_names = [model_names[i] for i in order]
|
65 |
+
else:
|
66 |
+
ordered_sim_matrix = sim_matrix
|
67 |
+
ordered_model_names = model_names
|
68 |
+
|
69 |
+
return ordered_sim_matrix, ordered_model_names
|
70 |
+
|
71 |
+
|
72 |
+
# ------------------------------------------------------------------------------------------------
|
73 |
+
#
|
74 |
+
# UMAP computation
|
75 |
+
#
|
76 |
+
# ------------------------------------------------------------------------------------------------
|
77 |
+
|
78 |
+
def compute_umap(dist_matrix,d=2):
|
79 |
+
embedding = umap.UMAP(densmap=True,n_components=d, metric='precomputed',random_state=42).fit_transform(dist_matrix)
|
80 |
+
return embedding
|