transformers-modular-refactor / modular_graph_and_candidates.py
Molbap's picture
Molbap HF Staff
setup cache
6d106b8
#!/usr/bin/env python
"""
modular_graph_and_candidates.py
================================
Create **one** rich view that combines
1. The *dependency graph* between existing **modular_*.py** implementations in
πŸ€—Β Transformers (blue/🟑) **and**
2. The list of *missing* modular models (full‑red nodes) **plus** similarity
edges (full‑red links) between highly‑overlapping modelling files – the
output of *find_modular_candidates.py* – so you can immediately spot good
refactor opportunities.
––– Usage –––
```bash
python modular_graph_and_candidates.py /path/to/transformers \
--multimodal # keep only models whose modelling code mentions
# "pixel_values" β‰₯Β 3 times
--sim-threshold 0.5 # Jaccard cutoff (default 0.50)
--out graph.html # output HTML file name
```
Colour legend in the generated HTML:
* 🟑 **base model**Β β€” has modular shards *imported* by others but no parent
* πŸ”΅Β **derived modular model**Β β€” has a `modular_*.py` and inherits from β‰₯β€―1 model
* πŸ”΄Β **candidate**Β β€” no `modular_*.py` yet (and/or very similar to another)
* red edges = high‑Jaccard similarity links (potential to factorise)
"""
from __future__ import annotations
import argparse
import ast
import json
import re
import subprocess
import tokenize
from collections import Counter, defaultdict
from itertools import combinations
from pathlib import Path
from typing import Dict, List, Set, Tuple
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm
import numpy as np
import spaces
import torch
from datetime import datetime
# ────────────────────────────────────────────────────────────────────────────────
# CONFIG
# ───────────────────────────────────────────────────────────────────────────────
SIM_DEFAULT = 0.5 # similarity threshold
PIXEL_MIN_HITS = 0 # multimodal trigger ("pixel_values")
HTML_DEFAULT = "d3_modular_graph.html"
# ────────────────────────────────────────────────────────────────────────────────
# 1) Helpers to analyse *modelling* files (for similarity & multimodal filter)
# ────────────────────────────────────────────────────────────────────────────────
def _strip_source(code: str) -> str:
"""Remove doc‑strings, comments and import lines to keep only the core code."""
code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code) # doc‑strings
code = re.sub(r"#.*", "", code) # # comments
return "\n".join(ln for ln in code.splitlines()
if not re.match(r"\s*(from|import)\s+", ln))
def _tokenise(code: str) -> Set[str]:
"""Extract identifiers using regex - more robust than tokenizer for malformed code."""
toks: Set[str] = set()
for match in re.finditer(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', code):
toks.add(match.group())
return toks
def build_token_bags(models_root: Path) -> Tuple[Dict[str, List[Set[str]]], Dict[str, int]]:
"""Return token‑bags of every `modeling_*.py` plus a pixel‑value counter."""
bags: Dict[str, List[Set[str]]] = defaultdict(list)
pixel_hits: Dict[str, int] = defaultdict(int)
for mdl_dir in sorted(p for p in models_root.iterdir() if p.is_dir()):
for py in mdl_dir.rglob("modeling_*.py"):
try:
text = py.read_text(encoding="utf‑8")
pixel_hits[mdl_dir.name] += text.count("pixel_values")
bags[mdl_dir.name].append(_tokenise(_strip_source(text)))
except Exception as e:
print(f"⚠️ Skipped {py}: {e}")
return bags, pixel_hits
def _jaccard(a: Set[str], b: Set[str]) -> float:
return 0.0 if (not a or not b) else len(a & b) / len(a | b)
def similarity_clusters(bags: Dict[str, List[Set[str]]], thr: float) -> Dict[Tuple[str,str], float]:
"""Return {(modelA, modelB): score} for pairs with Jaccard β‰₯ *thr*."""
largest = {m: max(ts, key=len) for m, ts in bags.items() if ts}
out: Dict[Tuple[str,str], float] = {}
for m1, m2 in combinations(sorted(largest.keys()), 2):
s = _jaccard(largest[m1], largest[m2])
if s >= thr:
out[(m1, m2)] = s
return out
@spaces.GPU
def embedding_similarity_clusters(models_root: Path, missing: List[str], thr: float) -> Dict[Tuple[str, str], float]:
model = SentenceTransformer("codesage/codesage-large-v2", device="cuda", trust_remote_code=True)
try:
cfg = model[0].auto_model.config
pos_limit = int(getattr(cfg, "n_positions", getattr(cfg, "max_position_embeddings")))
except Exception:
pos_limit = 1024
seq_len = min(pos_limit, 2048)
model.max_seq_length = seq_len
model[0].max_seq_length = seq_len
model[0].tokenizer.model_max_length = seq_len
texts = {}
for name in tqdm(missing, desc="Reading modeling files"):
if any(skip in name.lower() for skip in ["mobilebert", "lxmert"]):
print(f"Skipping {name} (causes GPU abort)")
continue
code = ""
for py in (models_root / name).rglob("modeling_*.py"):
try:
code += _strip_source(py.read_text(encoding="utf-8")) + "\n"
except Exception:
continue
texts[name] = code.strip() or " "
names = list(texts)
all_embeddings = []
print(f"Encoding embeddings for {len(names)} models...")
batch_size = 4 # keep your default
# ── two-stage caching: temp (for resume) + permanent (for reuse) ─────────────
temp_cache_path = Path("temp_embeddings.npz") # For resuming computation
final_cache_path = Path("embeddings_cache.npz") # For permanent storage
start_idx = 0
emb_dim = getattr(model, "get_sentence_embedding_dimension", lambda: 768)()
# Try to load from permanent cache first
if final_cache_path.exists():
try:
cached = np.load(final_cache_path, allow_pickle=True)
cached_names = list(cached["names"])
if names == cached_names: # Exact match - use final cache
print(f"βœ… Using final embeddings cache ({len(cached_names)} models)")
return compute_similarities_from_cache(thr)
except Exception as e:
print(f"⚠️ Failed to load final cache: {e}")
# Try to resume from temp cache
if temp_cache_path.exists():
try:
cached = np.load(temp_cache_path, allow_pickle=True)
cached_names = list(cached["names"])
if names[:len(cached_names)] == cached_names:
loaded = cached["embeddings"].astype(np.float32)
all_embeddings.append(loaded)
start_idx = len(cached_names)
print(f"πŸ”„ Resuming from temp cache: {start_idx}/{len(names)} models")
except Exception as e:
print(f"⚠️ Failed to load temp cache: {e}")
# ───────────────────────────────────────────────────────────────────────────
for i in tqdm(range(start_idx, len(names), batch_size), desc="Batches", leave=False):
batch_names = names[i:i+batch_size]
batch_texts = [texts[name] for name in batch_names]
try:
print(f"Processing batch: {batch_names}")
emb = model.encode(batch_texts, convert_to_numpy=True, show_progress_bar=False)
except Exception as e:
print(f"⚠️ GPU worker error for batch {batch_names}: {type(e).__name__}: {e}")
emb = np.zeros((len(batch_names), emb_dim), dtype=np.float32)
all_embeddings.append(emb)
# save to temp cache after each batch (for resume)
try:
cur = np.vstack(all_embeddings).astype(np.float32)
np.savez(
temp_cache_path,
embeddings=cur,
names=np.array(names[:i+len(batch_names)], dtype=object),
)
except Exception as e:
print(f"⚠️ Failed to write temp cache: {e}")
if (i - start_idx) % (3 * batch_size) == 0 and torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(f"🧹 Cleared GPU cache after batch {(i - start_idx)//batch_size + 1}")
embeddings = np.vstack(all_embeddings).astype(np.float32)
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12
embeddings = embeddings / norms
print("Computing pairwise similarities...")
sims_mat = embeddings @ embeddings.T
out = {}
matrix_size = embeddings.shape[0]
processed_names = names[:matrix_size]
for i in range(matrix_size):
for j in range(i + 1, matrix_size):
s = float(sims_mat[i, j])
if s >= thr:
out[(processed_names[i], processed_names[j])] = s
# Save to final cache when complete
try:
np.savez(final_cache_path, embeddings=embeddings, names=np.array(names, dtype=object))
print(f"πŸ’Ύ Final embeddings saved to {final_cache_path}")
# Clean up temp cache
if temp_cache_path.exists():
temp_cache_path.unlink()
print(f"🧹 Cleaned up temp cache")
except Exception as e:
print(f"⚠️ Failed to save final cache: {e}")
return out
def compute_similarities_from_cache(threshold: float) -> Dict[Tuple[str, str], float]:
"""Compute similarities from cached embeddings without reprocessing."""
embeddings_path = Path("embeddings_cache.npz")
if not embeddings_path.exists():
return {}
try:
cached = np.load(embeddings_path, allow_pickle=True)
embeddings = cached["embeddings"].astype(np.float32)
names = list(cached["names"])
# Normalize embeddings
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12
embeddings = embeddings / norms
# Compute similarities
sims_mat = embeddings @ embeddings.T
out = {}
for i in range(len(names)):
for j in range(i + 1, len(names)):
s = float(sims_mat[i, j])
if s >= threshold:
out[(names[i], names[j])] = s
print(f"⚑ Computed {len(out)} similarities from cache (threshold: {threshold})")
return out
except Exception as e:
print(f"⚠️ Failed to compute from cache: {e}")
return {}
# ────────────────────────────────────────────────────────────────────────────────
# 2) Scan *modular_*.py* files to build an import‑dependency graph
# – only **modeling_*** imports are considered (skip configuration / processing)
# ────────────────────────────────────────────────────────────────────────────────
def modular_files(models_root: Path) -> List[Path]:
return [p for p in models_root.rglob("modular_*.py") if p.suffix == ".py"]
def dependency_graph(modular_files: List[Path], models_root: Path) -> Dict[str, List[Dict[str,str]]]:
"""Return {derived_model: [{source, imported_class}, ...]}
Only `modeling_*` imports are kept; anything coming from configuration/processing/
image* utils is ignored so the visual graph focuses strictly on modelling code.
Excludes edges to sources whose model name is not a model dir.
"""
model_names = {p.name for p in models_root.iterdir() if p.is_dir()}
deps: Dict[str, List[Dict[str,str]]] = defaultdict(list)
for fp in modular_files:
derived = fp.parent.name
try:
tree = ast.parse(fp.read_text(encoding="utf‑8"), filename=str(fp))
except Exception as e:
print(f"⚠️ AST parse failed for {fp}: {e}")
continue
for node in ast.walk(tree):
if not isinstance(node, ast.ImportFrom) or not node.module:
continue
mod = node.module
# keep only *modeling_* imports, drop anything else
if ("modeling_" not in mod or
"configuration_" in mod or
"processing_" in mod or
"image_processing" in mod or
"modeling_attn_mask_utils" in mod):
continue
parts = re.split(r"[./]", mod)
src = next((p for p in parts if p not in {"", "models", "transformers"}), "")
if not src or src == derived or src not in model_names:
continue
for alias in node.names:
deps[derived].append({"source": src, "imported_class": alias.name})
return dict(deps)
# modular_graph_and_candidates.py (top-level)
def get_missing_models(models_root: Path, multimodal: bool = False) -> Tuple[List[str], Dict[str, List[Set[str]]], Dict[str, int]]:
"""Get list of models missing modular implementations."""
bags, pix_hits = build_token_bags(models_root)
mod_files = modular_files(models_root)
models_with_modular = {p.parent.name for p in mod_files}
missing = [m for m in bags if m not in models_with_modular]
if multimodal:
missing = [m for m in missing if pix_hits[m] >= PIXEL_MIN_HITS]
return missing, bags, pix_hits
def compute_similarities(models_root: Path, missing: List[str], bags: Dict[str, List[Set[str]]],
threshold: float, sim_method: str) -> Dict[Tuple[str, str], float]:
"""Compute similarities between missing models using specified method."""
if sim_method == "jaccard":
return similarity_clusters({m: bags[m] for m in missing}, threshold)
else:
# Try to use cached embeddings first
embeddings_path = Path("embeddings_cache.npz")
if embeddings_path.exists():
cached_sims = compute_similarities_from_cache(threshold)
if cached_sims: # Cache exists and worked
return cached_sims
# Fallback to full computation
return embedding_similarity_clusters(models_root, missing, threshold)
def build_graph_json(
transformers_dir: Path,
threshold: float = SIM_DEFAULT,
multimodal: bool = False,
sim_method: str = "jaccard",
) -> dict:
"""Return the {nodes, links} dict that D3 needs."""
# Check if we can use cached embeddings only
embeddings_cache = Path("embeddings_cache.npz")
print(f"πŸ” Cache file exists: {embeddings_cache.exists()}, sim_method: {sim_method}")
if sim_method == "embedding" and embeddings_cache.exists():
try:
# Try to compute from cache without accessing repo
cached_sims = compute_similarities_from_cache(threshold)
print(f"πŸ” Got {len(cached_sims)} cached similarities")
if cached_sims:
# Create graph with cached similarities + modular dependencies
cached_data = np.load(embeddings_cache, allow_pickle=True)
missing = list(cached_data["names"])
# Still need to get modular dependencies from repo
models_root = transformers_dir / "src/transformers/models"
mod_files = modular_files(models_root)
deps = dependency_graph(mod_files, models_root)
# Build full graph structure
nodes = set(missing) # Start with cached models
links = []
# Add dependency links
for drv, lst in deps.items():
for d in lst:
links.append({
"source": d["source"],
"target": drv,
"label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports",
"cand": False
})
nodes.update({d["source"], drv})
# Add similarity links
for (a, b), s in cached_sims.items():
links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True})
# Create node list with proper classification
targets = {lk["target"] for lk in links if not lk["cand"]}
sources = {lk["source"] for lk in links if not lk["cand"]}
nodelist = []
for n in sorted(nodes):
if n in missing and n not in sources and n not in targets:
cls = "cand"
elif n in sources and n not in targets:
cls = "base"
else:
cls = "derived"
nodelist.append({"id": n, "cls": cls, "sz": 1})
print(f"⚑ Built graph from cache: {len(nodelist)} nodes, {len(links)} links")
return {"nodes": nodelist, "links": links}
except Exception as e:
print(f"⚠️ Cache-only build failed: {e}, falling back to full build")
# Full build with repository access
models_root = transformers_dir / "src/transformers/models"
# Get missing models and their data
missing, bags, pix_hits = get_missing_models(models_root, multimodal)
# Build dependency graph
mod_files = modular_files(models_root)
deps = dependency_graph(mod_files, models_root)
# Compute similarities
sims = compute_similarities(models_root, missing, bags, threshold, sim_method)
# ---- assemble nodes & links ----
nodes: Set[str] = set()
links: List[dict] = []
for drv, lst in deps.items():
for d in lst:
links.append({
"source": d["source"],
"target": drv,
"label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports",
"cand": False
})
nodes.update({d["source"], drv})
for (a, b), s in sims.items():
links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True})
nodes.update({a, b})
nodes.update(missing)
deg = Counter()
for lk in links:
deg[lk["source"]] += 1
deg[lk["target"]] += 1
max_deg = max(deg.values() or [1])
targets = {lk["target"] for lk in links if not lk["cand"]}
sources = {lk["source"] for lk in links if not lk["cand"]}
missing_only = [m for m in missing if m not in sources and m not in targets]
nodes.update(missing_only)
nodelist = []
for n in sorted(nodes):
if n in missing_only:
cls = "cand"
elif n in sources and n not in targets:
cls = "base"
else:
cls = "derived"
nodelist.append({"id": n, "cls": cls, "sz": 1 + 2*(deg[n]/max_deg)})
graph = {"nodes": nodelist, "links": links}
return graph
# ────────────────────────────────────────────────────────────────────────────────
# Timeline functions for chronological visualization
# ────────────────────────────────────────────────────────────────────────────────
def get_model_creation_dates(transformers_dir: Path) -> Dict[str, datetime]:
"""Get creation dates for all model directories by finding the earliest add of the directory path."""
models_root = transformers_dir / "src/transformers/models"
creation_dates: Dict[str, datetime] = {}
if not models_root.exists():
return creation_dates
def run_git(args: list[str]) -> subprocess.CompletedProcess:
return subprocess.run(
["git"] + args,
cwd=transformers_dir,
capture_output=True,
text=True,
timeout=120,
)
# Ensure full history; shallow clones make every path look newly added "today".
shallow = run_git(["rev-parse", "--is-shallow-repository"])
if shallow.returncode == 0 and shallow.stdout.strip() == "true":
# Try best-effort unshallow; if it fails, we still proceed.
run_git(["fetch", "--unshallow", "--tags", "--prune"]) # ignore return code
# Fallback if server forbids --unshallow
run_git(["fetch", "--depth=100000", "--tags", "--prune"])
for model_dir in models_root.iterdir():
if not model_dir.is_dir():
continue
rel = f"src/transformers/models/{model_dir.name}/"
# Earliest commit that ADDED something under this directory.
# Use a stable delimiter to avoid locale/spacing issues.
proc = run_git([
"log",
"--reverse", # oldest β†’ newest
"--diff-filter=A", # additions only
"--date=short", # YYYY-MM-DD
'--format=%H|%ad', # hash|date
"--",
rel,
])
if proc.returncode != 0 or not proc.stdout.strip():
# As a fallback, look at the earliest commit touching any tracked file under the dir.
# This can catch cases where files were moved (rename) rather than added.
ls = run_git(["ls-files", rel])
files = [ln for ln in ls.stdout.splitlines() if ln.strip()]
best_date: datetime | None = None
if files:
for fp in files:
proc_file = run_git([
"log",
"--reverse",
"--diff-filter=A",
"--date=short",
"--format=%H|%ad",
"--",
fp,
])
line = proc_file.stdout.splitlines()[0].strip() if proc_file.stdout else ""
if line and "|" in line:
_, d = line.split("|", 1)
try:
dt = datetime.strptime(d.strip(), "%Y-%m-%d")
if best_date is None or dt < best_date:
best_date = dt
except ValueError:
pass
if best_date is not None:
creation_dates[model_dir.name] = best_date
print(f"βœ… {model_dir.name}: {best_date.strftime('%Y-%m-%d')}")
else:
print(f"❌ {model_dir.name}: no add commit found")
continue
first_line = proc.stdout.splitlines()[0].strip() # oldest add
if "|" in first_line:
_, date_str = first_line.split("|", 1)
try:
creation_dates[model_dir.name] = datetime.strptime(date_str.strip(), "%Y-%m-%d")
print(f"βœ… {model_dir.name}: {date_str.strip()}")
except ValueError:
print(f"❌ {model_dir.name}: bad date format: {date_str!r}")
else:
print(f"❌ {model_dir.name}: unexpected log format: {first_line!r}")
return creation_dates
def build_timeline_json(
transformers_dir: Path,
threshold: float = SIM_DEFAULT,
multimodal: bool = False,
sim_method: str = "jaccard",
) -> dict:
"""Build chronological timeline with modular connections."""
# Get the standard dependency graph for connections
graph = build_graph_json(transformers_dir, threshold, multimodal, sim_method)
# Get creation dates for chronological positioning
creation_dates = get_model_creation_dates(transformers_dir)
# Enhance nodes with chronological data
for node in graph["nodes"]:
model_name = node["id"]
if model_name in creation_dates:
creation_date = creation_dates[model_name]
node.update({
"date": creation_date.isoformat(),
"year": creation_date.year,
"timestamp": creation_date.timestamp()
})
else:
# Fallback for models without date info
node.update({
"date": "2020-01-01T00:00:00", # Default date
"year": 2020,
"timestamp": datetime(2020, 1, 1).timestamp()
})
# Add timeline metadata
valid_dates = [n for n in graph["nodes"] if n["timestamp"] > 0]
if valid_dates:
min_year = min(n["year"] for n in valid_dates)
max_year = max(n["year"] for n in valid_dates)
graph["timeline_meta"] = {
"min_year": min_year,
"max_year": max_year,
"total_models": len(graph["nodes"]),
"dated_models": len(valid_dates)
}
else:
graph["timeline_meta"] = {
"min_year": 2018,
"max_year": 2024,
"total_models": len(graph["nodes"]),
"dated_models": 0
}
return graph
def generate_html(graph: dict) -> str:
"""Return the full HTML string with inlined CSS/JS + graph JSON."""
js = JS.replace("__GRAPH_DATA__", json.dumps(graph, separators=(",", ":")))
return HTML.replace("__CSS__", CSS).replace("__JS__", js)
def generate_timeline_html(timeline: dict) -> str:
"""Return the full HTML string for chronological timeline visualization."""
js = TIMELINE_JS.replace("__TIMELINE_DATA__", json.dumps(timeline, separators=(",", ":")))
return TIMELINE_HTML.replace("__TIMELINE_CSS__", TIMELINE_CSS).replace("__TIMELINE_JS__", js)
# ────────────────────────────────────────────────────────────────────────────────
# 3) HTML (D3.js) boilerplate – CSS + JS templates (unchanged design)
# ────────────────────────────────────────────────────────────────────────────────
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
:root{
--bg:#ffffff;
--text:#222222;
--muted:#555555;
--outline:#ffffff;
}
@media (prefers-color-scheme: dark){
:root{
--bg:#0b0d10;
--text:#e8e8e8;
--muted:#c8c8c8;
--outline:#000000;
}
}
body{ margin:0; font-family:'Inter',Arial,sans-serif; background:var(--bg); overflow:hidden; }
svg{ width:100vw; height:100vh; }
.link{ stroke:#999; stroke-opacity:.6; }
.link.cand{ stroke:#e63946; stroke-width:2.5; }
.node-label{
fill:var(--text);
pointer-events:none;
text-anchor:middle;
font-weight:600;
paint-order:stroke fill;
stroke:var(--outline);
stroke-width:3px;
}
.link-label{
fill:var(--muted);
pointer-events:none;
text-anchor:middle;
font-size:10px;
paint-order:stroke fill;
stroke:var(--bg);
stroke-width:2px;
}
.node.base image{ width:60px; height:60px; transform:translate(-30px,-30px); }
.node.derived circle{ fill:#1f77b4; }
.node.cand circle, .node.cand path{ fill:#e63946; }
#legend{
position:fixed; top:18px; left:18px;
background:rgba(255,255,255,.92);
padding:18px 28px; border-radius:10px; border:1.5px solid #bbb;
font-size:18px; box-shadow:0 2px 8px rgba(0,0,0,.08);
}
@media (prefers-color-scheme: dark){
#legend{ background:rgba(20,22,25,.92); color:#e8e8e8; border-color:#444; }
}
"""
JS = """
function updateVisibility() {
const show = document.getElementById('toggleRed').checked;
svg.selectAll('.link.cand').style('display', show ? null : 'none');
svg.selectAll('.node.cand').style('display', show ? null : 'none');
svg.selectAll('.link-label').filter(d => d.cand).style('display', show ? null : 'none');
}
document.getElementById('toggleRed').addEventListener('change', updateVisibility);
const graph = __GRAPH_DATA__;
const W = innerWidth, H = innerHeight;
const svg = d3.select('#dependency').call(d3.zoom().on('zoom', e => g.attr('transform', e.transform)));
const g = svg.append('g');
const link = g.selectAll('line')
.data(graph.links)
.join('line')
.attr('class', d => d.cand ? 'link cand' : 'link');
const linkLbl = g.selectAll('text.link-label')
.data(graph.links)
.join('text')
.attr('class', 'link-label')
.text(d => d.label);
const node = g.selectAll('g.node')
.data(graph.nodes)
.join('g')
.attr('class', d => `node ${d.cls}`)
.call(d3.drag().on('start', dragStart).on('drag', dragged).on('end', dragEnd));
const baseSel = node.filter(d => d.cls === 'base');
baseSel.append('circle').attr('r', d => 22*d.sz).attr('fill', '#ffbe0b');
node.filter(d => d.cls !== 'base').append('circle').attr('r', d => 20*d.sz);
node.append('text')
.attr('class','node-label')
.attr('dy','-2.4em')
.style('font-size', d => d.cls === 'base' ? '32px' : '28px')
.style('font-weight', d => d.cls === 'base' ? 'bold' : 'normal')
.text(d => d.id);
const sim = d3.forceSimulation(graph.nodes)
.force('link', d3.forceLink(graph.links).id(d => d.id).distance(520))
.force('charge', d3.forceManyBody().strength(-600))
.force('center', d3.forceCenter(W / 2, H / 2))
.force('collide', d3.forceCollide(d => 50));
sim.on('tick', () => {
link.attr('x1', d=>d.source.x).attr('y1', d=>d.source.y)
.attr('x2', d=>d.target.x).attr('y2', d=>d.target.y);
linkLbl.attr('x', d=> (d.source.x+d.target.x)/2)
.attr('y', d=> (d.source.y+d.target.y)/2);
node.attr('transform', d=>`translate(${d.x},${d.y})`);
});
function dragStart(e,d){ if(!e.active) sim.alphaTarget(.3).restart(); d.fx=d.x; d.fy=d.y; }
function dragged(e,d){ d.fx=e.x; d.fy=e.y; }
function dragEnd(e,d){ if(!e.active) sim.alphaTarget(0); d.fx=d.fy=null; }
"""
HTML = """
<!DOCTYPE html>
<html lang='en'><head><meta charset='UTF-8'>
<title>Transformers modular graph</title>
<style>__CSS__</style></head><body>
<div id='legend'>
🟑 base<br>πŸ”΅ modular<br>πŸ”΄ candidate<br>red edgeΒ = high embedding similarity<br><br>
<label><input type="checkbox" id="toggleRed" checked> Show candidates edges and nodes</label>
</div>
<svg id='dependency'></svg>
<script src='https://d3js.org/d3.v7.min.js'></script>
<script>__JS__</script></body></html>
"""
# ────────────────────────────────────────────────────────────────────────────────
# Timeline HTML Templates
# ────────────────────────────────────────────────────────────────────────────────
TIMELINE_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
:root{
--bg:#ffffff;
--text:#222222;
--muted:#555555;
--outline:#ffffff;
--timeline-line:#dee2e6;
--base-color:#ffbe0b;
--derived-color:#1f77b4;
--candidate-color:#e63946;
}
@media (prefers-color-scheme: dark){
:root{
--bg:#0b0d10;
--text:#e8e8e8;
--muted:#c8c8c8;
--outline:#000000;
--timeline-line:#343a40;
}
}
body{
margin:0;
font-family:'Inter',Arial,sans-serif;
background:var(--bg);
overflow:hidden;
}
svg{ width:100vw; height:100vh; }
/* Enhanced link styles for chronological flow */
.link{
stroke:#4a90e2;
stroke-opacity:0.6;
stroke-width:1.5;
fill:none;
transition: stroke-opacity 0.3s ease;
}
.link.cand{
stroke:var(--candidate-color);
stroke-width:2.5;
stroke-opacity:0.8;
stroke-dasharray: 4,4;
}
.link:hover{
stroke-opacity:1;
stroke-width:3;
}
/* Improved node label styling */
.node-label{
fill:var(--text);
pointer-events:none;
text-anchor:middle;
font-weight:600;
font-size:13px;
paint-order:stroke fill;
stroke:var(--outline);
stroke-width:3px;
cursor:default;
}
/* Enhanced node styling with better visual hierarchy */
.node.base circle{
fill:var(--base-color);
stroke:#d4a000;
stroke-width:2;
}
.node.derived circle{
fill:var(--derived-color);
stroke:#1565c0;
stroke-width:2;
}
.node.cand circle{
fill:var(--candidate-color);
stroke:#c62828;
stroke-width:2;
}
.node circle{
transition: r 0.3s ease, stroke-width 0.3s ease;
cursor:grab;
}
.node:hover circle{
r:22;
stroke-width:3;
}
.node:active{
cursor:grabbing;
}
/* Timeline axis styling */
.timeline-axis {
stroke: var(--timeline-line);
stroke-width: 3px;
stroke-opacity: 0.8;
}
.timeline-tick {
stroke: var(--timeline-line);
stroke-width: 2px;
stroke-opacity: 0.6;
}
.timeline-label {
fill: var(--muted);
font-size: 14px;
font-weight: 600;
text-anchor: middle;
}
/* Enhanced controls panel */
#controls{
position:fixed; top:20px; left:20px;
background:rgba(255,255,255,.95);
padding:20px 26px; border-radius:12px; border:1.5px solid #e0e0e0;
font-size:14px; box-shadow:0 4px 16px rgba(0,0,0,.12);
z-index: 100;
backdrop-filter: blur(8px);
max-width: 280px;
}
@media (prefers-color-scheme: dark){
#controls{
background:rgba(20,22,25,.95);
color:#e8e8e8;
border-color:#404040;
}
}
#controls label{
display:flex;
align-items:center;
margin-top:10px;
cursor:pointer;
}
#controls input[type="checkbox"]{
margin-right:8px;
cursor:pointer;
}
"""
TIMELINE_JS = """
function updateVisibility() {
const show = document.getElementById('toggleRed').checked;
svg.selectAll('.link.cand').style('display', show ? null : 'none');
svg.selectAll('.node.cand').style('display', show ? null : 'none');
}
document.getElementById('toggleRed').addEventListener('change', updateVisibility);
const timeline = __TIMELINE_DATA__;
const W = innerWidth, H = innerHeight;
// Enhanced timeline configuration for maximum horizontal spread
const MARGIN = { top: 60, right: 200, bottom: 120, left: 200 };
const CONTENT_HEIGHT = H - MARGIN.top - MARGIN.bottom;
const VERTICAL_LANES = 4; // Number of horizontal lanes for better organization
// Create SVG with zoom behavior
const svg = d3.select('#timeline-svg');
const zoomBehavior = d3.zoom()
.scaleExtent([0.1, 8])
.on('zoom', handleZoom);
svg.call(zoomBehavior);
const g = svg.append('g');
// Time scale for chronological positioning with much wider spread
const timeExtent = d3.extent(timeline.nodes.filter(d => d.timestamp > 0), d => d.timestamp);
let timeScale;
if (timeExtent[0] && timeExtent[1]) {
// Much wider timeline for maximum horizontal spread
const timeWidth = Math.max(W * 8, 8000);
timeScale = d3.scaleTime()
.domain(timeExtent.map(t => new Date(t * 1000)))
.range([MARGIN.left, timeWidth - MARGIN.right]);
// Timeline axis at the bottom
const timelineG = g.append('g').attr('class', 'timeline');
const timelineY = H - 80;
timelineG.append('line')
.attr('class', 'timeline-axis')
.attr('x1', MARGIN.left)
.attr('y1', timelineY)
.attr('x2', timeWidth - MARGIN.right)
.attr('y2', timelineY);
// Enhanced year markers with better spacing
const years = d3.timeYear.range(new Date(timeExtent[0] * 1000), new Date(timeExtent[1] * 1000 + 365*24*60*60*1000));
timelineG.selectAll('.timeline-tick')
.data(years)
.join('line')
.attr('class', 'timeline-tick')
.attr('x1', d => timeScale(d))
.attr('y1', timelineY - 15)
.attr('x2', d => timeScale(d))
.attr('y2', timelineY + 15);
timelineG.selectAll('.timeline-label')
.data(years)
.join('text')
.attr('class', 'timeline-label')
.attr('x', d => timeScale(d))
.attr('y', timelineY + 30)
.text(d => d.getFullYear());
}
function handleZoom(event) {
const { transform } = event;
g.attr('transform', transform);
}
// Enhanced curved links for better chronological flow visualization
const link = g.selectAll('path.link')
.data(timeline.links)
.join('path')
.attr('class', d => d.cand ? 'link cand' : 'link')
.attr('fill', 'none')
.attr('stroke-width', d => d.cand ? 2.5 : 1.5);
// Nodes with improved positioning strategy
const node = g.selectAll('g.node')
.data(timeline.nodes)
.join('g')
.attr('class', d => `node ${d.cls}`)
.call(d3.drag().on('start', dragStart).on('drag', dragged).on('end', dragEnd));
const baseSel = node.filter(d => d.cls === 'base');
baseSel.append('circle').attr('r', 20).attr('fill', '#ffbe0b');
node.filter(d => d.cls !== 'base').append('circle').attr('r', 18);
node.append('text')
.attr('class', 'node-label')
.attr('dy', '-2.2em')
.text(d => d.id);
// Organize nodes by chronological lanes for better vertical distribution
timeline.nodes.forEach((d, i) => {
if (d.timestamp > 0) {
// Assign lane based on chronological order within similar timeframes
const yearNodes = timeline.nodes.filter(n =>
n.timestamp > 0 &&
Math.abs(n.timestamp - d.timestamp) < 365*24*60*60
);
d.lane = yearNodes.indexOf(d) % VERTICAL_LANES;
} else {
d.lane = i % VERTICAL_LANES;
}
});
// Enhanced force simulation for optimal horizontal chronological layout
const sim = d3.forceSimulation(timeline.nodes)
.force('link', d3.forceLink(timeline.links).id(d => d.id)
.distance(d => d.cand ? 100 : 200)
.strength(d => d.cand ? 0.1 : 0.3))
.force('charge', d3.forceManyBody().strength(-300))
.force('collide', d3.forceCollide(d => 45).strength(0.8));
// Very strong chronological X positioning for proper horizontal spread
if (timeScale) {
sim.force('chronological', d3.forceX(d => {
if (d.timestamp > 0) {
return timeScale(new Date(d.timestamp * 1000));
}
// Place undated models at the end
return timeScale.range()[1] + 100;
}).strength(0.95));
}
// Organized Y positioning using lanes instead of random spread
sim.force('lanes', d3.forceY(d => {
const centerY = H / 2 - 100; // Position above timeline
const laneHeight = (H - 200) / (VERTICAL_LANES + 1); // Account for timeline space
const targetY = centerY - ((H - 200) / 2) + (d.lane + 1) * laneHeight;
return targetY;
}).strength(0.7));
// Add center force to prevent rightward drift
sim.force('center', d3.forceCenter(timeScale ? (timeScale.range()[0] + timeScale.range()[1]) / 2 : W / 2, H / 2 - 100).strength(0.1));
// Custom path generator for curved links that follow chronological flow
function linkPath(d) {
const sourceX = d.source.x || 0;
const sourceY = d.source.y || 0;
const targetX = d.target.x || 0;
const targetY = d.target.y || 0;
// Create curved paths for better visual flow
const dx = targetX - sourceX;
const dy = targetY - sourceY;
const dr = Math.sqrt(dx * dx + dy * dy) * 0.3;
// Curve direction based on chronological order
const curve = dx > 0 ? dr : -dr;
return `M${sourceX},${sourceY}A${dr},${dr} 0 0,1 ${targetX},${targetY}`;
}
sim.on('tick', () => {
link.attr('d', linkPath);
node.attr('transform', d => `translate(${d.x},${d.y})`);
});
function dragStart(e, d) {
if (!e.active) sim.alphaTarget(.3).restart();
d.fx = d.x;
d.fy = d.y;
}
function dragged(e, d) {
d.fx = e.x;
d.fy = e.y;
}
function dragEnd(e, d) {
if (!e.active) sim.alphaTarget(0);
d.fx = d.fy = null;
}
// Initialize
updateVisibility();
// Auto-fit timeline view with better zoom for horizontal spread
setTimeout(() => {
if (timeScale && timeExtent[0] && timeExtent[1]) {
const timeWidth = timeScale.range()[1] - timeScale.range()[0];
const scale = Math.min((W * 0.9) / timeWidth, 1);
const translateX = (W - timeWidth * scale) / 2;
const translateY = 0;
svg.transition()
.duration(2000)
.call(zoomBehavior.transform,
d3.zoomIdentity.translate(translateX, translateY).scale(scale));
}
}, 1500);
"""
TIMELINE_HTML = """
<!DOCTYPE html>
<html lang='en'><head><meta charset='UTF-8'>
<title>Transformers Chronological Timeline</title>
<style>__TIMELINE_CSS__</style></head><body>
<div id='controls'>
<div style='font-weight:600; margin-bottom:8px;'>Chronological Timeline</div>
🟑 base<br>πŸ”΅ modular<br>πŸ”΄ candidate<br>
<label><input type="checkbox" id="toggleRed" checked> Show candidates</label>
<div style='margin-top:10px; font-size:11px; color:var(--muted);'>
Models positioned by creation date<br>
Scroll & zoom to explore timeline
</div>
</div>
<svg id='timeline-svg'></svg>
<script src='https://d3js.org/d3.v7.min.js'></script>
<script>__TIMELINE_JS__</script></body></html>
"""
# ────────────────────────────────────────────────────────────────────────────────
# HTML writer
# ────────────────────────────────────────────────────────────────────────────────
def write_html(graph_data: dict, path: Path):
path.write_text(generate_html(graph_data), encoding="utf-8")
# ────────────────────────────────────────────────────────────────────────────────
# MAIN
# ────────────────────────────────────────────────────────────────────────────────
def main():
ap = argparse.ArgumentParser(description="Visualise modular dependencies + candidates")
ap.add_argument("transformers", help="Path to local πŸ€— transformers repo root")
ap.add_argument("--multimodal", action="store_true", help="filter to models with β‰₯3 'pixel_values'")
ap.add_argument("--sim-threshold", type=float, default=SIM_DEFAULT)
ap.add_argument("--out", default=HTML_DEFAULT)
ap.add_argument("--sim-method", choices=["jaccard", "embedding"], default="jaccard",
help="Similarity method: 'jaccard' or 'embedding'")
args = ap.parse_args()
graph = build_graph_json(
transformers_dir=Path(args.transformers).expanduser().resolve(),
threshold=args.sim_threshold,
multimodal=args.multimodal,
sim_method=args.sim_method,
)
write_html(graph, Path(args.out).expanduser())
if __name__ == "__main__":
main()