Daetheys commited on
Commit
3d6ba31
·
1 Parent(s): 77fc2c1

First version gradio

Browse files
Files changed (11) hide show
  1. app.py +396 -0
  2. constants.py +8 -0
  3. family_table.json +1 -0
  4. inputs/math.json +1 -0
  5. llm_run.py +48 -0
  6. loading.py +132 -0
  7. packages.txt +2 -0
  8. phylogeny.py +114 -0
  9. plotting.py +522 -0
  10. requirements.txt +15 -0
  11. tools.py +80 -0
app.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import numpy as np
4
+ import ujson as json
5
+
6
+ from loading import load_data, save_git
7
+ from tools import compute_ordered_matrix
8
+ from plotting import plot_sim_matrix_fig, plot_umap_fig, plot_tree, update_sim_matrix_fig, update_umap_fig, update_tree_fig
9
+ from llm_run import download_llm_to_cache, load_model, llm_run
10
+
11
+ def reload_figures():
12
+ global MODEL_SEARCHED_X, MODEL_SEARCHED_Y, ALPHA_EDGES, ALPHA_NAMES, ALPHA_MARKERS, FIGS, ORDERED_MODEL_NAMES
13
+ fig1 = update_sim_matrix_fig(FIGS['fig1'], ORDERED_MODEL_NAMES, model_search_x=MODEL_SEARCHED_X, model_search_y=MODEL_SEARCHED_Y)
14
+ fig2 = update_umap_fig(FIGS['fig2'], DIST_MATRIX, MODEL_NAMES, FAMILIES, COLORS, model_search_x=MODEL_SEARCHED_X, alpha_edges=ALPHA_EDGES['fig2'], alpha_names=ALPHA_NAMES['fig2'], alpha_markers=ALPHA_MARKERS['fig2'])
15
+ fig4 = update_tree_fig(FIGS['fig4'], MODEL_NAMES, model_search=MODEL_SEARCHED_X, alpha_edges=ALPHA_EDGES['fig4'], alpha_names=ALPHA_NAMES['fig4'], alpha_markers=ALPHA_MARKERS['fig4'])
16
+ return [fig1,fig2,fig4]
17
+
18
+ def search_bar_changeX(value):
19
+ global MODEL_SEARCHED_X
20
+ MODEL_SEARCHED_X = value
21
+ return reload_figures()
22
+
23
+ def search_bar_changeY(value):
24
+ global MODEL_SEARCHED_Y
25
+ MODEL_SEARCHED_Y = value
26
+ return reload_figures()
27
+
28
+ def slider_changeAlphaMarkers(value,key):
29
+ global ALPHA_MARKERS
30
+ ALPHA_MARKERS[key] = value
31
+ return reload_figures()
32
+
33
+ def slider_changeAlphaNames(value,key):
34
+ global ALPHA_NAMES
35
+ ALPHA_NAMES[key] = value
36
+ return reload_figures()
37
+
38
+ def slider_changeAlphaEdges(value,key):
39
+ global ALPHA_EDGES
40
+ ALPHA_EDGES[key] = value
41
+ return reload_figures()
42
+
43
+ def search_bar_gr(model_names,slider=True,double_search=False,key=None):
44
+ global MODEL_SEARCHED_X,MODEL_SEARCHED_Y,ALPHA_EDGES,ALPHA_NAMES, ALPHA_MARKERS
45
+ #col1,col2 = gr.Row([0.2,0.8])
46
+ ret = []
47
+ with gr.Column(scale=1) as col1:
48
+ with gr.Group():
49
+ if MODEL_SEARCHED_X is None:
50
+ index = 0
51
+ else:
52
+ index = model_names.index(MODEL_SEARCHED_X)
53
+ ms_x = gr.Dropdown(label='Search'+(' X' if double_search else ''),choices=model_names,value=model_names[index],key='model_search_x_'+key,interactive=True)
54
+ #set MODEL_SEARCH_X
55
+ ret.append(ms_x)
56
+ if double_search:
57
+ if MODEL_SEARCHED_Y is None:
58
+ index = 0
59
+ else:
60
+ index = model_names.index(MODEL_SEARCHED_Y)
61
+ ms_y = gr.Dropdown(label='Search Y',choices=model_names,value=model_names[index],key='model_search_y_'+key,interactive=True)
62
+ ret.append(ms_y)
63
+ if slider:
64
+ with gr.Group():
65
+ values = np.arange(0, 1.05,0.05)
66
+ #truncate values to the 100th
67
+ values = np.round(values,2)
68
+ alpha_edges = gr.Slider(label='Alpha Edges',
69
+ minimum=0,
70
+ maximum=1,
71
+ step=0.05,
72
+ value=ALPHA_EDGES[key],
73
+ key='alpha_edges_'+key,
74
+ interactive=True)
75
+
76
+ values = np.arange(0, 1.05,0.05)
77
+ #truncate values to the 100th
78
+ values = np.round(values,2)
79
+ alpha_names = gr.Slider(label='Alpha Names',
80
+ minimum=0,
81
+ maximum=1,
82
+ step=0.05,
83
+ value=ALPHA_NAMES[key],
84
+ key='alpha_names_'+key,
85
+ interactive=True)
86
+
87
+ values = np.arange(0, 1.05,0.05)
88
+ #truncate values to the 100th
89
+ values = np.round(values,2)
90
+ alpha_markers = gr.Slider(label='Alpha Markers',
91
+ minimum=0,
92
+ maximum=1,
93
+ step=0.05,
94
+ value=ALPHA_MARKERS[key],
95
+ key='alpha_markers_'+key,
96
+ interactive=True)
97
+ ret.append(alpha_edges)
98
+ ret.append(alpha_names)
99
+ ret.append(alpha_markers)
100
+ col2 = gr.Column(scale=5)
101
+ ret.insert(0,col2)
102
+ return ret
103
+
104
+ import spaces
105
+ @spaces.GPU(duration=300)
106
+ def _run(path,genes,N,progress_bar):
107
+ #Load the model
108
+ progress_bar(0.20, desc="Loading Model...",total=100)
109
+ try:
110
+ model,tokenizer = load_model(path)
111
+ except ValueError as e:
112
+ print(f"Error loading model '{path}': {e}")
113
+ gr.Warning("Model couldn't load. This space currently only works with AutoModelForCausalLM models. Please check the model architecture and try again.")
114
+ return None
115
+ except OSError as e:
116
+ print(f"Error loading model '{path}': {e}")
117
+ gr.Warning("Model doesn't seem to exist on the HuggingFace Hub. Please check the model name and try again.")
118
+ return None
119
+ except RuntimeError as e:
120
+ if 'out of memory' in str(e):
121
+ print(f"Error loading model '{path}': {e}")
122
+ gr.Warning("Loading the model triggered an out of memory error. It may be too big for the GPU (80Go RAM). Please try again with a smaller model.")
123
+ return None
124
+ else:
125
+ print(f"Error loading model '{path}': {e}")
126
+ gr.Warning("Model couldn't be loaded. Please check the logs or report an issue.")
127
+ return None
128
+ except Exception as e:
129
+ print(f"Error loading model '{path}': {e}")
130
+ gr.Warning("Model couldn't be loaded. Please check logs or report an issue.")
131
+ return None
132
+ progress_bar(0.25, desc="Generating data...",total=100)
133
+ for i,output in enumerate(llm_run(model,tokenizer,genes,N)):
134
+ progress_bar(0.25 + i*(70/len(genes))/100, desc=f"Generating data... {i+1}/{len(genes)}",total=100)
135
+ return output
136
+
137
+ def run(path,progress_bar):
138
+ global DEFAULT_FAMILY_NAME, PHYLOLM_N
139
+ family = DEFAULT_FAMILY_NAME
140
+ N = PHYLOLM_N
141
+ #Loading bar
142
+ progress_bar(0, desc="Downloading model...",total=100)
143
+ try:
144
+ # Download the model to cache
145
+ if download_llm_to_cache(path) is None:
146
+ gr.Warning("Model not found on Hugging Face Hub. Please check the model name and try again.")
147
+ return None
148
+ except OSError as e:
149
+ print(f"Error downloading model: {e}")
150
+ gr.Warning("Model not found on Hugging Face Hub. Please check the model name and try again.")
151
+ return None
152
+
153
+ # Load the model
154
+ progress_bar(0.10, desc="Loading contexts...",total=100)
155
+
156
+ with open('inputs/math.json', 'r') as f:
157
+ genes = json.load(f)
158
+
159
+ # Load the model and run
160
+ progress_bar(0.15, desc="Waiting for GPU...",total=100)
161
+
162
+ try:
163
+ output = _run(path,genes,N,progress_bar)
164
+ if output is None:
165
+ return None
166
+ except Exception as e:
167
+ print(f"Error running model: {e}")
168
+ gr.Warning("Something unexpected happened during the run or the loading of the model. Please check the logs or report an issue.")
169
+ return None
170
+
171
+ progress_bar(0.95, desc="Saving data ...",total=100)
172
+
173
+ alleles = [[compl[j]['generated_text'][len(gene):][:4] for j in range(len(compl))] for gene,compl in zip(genes,output)]
174
+ save_git(alleles,genes,path,family)
175
+
176
+ progress_bar(1, desc="Done!",total=100)
177
+
178
+
179
+ def prepare_run(model_name,progress_bar=gr.Progress()):
180
+ global MODEL_SEARCHED_X,MODEL_NAMES
181
+ if model_name in MODEL_NAMES:
182
+ gr.Warning('Model already exists in the database.')
183
+ MODEL_SEARCHED_X = model_name
184
+ reload_figures()
185
+ return
186
+ run(model_name,progress_bar)
187
+
188
+ def reload_env():
189
+ global SIM_MAT_SEARCH_X, SIM_MAT_SEARCH_Y, VIZ_SEARCH, TREE_SEARCH
190
+ global MODEL_NAMES, FAMILIES, COLORS, SIM_MATRIX, DIST_MATRIX
191
+ global FIGS, FIGS_OBJECTS
192
+
193
+ # Load models for the dropdown
194
+ data, model_names, families, sim_matrix, colors = load_data()
195
+
196
+ sim_matrix_safe = np.where(sim_matrix == 0, np.finfo(np.float64).eps, sim_matrix)
197
+ dist_matrix = -np.log(sim_matrix_safe)
198
+
199
+ #Set globals
200
+ MODEL_NAMES = model_names
201
+ FAMILIES = families
202
+ COLORS = colors
203
+ SIM_MATRIX = sim_matrix
204
+ DIST_MATRIX = dist_matrix
205
+
206
+ #Update Figs
207
+ ordered_sim_matrix, ordered_model_names = compute_ordered_matrix(sim_matrix,dist_matrix, model_names)
208
+ ORDERED_MODEL_NAMES = ordered_model_names
209
+ FIGS['fig1'] = plot_sim_matrix_fig(ordered_sim_matrix, ordered_model_names, families, colors)
210
+ FIGS['fig2'] = plot_umap_fig(dist_matrix, sim_matrix, model_names, families, colors,
211
+ alpha_edges=ALPHA_EDGES['fig2'],alpha_names=ALPHA_NAMES['fig2'],alpha_markers=ALPHA_MARKERS['fig2'])
212
+ FIGS['fig4'] = plot_tree(sim_matrix, model_names, families, colors,alpha_edges=ALPHA_EDGES['fig4'],alpha_names=ALPHA_NAMES['fig4'],alpha_markers=ALPHA_MARKERS['fig4'])
213
+
214
+ #Update search bars
215
+ sim_mat_search_x = gr.Dropdown(label='Search X',choices=model_names,value=model_names[0],key='model_search_x_fig1',interactive=True)
216
+ sim_mat_search_y = gr.Dropdown(label='Search Y',choices=model_names,value=model_names[0],key='model_search_y_fig1',interactive=True)
217
+ viz_search = gr.Dropdown(label='Search',choices=model_names,value=model_names[0],key='model_search_fig2',interactive=True)
218
+ tree_search = gr.Dropdown(label='Search',choices=model_names,value=model_names[0],key='model_search_fig4',interactive=True)
219
+
220
+ return FIGS['fig1'], FIGS['fig2'], FIGS['fig4'], sim_mat_search_x, sim_mat_search_y, viz_search, tree_search
221
+
222
+
223
+
224
+ # Load environment variables
225
+
226
+ USERNAME = os.environ['GITHUB_USERNAME']
227
+ TOKEN = os.environ['GITHUB_TOKEN']
228
+ MAIL = os.environ['GITHUB_MAIL']
229
+
230
+ MODEL_SEARCHED_X = None
231
+ MODEL_SEARCHED_Y = None
232
+ ALPHA_EDGES = {'fig2':0.05, 'fig3':0.05,'fig4':1.0}
233
+ ALPHA_NAMES = {'fig2':0.0, 'fig3':0.0,'fig4':0.0}
234
+ ALPHA_MARKERS = {'fig2':0.8, 'fig3':0.8,'fig4':1.0}
235
+
236
+ FIGS = {'fig1':None,'fig2':None,'fig3':None,'fig4':None}
237
+ FIGS_OBJECTS = [None,None,None]
238
+ MODEL_NAMES = None
239
+ FAMILIES = None
240
+ COLORS = None
241
+ ORDERED_MODEL_NAMES = None
242
+ SIM_MATRIX = None
243
+ DIST_MATRIX = None
244
+
245
+ DEFAULT_FAMILY_NAME = '?'
246
+ PHYLOLM_N = 32
247
+
248
+ SIM_MAT_SEARCH_X = None
249
+ SIM_MAT_SEARCH_Y = None
250
+ VIZ_SEARCH = None
251
+ TREE_SEARCH = None
252
+
253
+ # Build the Gradio interface
254
+ with gr.Blocks(title="PhyloLM", theme=gr.themes.Default()) as demo:
255
+ gr.Markdown("# PhyloLM: Phylogenetic Mapping of Language Models")
256
+
257
+ gr.Markdown(
258
+ "Welcome to PhyloLM ([paper](https://arxiv.org/abs/2404.04671) - [code](https://github.com/Nicolas-Yax/PhyloLM)) — a tool for comparing language models based on their **behavioral similarity**, inspired by methods from comparative genomics. "
259
+ "Instead of architecture or weights, we use output behavior on diagnostic prompts as a behavioral fingerprint to compute a distance metric, akin to how biologists compare species using genetic data. This makes it possible to draw a unique map of all LLMs (various architectures, gated and non gated, ...)."
260
+ "The goal of this space is to create a collaborative space where everyone can visualize these maps and extend them with models of their choice. "
261
+ )
262
+
263
+ gr.Markdown("## Explore Maps of Models")
264
+
265
+ gr.Markdown(
266
+ "This interactive space allows users to explore model similarities through four types of visualizations:\n"
267
+ "- A similarity matrix (values range from 0 = dissimilar to 1 = highly similar). \n"
268
+ "- 2D and 3D scatter plots representing how close or far from each other LLMs are (plotted using UMAP). \n"
269
+ "- A tree to visualize distances between models (distance from leaf A to leaf B in the tree is similar to the distance between the two models)\n\n"
270
+ )
271
+
272
+ # Load models for the dropdown
273
+ data, model_names, families, sim_matrix, colors = load_data()
274
+
275
+ sim_matrix_safe = np.where(sim_matrix == 0, np.finfo(np.float64).eps, sim_matrix)
276
+ dist_matrix = -np.log(sim_matrix_safe)
277
+
278
+ #Set globals
279
+ MODEL_NAMES = model_names
280
+ FAMILIES = families
281
+ COLORS = colors
282
+ SIM_MATRIX = sim_matrix
283
+ DIST_MATRIX = dist_matrix
284
+
285
+ # Create the tabs
286
+ tab_state = gr.State(value="Similarity Matrix") # Default tab
287
+ tabs = gr.Tabs(["Similarity Matrix", "2D Visualization","Tree Visualization"])
288
+ with tabs:
289
+ with gr.TabItem("Similarity Matrix"):
290
+ # Similarity matrix visualization
291
+ with gr.Row():
292
+ col2,sim_mat_search_x,sim_mat_search_y = search_bar_gr(model_names,slider=False,double_search=True,key='fig1')
293
+ with col2:
294
+ ordered_sim_matrix, ordered_model_names = compute_ordered_matrix(sim_matrix,dist_matrix, model_names)
295
+ fig = plot_sim_matrix_fig(ordered_sim_matrix, ordered_model_names, families, colors)
296
+ sim_matrix_output = gr.Plot(fig,label="Similarity Matrix")
297
+ FIGS['fig1'] = fig
298
+ ORDERED_MODEL_NAMES = ordered_model_names
299
+ FIGS_OBJECTS[0] = sim_matrix_output
300
+ with gr.TabItem("2D Visualization"):
301
+ # 2D visualization
302
+ with gr.Row():
303
+ col2,viz_search,viz_alpha_edge,viz_alpha_name,viz_alpha_marker = search_bar_gr(model_names,slider=True,double_search=False,key='fig2')
304
+ with col2:
305
+ fig = plot_umap_fig(dist_matrix, sim_matrix, model_names, families, colors,
306
+ alpha_edges=ALPHA_EDGES['fig2'],alpha_names=ALPHA_NAMES['fig2'],alpha_markers=ALPHA_MARKERS['fig2'])
307
+ plot_output = gr.Plot(fig,label="2D Visualization")
308
+ FIGS['fig2'] = fig
309
+ FIGS_OBJECTS[1] = plot_output
310
+ with gr.TabItem("Tree Visualization"):
311
+ # Tree visualization
312
+ with gr.Row():
313
+ col2,tree_search,tree_alpha_edge,tree_alpha_name,tree_alpha_marker = search_bar_gr(model_names,slider=True,double_search=False,key='fig4')
314
+ with col2:
315
+ fig = plot_tree(sim_matrix, model_names, families, colors,alpha_edges=ALPHA_EDGES['fig4'],alpha_names=ALPHA_NAMES['fig4'],alpha_markers=ALPHA_MARKERS['fig4'])
316
+ tree_output = gr.Plot(fig,label="Tree Visualization")
317
+ FIGS['fig4'] = fig
318
+ FIGS_OBJECTS[2] = tree_output
319
+
320
+
321
+ # Submit model section
322
+ gr.Markdown("## Submitting a Model")
323
+
324
+ gr.Markdown(
325
+ "You may contribute new models to this collaborative space using compute resources. "
326
+ "Once processed, the model will be compared to existing ones, and its results added to a shared public database. "
327
+ "Model families (e.g., LLaMA, OPT, Mistral) are extracted from Hugging Face model cards and used only for visualization (e.g., coloring plots); they are **not** involved in the computation of similarity."
328
+ )
329
+
330
+ gr.Markdown(
331
+ "**To add a new model:**\n"
332
+ "1. Enter the name of a model hosted on Hugging Face (e.g., `'mistralai/Mistral-7B-Instruct-v0.3'`).\n"
333
+ "2. Click on the **Run PhyloLM** button.\n"
334
+ "- If the model has already been processed, you'll be notified and no new run will start.\n"
335
+ "- If it hasn't been processed, it will be downloaded and be evaluated.\n\n"
336
+ "⚠️ Be careful when submitting large LLMs (typically >15B parameters) as they may exceed the GPU RAM or the time limit, leading to failed runs."
337
+ )
338
+
339
+ with gr.Group():
340
+ model_input = gr.Textbox(label="Model", interactive=True)
341
+ submit_btn = gr.Button("Run PhyloLM", variant="primary")
342
+
343
+
344
+ # Disclaimer and citation
345
+ gr.Markdown("## Disclaimer")
346
+ gr.Markdown(
347
+ "This is a research prototype and may contain bugs or limitations. "
348
+ "All computed data are public and hosted on [GitHub](https://github.com/PhyloLM/Data). "
349
+ "If you'd like to contribute additional models — especially for gated or large models that cannot be processed via the web interface — "
350
+ "you are welcome to submit a pull request to the repository cited above. "
351
+ "All results are computed on the 'Math' set of genes used in the original paper."
352
+ )
353
+
354
+ gr.Markdown("## Citation")
355
+ gr.Markdown("If you find this project useful for your research, please consider citing the following paper:")
356
+
357
+ #bibtex
358
+ gr.Code('''@inproceedings{
359
+ yax2025phylolm,
360
+ title={Phylo{LM}: Inferring the Phylogeny of Large Language Models and Predicting their Performances in Benchmarks},
361
+ author={Nicolas Yax and Pierre-Yves Oudeyer and Stefano Palminteri},
362
+ booktitle={The Thirteenth International Conference on Learning Representations},
363
+ year={2025},
364
+ url={https://openreview.net/forum?id=rTQNGQxm4K}
365
+ }''',language=None)
366
+
367
+ # Change actions from search bars
368
+ sim_mat_search_x.change(fn=search_bar_changeX, inputs=sim_mat_search_x, outputs=FIGS_OBJECTS)
369
+ sim_mat_search_y.change(fn=search_bar_changeY, inputs=sim_mat_search_y, outputs=FIGS_OBJECTS)
370
+
371
+ viz_search.change(fn=search_bar_changeX, inputs=viz_search, outputs=FIGS_OBJECTS)
372
+
373
+ tree_search.change(fn=search_bar_changeX, inputs=tree_search, outputs=FIGS_OBJECTS)
374
+
375
+ # Change actions from sliders
376
+ viz_alpha_edge.change(fn=lambda x : slider_changeAlphaEdges(x,'fig2'), inputs=viz_alpha_edge, outputs=FIGS_OBJECTS)
377
+ viz_alpha_name.change(fn=lambda x : slider_changeAlphaNames(x,'fig2'), inputs=viz_alpha_name, outputs=FIGS_OBJECTS)
378
+ viz_alpha_marker.change(fn=lambda x : slider_changeAlphaMarkers(x,'fig2'), inputs=viz_alpha_marker, outputs=FIGS_OBJECTS)
379
+
380
+ tree_alpha_edge.change(fn=lambda x : slider_changeAlphaEdges(x,'fig4'), inputs=tree_alpha_edge, outputs=FIGS_OBJECTS)
381
+ tree_alpha_name.change(fn=lambda x : slider_changeAlphaNames(x,'fig4'), inputs=tree_alpha_name, outputs=FIGS_OBJECTS)
382
+ tree_alpha_marker.change(fn=lambda x : slider_changeAlphaMarkers(x,'fig4'), inputs=tree_alpha_marker, outputs=FIGS_OBJECTS)
383
+
384
+ # Run PhyloLM button
385
+ submit_btn.click(fn=prepare_run, inputs=[model_input], outputs=[model_input]).then(fn=reload_env, inputs=[], outputs=FIGS_OBJECTS+ [sim_mat_search_x, sim_mat_search_y, viz_search, tree_search])
386
+
387
+ #Set more globals
388
+ SIM_MAT_SEARCH_X = sim_mat_search_x
389
+ SIM_MAT_SEARCH_Y = sim_mat_search_y
390
+ VIZ_SEARCH = viz_search
391
+ TREE_SEARCH = tree_search
392
+
393
+
394
+
395
+ if __name__ == "__main__":
396
+ demo.launch()
constants.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from matplotlib import colors as mcolors
2
+
3
+ UNKNOWN_COLOR = 'gray'
4
+ UNKNOWN_COLOR_RGB = mcolors.to_rgb(UNKNOWN_COLOR)
5
+ UNKNOWN_COLOR_RGB = tuple([int(255 * c) for c in UNKNOWN_COLOR_RGB])
6
+ DEFAULT_COLOR = 'black'
7
+ DEFAULT_COLOR_RGB = mcolors.to_rgb(DEFAULT_COLOR)
8
+ DEFAULT_COLOR_RGB = tuple([int(255 * c) for c in DEFAULT_COLOR_RGB])
family_table.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"JosephusCheung\/Guanaco":"Llama","Intel\/neural-chat-7b-v3":"Mistral","Intel\/neural-chat-7b-v3-1":"Mistral","teknium\/OpenHermes-2-Mistral-7B":"Mistral","teknium\/OpenHermes-2.5-Mistral-7B":"Mistral","teknium\/OpenHermes-13B":"Llama","teknium\/OpenHermes-7B":"Llama","mistralai\/Mistral-7B-Instruct-v0.2":"Mistral","mistralai\/Mixtral-8x7B-Instruct-v0.1":"Mistral","mistralai\/Mistral-7B-v0.1":"Mistral","mistralai\/Mixtral-8x7B-v0.1":"Mistral","mistralai\/Mistral-7B-Instruct-v0.1":"Mistral","chavinlo\/alpaca-native":"Llama","CausalLM\/14B":"Qwen","CausalLM\/7B":"Qwen","bigscience\/bloom-3b":"Bloom","bigscience\/bloom-7b1":"Bloom","bigscience\/bloomz-3b":"Bloom","bigscience\/bloom":"Bloom","bigscience\/bloomz-7b1":"Bloom","berkeley-nest\/Starling-LM-7B-alpha":"Mistral","EleutherAI\/pythia-6.9b":"Pythia","EleutherAI\/pythia-1.4b":"Pythia","EleutherAI\/pythia-12b":"Pythia","EleutherAI\/pythia-2.8b":"Pythia","EleutherAI\/pythia-410m":"Pythia","EleutherAI\/pythia-70m":"Pythia","EleutherAI\/pythia-160m":"Pythia","roneneldan\/TinyStories-1M":"TinyStories","lmsys\/vicuna-13b-v1.5":"Llama","lmsys\/vicuna-7b-v1.1":"Llama","lmsys\/vicuna-13b-v1.3":"Llama","lmsys\/vicuna-7b-v1.5":"Llama","lmsys\/vicuna-13b-v1.1":"Llama","lmsys\/vicuna-7b-v1.3":"Llama","google\/gemma-7b":"Gemma","google\/codegemma-7b":"Gemma","google\/gemma-2b-it":"Gemma","google\/codegemma-2b":"Gemma","google\/codegemma-7b-it":"Gemma","google\/gemma-1.1-7b-it":"Gemma","google\/gemma-2b":"Gemma","google\/gemma-1.1-2b-it":"Gemma","google\/gemma-7b-it":"Gemma","microsoft\/Orca-2-13b":"Llama","microsoft\/Orca-2-7b":"Llama","Imran1\/MedChat3.5":"Mistral","tenyx\/TenyxChat-7B-v1":"Mistral","databricks\/dolly-v2-7b":"Pythia","databricks\/dolly-v2-3b":"Pythia","databricks\/dolly-v2-12b":"Pythia","Qwen\/Qwen-1_8B":"Qwen","Qwen\/Qwen1.5-0.5B":"Qwen","Qwen\/Qwen1.5-72B-Chat":"Qwen","Qwen\/Qwen1.5-7B-Chat":"Qwen","Qwen\/Qwen1.5-2B-Chat":"Qwen","Qwen\/Qwen1.5-7B":"Qwen","Qwen\/Qwen1.5-72B":"Qwen","Qwen\/Qwen1.5-32B-Chat":"Qwen","Qwen\/Qwen1.5-4B-Chat":"Qwen","Qwen\/Qwen1.5-1.8B":"Qwen","Qwen\/Qwen1.5-14B-Chat":"Qwen","Qwen\/Qwen1.5-0.5B-Chat":"Qwen","Qwen\/Qwen-72B":"Qwen","Qwen\/Qwen-14B":"Qwen","Qwen\/Qwen1.5-4B":"Qwen","Qwen\/Qwen1.5-14B":"Qwen","Qwen\/Qwen-7B":"Qwen","Qwen\/Qwen1.5-32B":"Qwen","OpenAssistant\/oasst-sft-4-pythia-12b-epoch-3.5":"Pythia","mlabonne\/NeuralHermes-2.5-Mistral-7B":"Mistral","facebook\/opt-6.7b":"OPT","facebook\/opt-125m":"OPT","facebook\/opt-66b":"OPT","facebook\/opt-13b":"OPT","facebook\/opt-1.3b":"OPT","facebook\/opt-30b":"OPT","facebook\/opt-350m":"OPT","facebook\/opt-2.7b":"OPT","HuggingFaceH4\/zephyr-7b-beta":"Mistral","HuggingFaceH4\/zephyr-7b-alpha":"Mistral","openchat\/openchat_v2":"Llama","openchat\/openchat_v2_w":"Llama","openchat\/openchat_v3.2":"Llama","openchat\/openchat_v3.1":"Llama","openchat\/openchat_3.5":"Mistral","openchat\/openchat_v3.2_super":"Llama","TigerResearch\/tigerbot-13b-base-v2":"Llama","TigerResearch\/tigerbot-7b-base-v2":"Bloom","TigerResearch\/tigerbot-7b-chat":"Llama","TigerResearch\/tigerbot-13b-chat-v1":"Llama","TigerResearch\/tigerbot-7b-sft-v1":"Bloom","TigerResearch\/tigerbot-7b-sft-v2":"Bloom","TigerResearch\/tigerbot-13b-chat-v2":"Llama","TigerResearch\/tigerbot-13b-chat-v3":"Llama","TigerResearch\/tigerbot-13b-chat-v4":"Llama","TigerResearch\/tigerbot-7b-base-v1":"Bloom","TigerResearch\/tigerbot-13b-base-v1":"Llama","TigerResearch\/tigerbot-7b-base":"Llama","fxmarty\/tiny-llama-fast-tokenizer":"fxmarty","project-baize\/baize-v2-7b":"Llama","huggyllama\/llama-7b":"Llama","huggyllama\/llama-13b":"Llama","Arc53\/docsgpt-7b-mistral":"Mistral","meta-llama\/Llama-2-7b-hf":"Llama","meta-llama\/Llama-2-13b-hf":"Llama","meta-llama\/Llama-2-7b":"Llama"}
inputs/math.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["# In observing a Tetrahedron...", "# Strong and Weak Form Solution \u2013 FEA\n\nPartial Differential Equations \u2013 PDE is called \u201cstro", "# Loci Browse Articles\n\nDisplaying 41 - 50 of 323\n\nThis article describes methods for cr", "Paul's Online Notes\nHome / Calculus I / Review / Trig Functions\nShow Mobile Notice Show All Not", "# Recursive formula for joint moments in free probability\n\nSuppose $\\mathfrak{A}$ is an algebra ", "# Is the oxygen molecule $O_2$ fermion or boson?\n\nI ask something makes me confused.\n", "# Python Indices of numbers greater than K\n\nIn this tutorial, we are going to ", "To find total cost, to the nearest cent, to cool the house for this 24-hour p", "# Propositional Logic", "Modeling the train reservation kata -", "[texhax] environment ", "# stability \u2026 boring old and simple stability\n\n[xxx@yyy~]$uptime 15:07:51 up 505 days, 47 min", "By accessing our 180 Days of Math for Sixth Grade Answers Key Day 72 regularly, students can get b", "E. Square Root of Permutation\ntime limit per test\n2 seconds\nmemory l", "# Skills\n\nThe Discworld skill model is large and complex. It is broken up into eight branches.\n", "# Just Some Division\n\nNumber Theory Level 1\n\n$$N$$ is a positive integer such that $$10", "MathSciNet bibliographic data MR343259 55B25 (55G35 57E15) Matumoto, Takao Eq", "InTech uses cookies to offer you the best online experience. By continuing to use our sit", "## Train Tracks\n\nConsider a segment of", "## Stream: general\n\n### Topic: detecti", "This vignette discusses data.table\u2019s reference semant", "A cell made up of two hydrogen electrodes. The positive electrode is in cont", "# Better way to calculate coordinates in Tikz?\n\nI am great fan of pgf and tikz in general to pro", "# Physics (Version 8.4", "# Abc conjecture\n\n\ufeff\nAbc conjecture\n\nThe abc conjecture is a conjecture", "# No offense intended, but\u2026\n\nI have", "Select Board & Class\n\nAreas Related to Circles\n\nTo und", "Determine whether $\\sum_{n=1}^\\infty \\frac{\\sin^2 n}{n^2}$ conv", "# Calculating Derivative \u2013 Third root \u2013 Exercise 1106\n\nExercise\n\nFind the derivati", "# BE Thesis\n\nSenast inlagda poster:\n2018-09-06\n", "# All Questions\n\n1,360 questions\nFilter by\nSorted by\nTagged with\n1answer\n113 views\n\n###", "# Ignatius and the P", "## anonymous one year ago What is the fifth term of the sequence who", "size - Maple Help\n\nMTM\n\n size\n", "http://en.wikipedia.org/wiki/Taylor_series\n\n## Taylor series in several var", "# I Using determinant to find constraints on equation\n\nTags:\n1. Jan 15, 2017\n\n### TheDemx27\n", "# Decimal Numbers\n\nDecimal numbers are similar to fractions in basic principle.\n\nThey are often u", "Sales Toll Free No: 1-800-481-2338\n\n# How to Divide Monomial by a Non Zero Constant?\n\nTopPolynomial", "# Transformation of continuous uniform distribution\n", "CTF Team at the University of British Columbia\n\n# [corCTF 2021] smogofwar\n\n25 Aug 2021 ", "0\nResearch Papers\n\n# Analytic and Geometric ", "# Difference between revisions of \"1984 AIME Problems/Problem 11\"\n\n## Problem\n\nA ", "# Image Mosaicking\u00b6\n\n#", "# When two smaller atoms combine into a larger atom what has occurred?", "# Bessel functions\n\n(diff) \u2190 Older revision | La", "## WeBWorK Main Forum\n\n### Why Giga newton is not ", "### Homes\n\nThere are ", "## Return to Question\n\n2 deleted 256 characters in body\n\nA complex manifold $X$ is said ", "# Algebra Examples\n\nFind Pivot Positions and Pivot Co", "# The Physics Behind the American Death Tr", "Data\n\n1. Title: Webs and $q$-Howe dualities in types $\\mathbf{B}\\mathbf{C}\\math", "# All Questions\n\n1,524 questions\n", "# Pydon'ts\n\n## Improve your Python programming skills\n\n### Start here.\n\n294\n###", "# Math Help - 2-norm of a ma", "[texhax] \\mid Description\n\n", "# Q : 3\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0Find the area of the region bou", "as.epidata {EpiILM} R Documentation\n\nDiscrete Tim", "# Basic Matrix Row Operations Tips\n\nThere are four basic op", "# Nets within nets from the Grothendieck\u00a0const", "Free Version\nEasy\n\n# Same Base Exponential Equations\n\n", "# Preserving Plant Diversity\n\n## Location:", "# Crash on HWNDComponentPeer::destroyWindowCallback\n\nIt just feels like you\u2019r", "# SNR\u00b6\n\nclass gammapy.astro.source.SNR(e_sn='1e51 erg', theta=<Quantity 0.1>, n_ISM=<Qua", "# Forecasting Pseudo Random Numbers Using Deep Learning\n\nPublisher: IEEE\n\n", "# Including both transformed and original data (untransformed) in a multivariable linear regressio", "# Technical Fridays\n\nFriday, September 1, 2017\n\nIn 1973, the Un", "# Representation of simple groups\n\nLet $G$ be a finite simple group, prove th", "# The Impact of Meteorological Facto", "Advertisement Remove all ads\n\n# If\u00a0y=log[x+sqrt(x^2+a^2)]\u00a0show that\u00a0(x^2+a^2)(d^2y)/(dx^2)+xdy", "2020 | Book\n\n# Principles of Data M", "# Thread: Math problem Parenthesis & PEMDAS\n\n1. ## Math problem Parenthesis & PEMDAS\n\nhi i n", "# Math Help - Physics problem; Find expression for veloci", "# Base class for finite field element", "# 0.1 Review exercises (ch 3-13) \u00a0(Page 11/12)\n\n Page 11 / 12\n\n130. Out", "All Rights Reserved. However, a compass needle will not be steady in the magneti", "# How is rest mass $m_0$ in $E=m_0c^2$ related to mass $m$ in $F=ma$?\n\nA p", "# All Questions\n\n7 views\n\n### Curve fitting of a list\n\nI have list obtained using a", "1 $\\begingroup$ Close", "# Annual income of A a", "## Seminars and Colloquia by Series\n\n### Geometric Equations for Matroid Varieties\n\nSeries\nSIAM S", "# How to prove that $C=\\{x: Ax\\le ", "# Phase locked Loop in Demodulation\n\nCan someone please clarify how a PLL works and how it can th", "# What is forward difference interpolation?\n\n## What is forward difference interpolati", "# Math Help - matlab code hel", "Question\n\n# What is the speed of the sound in a perfectly rigid rod?\n\nOpen in App\nSoluti", "# If $\\Sigma \\models \\phi$, then for some finite $\\Delta \\subset\\Sigma$, $\\Delta \\models", "# Need some help on a proof!\n\n1. Dec 12, 2004\n\n### MathematicalMatt\n\nHowdy, I just stumbled o", "## Calculus (3rd Edition)\n\n$f(x)=[x]$ has a jump discontinuity at $x=n$.\nThe function ", "Orthogonal functions\n\nOrthogonality\n\nTwo fu", "## MacKenzie's fundamental principle of greenkeeping\n\n##### 05 May 2017\n\nI taught two semi", "2 added 246 characters in body\n\n\"If you are walking between two policemen goin", "# \u3010BZOJ 4571\u3011[SCOI2016] \u7f8e\u5473\n\n#include<bits/stdc++.h>\n#define LL long long\nus", "# Math Help - confidence interval help!!!!\n\n1. ## confide", "# Alternatives\n\n## Summing Squares: Finding or Proving a Formula\n", "Thank you for visiting nature", "# How do I keep a string of text together without", "# Q6. In a model of a ship, the mast", "# How does the Taylor Series converge at all points f", "# Revealed preference\n\nRevealed preference theory, pioneered by American ec", "Previous issue \u00b7\u00a0 Next issue \u00b7\u00a0", "Bits - Maple Programming Help\n\nHome : Su", "# Difference between revisions of \"", "## [POJ2411]Mondriaan\\'s Dream\n\n \u6210", "# Period ofWeeks() method in Java\n\nJava 8Object Oriented ProgrammingProgramming\n\nThe", "Debugging graphs with ease: experimental Visual Studio plugin\n\nRevis", "k-mer overrepresentation of WGS Illumina reads\n0\n0\nEntering edit mode\n3.8 years a", "# Hydrometeorology Research Group\n\nIn\u00a0[5]:\nfrom IPython.display import HTM", "# \ud83d\udd35\u26aa\ud83d\udd34 Bioinspired tough gel sheath for robust and versatile surface functionalization \u2013 Content Mar", "# Recurring Decimal To Fraction Cal", "# Integral related to the modified Bessel function\n\nI would like to solve t", "# Viewpoint: Particle Decays Point to an Arrow of Time\n\n\u2022 Michael Ze", "CiteULike is a free online bibliography manager. Register a", "# Social Media, Misinformation, and Voting Decisions\n\nWorking Pap", "# Test Video\n\nAller \u00e0 : Navigation, rechercher\n\nThis is an", "Enterprise Multiples\n\n.\n\nEV/SALES\n\nEqu", "PLANET Discussion: Daeridune\n\nDiscussion in 'The Manaverse W", "# Tag Info\n\n## Hot answers tagged total\n\n6\n\nDon't subtrac", "# Welcome to grmpy\u2019s documentation!\u00b6\n\ngrmpy is an open-source packag", "# Solve the linear equatio", "## February 25, 2008\n\n### A Questio", "# Concept and expression of a real function\n\n## Concept o", "# Optimization Week 9: Convex conjugate (Fenche", "# Inversions of Insertion Sort and Bubble Sort\n\nAn array with bubblesort time $$\\Theta(n)$$ is noth", "## Simple power series\n\nHey all,\n\nDoes anybody know if $x^{\\alpha}$ can be written in terms of an ", "# HiggsTools\n\nStephen Jones (MPI Muni", "# Mathematics 1010 online\n\n## Complex Numbers\n\nRecall how we built the numbe", "# string.replace.regex\n\nSynt", "# Math Help - Odd Integers\n\n1. ## Odd Integers\n\nWhat is the product if the largest of three cons"]
llm_run.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from huggingface_hub import snapshot_download,constants
3
+
4
+ def download_llm_to_cache(model_name, revision="main", cache_dir=None):
5
+ """
6
+ Download an LLM from the Hugging Face Hub to the cache without loading it into memory.
7
+
8
+ Args:
9
+ model_name (str): The name of the model on Hugging Face Hub (e.g., "meta-llama/Llama-2-7b-hf")
10
+ revision (str, optional): The specific model version to use. Defaults to "main".
11
+ cache_dir (str, optional): The cache directory to use. If None, uses the default HF cache directory.
12
+
13
+ Returns:
14
+ str: Path to the model in cache
15
+ """
16
+ # Get default cache dir if not specified
17
+ if cache_dir is None:
18
+ cache_dir = constants.HUGGINGFACE_HUB_CACHE
19
+
20
+ try:
21
+ # Download model to cache without loading into memory
22
+ cached_path = snapshot_download(
23
+ repo_id=model_name,
24
+ revision=revision,
25
+ cache_dir=cache_dir,
26
+ local_files_only=False # Set to True if you want to check local cache only
27
+ )
28
+
29
+ print(f"Model '{model_name}' is available in cache at: {cached_path}")
30
+ return cached_path
31
+
32
+ except Exception as e:
33
+ print(f"Error downloading model '{model_name}': {e}")
34
+ return None
35
+
36
+ def load_model(path,cache_dir=None):
37
+ model = transformers.AutoModelForCausalLM.from_pretrained(path,cache_dir=cache_dir,device_map='auto')
38
+ tokenizer = transformers.AutoTokenizer.from_pretrained(path,cache_dir=cache_dir,device_map='auto')
39
+ return model,tokenizer
40
+
41
+ def llm_run(model,tokenizer,genes,N):
42
+ generate = transformers.pipeline('text-generation',model=model, tokenizer=tokenizer,device_map='auto')
43
+ output = []
44
+ for i,gene in enumerate(genes):
45
+ out = generate([gene], min_new_tokens=4, max_new_tokens=4, do_sample=True, num_return_sequences=N)
46
+ output.append(out[0])
47
+ yield output
48
+ return output
loading.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ujson as json
3
+ import pygit2
4
+
5
+ from phylogeny import compute_all_P, compute_sim_matrix
6
+ from plotting import get_color, UNKNOWN_COLOR, DEFAULT_COLOR
7
+ # ------------------------------------------------------------------------------------------------
8
+ #
9
+ # Loading data
10
+ #
11
+ # ------------------------------------------------------------------------------------------------
12
+ def load_data():
13
+ global UNKNOWN_COLOR, DEFAULT_COLOR, MODEL_SEARCHED_X
14
+ data, model_names,families = load_git()
15
+ if data is None:
16
+ return
17
+
18
+ #Rename families if needed
19
+ with open('family_table.json','r') as f:
20
+ rename_table = json.load(f)
21
+
22
+ for i in range(len(model_names)):
23
+ try:
24
+ families[i] = rename_table[model_names[i]]
25
+ except KeyError:
26
+ pass
27
+
28
+ all_P = compute_all_P(data, model_names)
29
+ sim_matrix = compute_sim_matrix(model_names, all_P)
30
+
31
+ k = list(all_P.keys())[0]
32
+
33
+ unknown_color = UNKNOWN_COLOR
34
+
35
+ unique_families = list(set([f for f in families]))
36
+ colors = {}
37
+ idx = 0
38
+ for i, family in enumerate(unique_families):
39
+ color = get_color(idx)
40
+ idx += 1
41
+ while color == unknown_color: # Avoid using the unknown color for a family
42
+ color = get_color(idx)
43
+ idx += 1
44
+ colors[family] = color
45
+
46
+ colors['?'] = unknown_color # Assign the unknown color to the unknown family
47
+
48
+ return data, model_names, families, sim_matrix, colors
49
+
50
+ def load_git():
51
+ cred = pygit2.UserPass(os.environ['GITHUB_USERNAME'], os.environ['GITHUB_TOKEN'])
52
+ if os.path.exists('Data'):
53
+ repo = pygit2.Repository('Data')
54
+ remote = repo.remotes['origin'] # Use named reference instead of index
55
+ remote.fetch()
56
+
57
+ # Get the current branch name
58
+ branch_name = repo.head.shorthand
59
+
60
+ # Find the reference to the remote branch
61
+ remote_ref_name = f'refs/remotes/origin/{branch_name}'
62
+
63
+ # Merge the changes into the current branch
64
+ remote_commit = repo.lookup_reference(remote_ref_name).target
65
+
66
+ else:
67
+ repo = pygit2.clone_repository('https://github.com/PhyloLM/Data', './Data', bare=False, callbacks=GitHubRemoteCallbacks(os.environ['GITHUB_USERNAME'], os.environ['GITHUB_TOKEN']))
68
+
69
+ data_array = []
70
+ model_names = []
71
+ families = []
72
+ for foname in os.listdir('Data/math'):
73
+ #check if it is a directory
74
+ if not os.path.isdir(os.path.join('Data/math',foname)):
75
+ continue
76
+ for fname in os.listdir('Data/math/'+foname):
77
+ if not fname.endswith('.json'):
78
+ continue
79
+ with open(os.path.join('Data/math',foname,fname),'r') as f:
80
+ d = json.load(f)
81
+ families.append(d['family'])
82
+ model_names.append(foname+'/'+fname[:-5])
83
+ data_array.append(d['alleles'])
84
+
85
+ if data_array == []:
86
+ return None,[],[]
87
+ return data_array,model_names,families
88
+
89
+ # ------------------------------------------------------------------------------------------------
90
+ #
91
+ # Git functions
92
+ #
93
+ # ------------------------------------------------------------------------------------------------
94
+
95
+ class GitHubRemoteCallbacks(pygit2.RemoteCallbacks):
96
+ def __init__(self, username, token):
97
+ self.username = username
98
+ self.token = token
99
+ super().__init__()
100
+
101
+ def credentials(self, url, username_from_url, allowed_types):
102
+ return pygit2.UserPass(self.username, self.token)
103
+
104
+ # ------------------------------------------------------------------------------------------------
105
+ #
106
+ # Saving data
107
+ #
108
+ # ------------------------------------------------------------------------------------------------
109
+
110
+ def save_git(alleles,genes,model,family):
111
+ repo = pygit2.Repository('Data')
112
+ remo = repo.remotes['origin']
113
+
114
+ d = {'family':family,'alleles':alleles}
115
+ model_name = model
116
+ data_path = f'math/{model_name}.json'
117
+ path = os.path.join('Data',data_path)
118
+ #create the file folder path
119
+ if not os.path.exists(os.path.dirname(path)):
120
+ os.makedirs(os.path.dirname(path), exist_ok=True)
121
+ #Open the file
122
+ with open(path,'w') as f:
123
+ json.dump(d,f)
124
+
125
+ repo.index.add(data_path)
126
+ repo.index.write()
127
+ reference='HEAD'
128
+ tree = repo.index.write_tree()
129
+ author = pygit2.Signature(os.environ['GITHUB_USERNAME'], os.environ['GITHUB_MAIL'])
130
+ commiter = pygit2.Signature(os.environ['GITHUB_USERNAME'], os.environ['GITHUB_MAIL'])
131
+ oid = repo.create_commit(reference, author, commiter, f'Add data for model {model}', tree, [repo.head.target])
132
+ remo.push(['refs/heads/main'],callbacks=GitHubRemoteCallbacks(os.environ['GITHUB_USERNAME'],os.environ['GITHUB_TOKEN']))
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ graphviz
2
+ graphviz-dev
phylogeny.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from constants import UNKNOWN_COLOR, DEFAULT_COLOR, UNKNOWN_COLOR_RGB, DEFAULT_COLOR_RGB
4
+
5
+ def compute_P(alleles):
6
+ '''Compute the population matrix P(allele|gene) from the [alleles] given in input'''
7
+ P = []
8
+ # Process each gene position
9
+ for gene_alleles in alleles:
10
+ # Use Counter for more efficient counting
11
+ unique_alleles, counts = np.unique(gene_alleles, return_counts=True)
12
+ # Create frequency dictionary directly
13
+ d = dict(zip(unique_alleles, counts / len(gene_alleles)))
14
+ P.append(d)
15
+ return P
16
+
17
+ def compute_all_P(data, models):
18
+ '''Compute all population matrices from a given list of [models] on the data'''
19
+ all_P = {}
20
+ for mi, m in enumerate(models):
21
+ alleles = data[mi]
22
+ P = compute_P(alleles)
23
+ all_P[m] = P
24
+ return all_P
25
+
26
+ def compute_sim_matrix(models,all_P):
27
+ '''Compute the entire similarity matrix in one go'''
28
+ n_models = len(models)
29
+ n_genes = len(all_P[models[0]])
30
+
31
+ # Initialize matrices to store numerator and denominator terms
32
+ total_numerator = np.zeros((n_models, n_models))
33
+ left_denominators = np.zeros(n_models)
34
+ right_denominators = np.zeros(n_models)
35
+
36
+ # Process each gene position
37
+ for k in range(n_genes):
38
+ # Collect all alleles for this gene position
39
+ all_alleles = set()
40
+ for m in models:
41
+ all_alleles.update(all_P[m][k].keys())
42
+ all_alleles = list(all_alleles)
43
+
44
+ # Create frequency vectors for each model
45
+ freq_matrix = np.zeros((n_models, len(all_alleles)))
46
+ for i, m in enumerate(models):
47
+ for j, allele in enumerate(all_alleles):
48
+ if allele in all_P[m][k]:
49
+ freq_matrix[i, j] = all_P[m][k][allele]
50
+
51
+ # Update numerator: dot product of frequency vectors
52
+ total_numerator += np.dot(freq_matrix, freq_matrix.T)
53
+
54
+ # Update denominators: sum of squared frequencies
55
+ squared_sums = np.sum(freq_matrix**2, axis=1)
56
+ left_denominators += squared_sums
57
+ right_denominators += squared_sums
58
+
59
+ # Calculate final similarity matrix
60
+ denominator_matrix = np.sqrt(np.outer(left_denominators, right_denominators))
61
+ sim_matrix = total_numerator / denominator_matrix
62
+
63
+ return sim_matrix
64
+
65
+ def prepare_tree(tree, model_names, origins, colors):
66
+ """Prepare and color the phylogenetic tree based on model families."""
67
+ # Remove inner node names and color leaf nodes
68
+ for clade in tree.find_clades():
69
+ if clade.name and (clade.name.startswith('Inner') or clade.name.startswith('Clade')):
70
+ #clade.name = None
71
+ pass
72
+ if clade.name == None or clade.name not in model_names:
73
+ clade.family = None
74
+ clade.flag = False
75
+ continue
76
+ # Color the clades if it is a leaf
77
+ index = model_names.index(clade.name)
78
+ clade.family = origins[index]
79
+ clade.flag = True
80
+
81
+ # Propagate colors up the tree when all children have the same color
82
+ all_clades = list(tree.find_clades())
83
+ clades = [clade for clade in all_clades if clade.flag is False]
84
+
85
+ # Iterate this process until there are no more clades to color
86
+ i = 0
87
+ while clades:
88
+ clade = clades[i % len(clades)]
89
+ children_families = [c.family for c in clade.clades]
90
+ children_families_set = set(children_families)
91
+ children_flags = [c.flag for c in clade.clades]
92
+ children_flags_set = set(children_flags)
93
+
94
+ if len(children_families_set) == 1: # If all children have the same color : this clade is locked with the same color
95
+ clade.family = children_families[0]
96
+ clade.flag = True
97
+ del clades[i % len(clades)]
98
+ elif len(children_families_set) == 2 and '?' in children_families_set: # If children have different colors and one is unknown : this clade is locked with the known color
99
+ clade.family = [f for f in children_families_set if f != '?'][0]
100
+ clade.flag = True
101
+ del clades[i % len(clades)]
102
+ elif len(children_flags_set) == 1: # If children have different colors : this clade is locked with no color
103
+ clade.flag = True
104
+ del clades[i % len(clades)]
105
+ elif clade.flag == True: #Sholdn't happen
106
+ del clades[i % len(clades)]
107
+ i += 1
108
+
109
+ #Set color associated with family to each clade
110
+ for clade in all_clades:
111
+ if clade.family is None:
112
+ clade.color = UNKNOWN_COLOR
113
+ else:
114
+ clade.color = colors[clade.family]
plotting.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ import numpy as np
3
+ from Bio.Phylo import to_networkx
4
+ from networkx.drawing.nx_agraph import graphviz_layout
5
+ import plotly.graph_objects as go
6
+ import plotly.express as px
7
+ from Bio.Phylo.TreeConstruction import DistanceTreeConstructor, DistanceCalculator, _DistanceMatrix
8
+
9
+ from tools import compute_ordered_matrix,compute_umap
10
+ from phylogeny import prepare_tree
11
+ from constants import UNKNOWN_COLOR, DEFAULT_COLOR, UNKNOWN_COLOR_RGB, DEFAULT_COLOR_RGB
12
+
13
+ # ------------------------------------------------------------------------------------------------
14
+ #
15
+ # Sim Matrix Plotting
16
+ #
17
+ # ------------------------------------------------------------------------------------------------
18
+
19
+ def plot_sim_matrix_fig(ordered_sim_matrix,ordered_model_names,families,colors):
20
+ fig = px.imshow(
21
+ ordered_sim_matrix,
22
+ x=ordered_model_names,
23
+ y=ordered_model_names,
24
+ zmin=0, zmax=1,
25
+ color_continuous_scale='gray',
26
+ )
27
+
28
+ fig.update_layout(coloraxis_colorbar=dict(title='Similarity'),
29
+ margin=dict(l=0, r=0, t=0, b=0),
30
+ autosize=True,
31
+ )
32
+
33
+ fig.update_traces(
34
+ colorbar=dict(
35
+ thickness=20,
36
+ len=0.75,
37
+ xanchor="right",
38
+ x=1.02
39
+ )
40
+ )
41
+
42
+ fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')
43
+ fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')
44
+
45
+ #Create rectangles for highlighted models
46
+ rectX = go.layout.Shape(
47
+ type="rect",
48
+ xref="x", yref="y",
49
+ x0=0, y0=0,
50
+ x1=0, y1=0,
51
+ line=dict(color="red", width=1),
52
+ fillcolor="rgba(0,0,0,0)",
53
+ name='rectX',
54
+ opacity=0,
55
+ )
56
+ fig.add_shape(rectX)
57
+ rectY = go.layout.Shape(
58
+ type="rect",
59
+ xref="x", yref="y",
60
+ x0=0, y0=0,
61
+ x1=0, y1=0,
62
+ line=dict(color="red", width=1),
63
+ fillcolor="rgba(0,0,0,0)",
64
+ name='rectY',
65
+ opacity=0,
66
+ )
67
+ fig.add_shape(rectY)
68
+
69
+ return fig
70
+
71
+ def update_sim_matrix_fig(fig, ordered_model_names, model_search_x=None, model_search_y=None):
72
+ if model_search_x in ordered_model_names:
73
+ idx_x = ordered_model_names.index(model_search_x)
74
+ fig.update_shapes(
75
+ selector=dict(name='rectX'),
76
+ x0=idx_x-0.5, y0=-0.5,
77
+ x1=idx_x+0.5, y1=len(ordered_model_names)-0.5,
78
+ opacity=0.7,
79
+ )
80
+ else:
81
+ fig.update_shapes(
82
+ selector=dict(name='rectX'),
83
+ opacity=0
84
+ )
85
+ if model_search_y in ordered_model_names:
86
+ idx_y = ordered_model_names.index(model_search_y)
87
+ fig.update_shapes(
88
+ selector=dict(name='rectY'),
89
+ x0=-0.5, y0=idx_y-0.5,
90
+ x1=len(ordered_model_names)-0.5, y1=idx_y+0.5,
91
+ opacity=0.7,
92
+ )
93
+ else:
94
+ fig.update_shapes(
95
+ selector=dict(name='rectY'),
96
+ opacity=0
97
+ )
98
+ return fig
99
+
100
+ # ------------------------------------------------------------------------------------------------
101
+ #
102
+ # 2D UMAP Plotting
103
+ #
104
+ # ------------------------------------------------------------------------------------------------
105
+
106
+ def alpha_scaling(val):
107
+ base = 0.35
108
+ return val**(1/(base+1/100))
109
+
110
+ def plot_umap_fig(dist_matrix, sim_matrix, model_names, families, colors,key='fig2',alpha_edges=None, alpha_names=None, alpha_markers=None):
111
+ embedding = compute_umap(dist_matrix,d=2)
112
+
113
+ fig = go.Figure()
114
+
115
+ #-- EDGES
116
+ # Calculate edge transparencies based on similarity
117
+ edges = []
118
+ for i in range(len(model_names)):
119
+ for j in range(i+1, len(model_names)): # Only process each pair once (i,j where i<j)
120
+ val = alpha_scaling(sim_matrix[i][j])
121
+ if val > 0.1:
122
+ edges.append((i, j, val, colors[families[i]]))
123
+
124
+ # Add all edges at once
125
+ for i, j, val, color in edges:
126
+ fig.add_trace(
127
+ go.Scatter(
128
+ x=[embedding[i,0], embedding[j,0]],
129
+ y=[embedding[i,1], embedding[j,1]],
130
+ mode='lines',
131
+ name='_edge',
132
+ line=dict(color=color, width=val),
133
+ opacity=alpha_edges,
134
+ showlegend=False,
135
+ hoverinfo='skip',
136
+ )
137
+ )
138
+
139
+ #-- NODES
140
+ marker_colors = [colors[f] for f in families]
141
+ fig.add_trace(
142
+ go.Scatter(
143
+ x=embedding[:,0],
144
+ y=embedding[:,1],
145
+ text=model_names,
146
+ mode='markers+text',
147
+ textposition='top center',
148
+ hoverinfo='text',
149
+ hoveron='points+fills',
150
+ showlegend=False,
151
+ name='_node',
152
+ marker=dict(
153
+ color=marker_colors,
154
+ size=8,
155
+ line_width=2,
156
+ opacity=alpha_markers,
157
+ ),
158
+ textfont=dict(
159
+ color=f'rgba(0,0,0,{alpha_names})',
160
+ size=8,
161
+ family="Arial Black",
162
+ )
163
+ )
164
+ )
165
+
166
+ #-- LEGEND
167
+ legends = []
168
+ for f in set(families):
169
+ legends.append(
170
+ go.Scatter(
171
+ x=[None],
172
+ y=[None],
173
+ mode='markers',
174
+ marker=dict(
175
+ color=colors[f],
176
+ size=8,
177
+ line_width=2,
178
+ opacity=1
179
+ ),
180
+ name=f,
181
+
182
+ )
183
+ )
184
+ fig.add_traces(legends)
185
+
186
+ #Add highlighted node
187
+ node = go.Scatter(
188
+ x=[0],
189
+ y=[0],
190
+ mode='markers+text',
191
+ textposition='top center',
192
+ textfont=dict(color='red', size=16, family="Arial Black"),
193
+ marker=dict(
194
+ color='red',
195
+ size=12,
196
+ symbol='circle',
197
+ line=dict(color='red', width=3)
198
+ ),
199
+ showlegend=False,
200
+ name='node',
201
+ opacity=0,
202
+ )
203
+ fig.add_trace(node)
204
+
205
+ #Setup the layout
206
+ fig.update_layout(
207
+ margin=dict(l=0, r=0, t=0, b=0),
208
+ autosize=True,
209
+ )
210
+
211
+ fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')
212
+ fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')
213
+
214
+ return fig
215
+
216
+ def update_umap_fig(fig, dist_matrix, model_names, families, colors, model_search_x=None, alpha_names=None, alpha_markers=None, alpha_edges=None, key='fig2'):
217
+ #Update nodes
218
+ fig.update_traces(
219
+ selector=dict(name='_node'),
220
+ textfont=dict(
221
+ color=f'rgba(0,0,0,{alpha_names})',
222
+ ),
223
+ marker=dict(
224
+ opacity=alpha_markers
225
+ ),
226
+ )
227
+
228
+ #Update edges
229
+ fig.update_traces(
230
+ selector=dict(mode='lines'),
231
+ line=dict(width=1),
232
+ opacity=alpha_edges
233
+ )
234
+
235
+ #Update highlighted node
236
+ if model_search_x in model_names:
237
+ searched_idx = model_names.index(model_search_x)
238
+ embedding = compute_umap(dist_matrix,d=2) #Cached computation
239
+ fig.update_traces(
240
+ selector=dict(name='node'),
241
+ x=[embedding[searched_idx,0]],
242
+ y=[embedding[searched_idx,1]],
243
+ text=[model_search_x],
244
+ marker=dict(
245
+ color=colors[families[searched_idx]],
246
+ ),
247
+ hovertext=model_search_x,
248
+ opacity=1
249
+ )
250
+ else:
251
+ fig.update_traces(
252
+ selector=dict(name='node'),
253
+ x=[0],
254
+ y=[0],
255
+ text=[''],
256
+ opacity=0
257
+ )
258
+ return fig
259
+
260
+ # ------------------------------------------------------------------------------------------------
261
+ #
262
+ # Phylogenetic Tree Plotting
263
+ #
264
+ # ------------------------------------------------------------------------------------------------
265
+
266
+ def draw_graphviz(tree, label_func=str, prog='twopi', args='',
267
+ node_size=15, edge_width=0.0, alpha_edges=None, alpha_names=None,alpha_markers=None, **kwargs):
268
+ #Display a tree or clade as a graph using Plotly, with layout from the graphviz engine.
269
+
270
+ global UNKNOWN_COLOR, DEFAULT_COLOR
271
+ # Convert the Bio.Phylo tree to a NetworkX graph
272
+ G = to_networkx(tree)
273
+
274
+ # Relabel nodes using integers while keeping original labels
275
+ Gi = nx.convert_node_labels_to_integers(G, label_attribute='label')
276
+
277
+ # Apply the Graphviz layout
278
+ pos = graphviz_layout(Gi, prog=prog, args=args)
279
+
280
+ # Prepare node labels for display
281
+ def get_label_mapping(G, selection):
282
+ for node, data in G.nodes(data=True):
283
+ if (selection is None) or (node in selection):
284
+ try:
285
+ label = label_func(data.get('label', node))
286
+ if label not in (None, node.__class__.__name__):
287
+ yield (node, label)
288
+ except (LookupError, AttributeError, ValueError):
289
+ pass
290
+
291
+ # Extract labels
292
+ labels = dict(get_label_mapping(Gi, None))
293
+ nodelist = list(labels.keys())
294
+
295
+ # Collect node colors and create edge traces
296
+ edge_traces = []
297
+ node_traces_by_family = {}
298
+ node_colors = {}
299
+ node_families = {}
300
+
301
+ # Track if we find the searched model and its position
302
+ searched_model_node = None
303
+ searched_model_pos = None
304
+
305
+ default_color = (0,0,0)
306
+
307
+ # Get colors and families for all nodes
308
+ for node in Gi.nodes():
309
+ node_data = Gi.nodes[node].get('label')
310
+ if hasattr(node_data, 'color'):
311
+ node_colors[node] = node_data.color.to_rgb() if not(node_data.color is None) else default_color
312
+ else:
313
+ node_colors[node] = default_color
314
+ node_colors[node] = f'rgb({node_colors[node][0]},{node_colors[node][1]},{node_colors[node][2]})'
315
+
316
+ if hasattr(node_data, 'family'):
317
+ node_families[node] = node_data.family
318
+ else:
319
+ node_families[node] = None
320
+
321
+ # Create edge traces
322
+ for edge in Gi.edges():
323
+ x0, y0 = pos[edge[0]]
324
+ x1, y1 = pos[edge[1]]
325
+
326
+ # Use the child node's color for the edge if available
327
+ edge_color = node_colors[edge[1]]
328
+ if list(edge_color) == list(UNKNOWN_COLOR_RGB): # Use the parent node's color for edge's color except if it's an unknown nodes
329
+ edge_color = tuple(DEFAULT_COLOR_RGB)
330
+ #edge_color = f'rgb({edge_color[0]},{edge_color[1]},{edge_color[2]})'
331
+ edge_trace = go.Scatter(
332
+ x=[x0, x1, None],
333
+ y=[y0, y1, None],
334
+ line=dict(width=edge_width, color=edge_color),
335
+ hoverinfo='none',
336
+ mode='lines',
337
+ showlegend=False,
338
+ name='_edge',
339
+ opacity=alpha_edges,
340
+ )
341
+ edge_traces.append(edge_trace)
342
+
343
+ # Create node traces
344
+ node_traces = []
345
+ for node in nodelist:
346
+ x,y = pos[node]
347
+ text = labels.get(node, None)
348
+ color = node_colors.get(node, None)
349
+ node_trace = go.Scatter(
350
+ x=[x],
351
+ y=[y],
352
+ text=text,
353
+ mode='markers+text',
354
+ textposition='top center',
355
+ hoverinfo='text',
356
+ showlegend=False,
357
+ name='_node',
358
+ marker=dict(
359
+ color=color,
360
+ size=node_size,
361
+ line_width=2,
362
+ opacity=alpha_markers,
363
+ ),
364
+ textfont=dict(
365
+ color=f'rgba(0,0,0,{alpha_names})',
366
+ size=8,
367
+ family="Arial Black",
368
+ )
369
+ )
370
+ node_traces.append(node_trace)
371
+
372
+ # Get color dict
373
+ colors = {}
374
+ families = []
375
+ for node in node_families.keys():
376
+ family = node_families[node]
377
+ if family is not None:
378
+ families.append(family)
379
+ colors[family] = node_colors.get(node, DEFAULT_COLOR)
380
+ else:
381
+ colors[family] = DEFAULT_COLOR
382
+
383
+ families = set(families)
384
+
385
+ #Custom legend
386
+ legends = []
387
+ for f in families:
388
+ legends.append(
389
+ go.Scatter(
390
+ x=[None],
391
+ y=[None],
392
+ mode='markers',
393
+ marker=dict(
394
+ color=colors[f],
395
+ size=8,
396
+ line_width=2,
397
+ opacity=1
398
+ ),
399
+ name=f,
400
+
401
+ )
402
+ )
403
+
404
+ # Create the figure
405
+ fig = go.Figure(
406
+ data=edge_traces + node_traces,
407
+ layout=go.Layout(
408
+ showlegend=True,
409
+ hovermode='closest',
410
+ margin=dict(b=1, l=1, r=1, t=1),
411
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
412
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
413
+ legend=dict(
414
+ yanchor="top",
415
+ y=0.99,
416
+ xanchor="right",
417
+ x=0.99
418
+ )
419
+ )
420
+ )
421
+
422
+ fig.add_traces(legends)
423
+
424
+ return fig
425
+
426
+ def get_color(index):
427
+ """Get a color from plotly's qualitative color palette."""
428
+ colors = px.colors.qualitative.Plotly
429
+ return colors[index % len(colors)]
430
+
431
+ def plot_tree(sim_matrix, models, families,colors, alpha_names=None, alpha_markers=None, alpha_edges=None):
432
+ """
433
+ Plot a phylogenetic tree based on a similarity matrix.
434
+
435
+ Parameters:
436
+ - sim_matrix: similarity matrix between models
437
+ - models: list of model names
438
+ - families: list of family names for each model
439
+
440
+ Returns:
441
+ - fig: Plotly figure object with the phylogenetic tree
442
+ """
443
+ # Create color mapping for families
444
+
445
+ # Prepare the distance matrix
446
+ dist_matrix = -np.log(np.maximum(sim_matrix, 1e-10)) # Avoid log(0)
447
+
448
+ # Prepare the data for Bio.Phylo
449
+ low_triangle_kl_mean = [[dist_matrix[i][j] for j in range(i+1)] for i in range(len(dist_matrix))]
450
+ df = _DistanceMatrix(names=models, matrix=low_triangle_kl_mean)
451
+
452
+ # Setup Bio.Phylo
453
+ calculator = DistanceCalculator('identity')
454
+ constructor = DistanceTreeConstructor(calculator, 'nj')
455
+
456
+ # Build the tree
457
+ NJTree = constructor.nj(df)
458
+ NJTree.ladderize(reverse=False)
459
+
460
+ # Color the tree
461
+ prepare_tree(NJTree, models, families, colors)
462
+
463
+ # Generate the plotly figure
464
+ fig = draw_graphviz(NJTree, node_size=15, edge_width=6,alpha_names=alpha_names, alpha_markers=alpha_markers, alpha_edges=alpha_edges)
465
+
466
+ return fig
467
+
468
+ def update_tree_fig(fig, model_names, model_search=None,alpha_names=None, alpha_markers=None, alpha_edges=None):
469
+ #Update nodes
470
+ fig.update_traces(
471
+ selector=dict(name='_node'),
472
+ marker=dict(
473
+ opacity=alpha_markers,
474
+ ),
475
+ textfont=dict(
476
+ color=f'rgba(0,0,0,{alpha_names})',
477
+ )
478
+ )
479
+
480
+ # Update edges
481
+ fig.update_traces(
482
+ selector=dict(name='_edge'),
483
+ opacity=alpha_edges,
484
+ )
485
+
486
+ for d in fig.data:
487
+ if d.name in ['_node','node']:
488
+ if d.text == 'mistralai/Mistral-7B-Instruct-v0.1':
489
+ print(d)
490
+
491
+ # Update highlighted node
492
+ fig.update_traces(
493
+ selector=dict(name='node'),
494
+ marker=dict(
495
+ size=15, # Bigger than normal nodes
496
+ line=None # Red border
497
+ ),
498
+ textfont=dict(
499
+ color=f'rgba(0,0,0,{alpha_names})', size=16, family="Arial Black",
500
+ ),
501
+ name='_node'
502
+ )
503
+ if model_search in model_names:
504
+ fig.update_traces(
505
+ selector=dict(name='_node',text=model_search),
506
+ marker=dict(
507
+ size=22, # Bigger than normal nodes
508
+ line=dict(color='red', width=4) # Red border
509
+ ),
510
+ textfont=dict(
511
+ color='red', size=16, family="Arial Black",
512
+ ),
513
+ name='node'
514
+ )
515
+ for d in fig.data:
516
+ if d.name in ['_node','node']:
517
+ if d.text == 'mistralai/Mistral-7B-Instruct-v0.1':
518
+ print(d)
519
+ else:
520
+ pass
521
+
522
+ return fig
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib
2
+ networkx
3
+ numpy==1.23.0
4
+ biopython
5
+ plotly
6
+ scikit-learn
7
+ streamlit
8
+ transformers
9
+ torch
10
+ pygit2
11
+ fastcluster
12
+ pygraphviz
13
+ accelerate
14
+ umap-learn
15
+ ujson
tools.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.spatial.distance import squareform
3
+ from fastcluster import linkage
4
+ import umap
5
+
6
+ # ------------------------------------------------------------------------------------------------
7
+ #
8
+ # Sim Matrix Ordering
9
+ #
10
+ # ------------------------------------------------------------------------------------------------
11
+
12
+ def seriation(Z,N,cur_index):
13
+ '''
14
+ input:
15
+ - Z is a hierarchical tree (dendrogram)
16
+ - N is the number of points given to the clustering process
17
+ - cur_index is the position in the tree for the recursive traversal
18
+ output:
19
+ - order implied by the hierarchical tree Z
20
+
21
+ seriation computes the order implied by a hierarchical tree (dendrogram)
22
+ '''
23
+ if cur_index < N:
24
+ return [cur_index]
25
+ else:
26
+ left = int(Z[cur_index-N,0])
27
+ right = int(Z[cur_index-N,1])
28
+ return (seriation(Z,N,left) + seriation(Z,N,right))
29
+
30
+ def compute_serial_matrix(dist_mat,method="ward"):
31
+ '''
32
+ input:
33
+ - dist_mat is a distance matrix
34
+ - method = ["ward","single","average","complete"]
35
+ output:
36
+ - seriated_dist is the input dist_mat,
37
+ but with re-ordered rows and columns
38
+ according to the seriation, i.e. the
39
+ order implied by the hierarchical tree
40
+ - res_order is the order implied by
41
+ the hierarhical tree
42
+ - res_linkage is the hierarhical tree (dendrogram)
43
+
44
+ compute_serial_matrix transforms a distance matrix into
45
+ a sorted distance matrix according to the order implied
46
+ by the hierarchical tree (dendrogram)
47
+ '''
48
+ N = len(dist_mat)
49
+ flat_dist_mat = squareform(dist_mat)
50
+ res_linkage = linkage(flat_dist_mat, method=method,preserve_input=True)
51
+ res_order = seriation(res_linkage, N, N + N-2)
52
+ seriated_dist = np.zeros((N,N))
53
+ a,b = np.triu_indices(N,k=1)
54
+ seriated_dist[a,b] = dist_mat[ [res_order[i] for i in a], [res_order[j] for j in b]]
55
+ seriated_dist[b,a] = seriated_dist[a,b]
56
+
57
+ return seriated_dist, res_order, res_linkage
58
+
59
+ def compute_ordered_matrix(sim_matrix,dist_matrix, model_names):
60
+ if len(sim_matrix) >= 2:
61
+ # Compute serial matrix (hierarchical clustering) for tab1
62
+ ordered_dist_matrix, order, Z = compute_serial_matrix(dist_matrix)
63
+ ordered_sim_matrix = sim_matrix[order][:, order]
64
+ ordered_model_names = [model_names[i] for i in order]
65
+ else:
66
+ ordered_sim_matrix = sim_matrix
67
+ ordered_model_names = model_names
68
+
69
+ return ordered_sim_matrix, ordered_model_names
70
+
71
+
72
+ # ------------------------------------------------------------------------------------------------
73
+ #
74
+ # UMAP computation
75
+ #
76
+ # ------------------------------------------------------------------------------------------------
77
+
78
+ def compute_umap(dist_matrix,d=2):
79
+ embedding = umap.UMAP(densmap=True,n_components=d, metric='precomputed',random_state=42).fit_transform(dist_matrix)
80
+ return embedding