#!/usr/bin/env python """ modular_graph_and_candidates.py ================================ Create **one** rich view that combines 1. The *dependency graph* between existing **modular_*.py** implementations in πŸ€—Β Transformers (blue/🟑) **and** 2. The list of *missing* modular models (full‑red nodes) **plus** similarity edges (full‑red links) between highly‑overlapping modelling files – the output of *find_modular_candidates.py* – so you can immediately spot good refactor opportunities. ––– Usage ––– ```bash python modular_graph_and_candidates.py /path/to/transformers \ --multimodal # keep only models whose modelling code mentions # "pixel_values" β‰₯Β 3 times --sim-threshold 0.5 # Jaccard cutoff (default 0.50) --out graph.html # output HTML file name ``` Colour legend in the generated HTML: * 🟑 **base model**Β β€” has modular shards *imported* by others but no parent * πŸ”΅Β **derived modular model**Β β€” has a `modular_*.py` and inherits from β‰₯β€―1 model * πŸ”΄Β **candidate**Β β€” no `modular_*.py` yet (and/or very similar to another) * red edges = high‑Jaccard similarity links (potential to factorise) """ from __future__ import annotations import argparse import ast import json import re import subprocess import tokenize from collections import Counter, defaultdict from itertools import combinations from pathlib import Path from typing import Dict, List, Set, Tuple from sentence_transformers import SentenceTransformer, util from tqdm import tqdm import numpy as np import spaces import torch from datetime import datetime # ──────────────────────────────────────────────────────────────────────────────── # CONFIG # ─────────────────────────────────────────────────────────────────────────────── SIM_DEFAULT = 0.5 # similarity threshold PIXEL_MIN_HITS = 0 # multimodal trigger ("pixel_values") HTML_DEFAULT = "d3_modular_graph.html" # ──────────────────────────────────────────────────────────────────────────────── # 1) Helpers to analyse *modelling* files (for similarity & multimodal filter) # ──────────────────────────────────────────────────────────────────────────────── def _strip_source(code: str) -> str: """Remove doc‑strings, comments and import lines to keep only the core code.""" code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code) # doc‑strings code = re.sub(r"#.*", "", code) # # comments return "\n".join(ln for ln in code.splitlines() if not re.match(r"\s*(from|import)\s+", ln)) def _tokenise(code: str) -> Set[str]: """Extract identifiers using regex - more robust than tokenizer for malformed code.""" toks: Set[str] = set() for match in re.finditer(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', code): toks.add(match.group()) return toks def build_token_bags(models_root: Path) -> Tuple[Dict[str, List[Set[str]]], Dict[str, int]]: """Return token‑bags of every `modeling_*.py` plus a pixel‑value counter.""" bags: Dict[str, List[Set[str]]] = defaultdict(list) pixel_hits: Dict[str, int] = defaultdict(int) for mdl_dir in sorted(p for p in models_root.iterdir() if p.is_dir()): for py in mdl_dir.rglob("modeling_*.py"): try: text = py.read_text(encoding="utf‑8") pixel_hits[mdl_dir.name] += text.count("pixel_values") bags[mdl_dir.name].append(_tokenise(_strip_source(text))) except Exception as e: print(f"⚠️ Skipped {py}: {e}") return bags, pixel_hits def _jaccard(a: Set[str], b: Set[str]) -> float: return 0.0 if (not a or not b) else len(a & b) / len(a | b) def similarity_clusters(bags: Dict[str, List[Set[str]]], thr: float = 0.1) -> Dict[Tuple[str,str], float]: largest = {m: max(ts, key=len) for m, ts in bags.items() if ts} out: Dict[Tuple[str,str], float] = {} for m1, m2 in combinations(sorted(largest.keys()), 2): s = _jaccard(largest[m1], largest[m2]) if s >= thr: out[(m1, m2)] = s return out @spaces.GPU def embedding_similarity_clusters(models_root: Path, missing: List[str], thr: float = 0.1) -> Dict[Tuple[str, str], float]: model = SentenceTransformer("microsoft/codebert-base", trust_remote_code=True) try: cfg = model[0].auto_model.config pos_limit = int(getattr(cfg, "n_positions", getattr(cfg, "max_position_embeddings"))) except Exception: pos_limit = 1024 seq_len = min(pos_limit, 2048) model.max_seq_length = seq_len model[0].max_seq_length = seq_len model[0].tokenizer.model_max_length = seq_len texts = {} for name in tqdm(missing, desc="Reading modeling files"): if any(skip in name.lower() for skip in ["mobilebert", "lxmert"]): print(f"Skipping {name} (causes GPU abort)") continue code = "" for py in (models_root / name).rglob("modeling_*.py"): try: code += _strip_source(py.read_text(encoding="utf-8")) + "\n" except Exception: continue texts[name] = code.strip() or " " names = list(texts) all_embeddings = [] print(f"Encoding embeddings for {len(names)} models...") batch_size = 4 # keep your default # ── two-stage caching: temp (for resume) + permanent (for reuse) ───────────── temp_cache_path = Path("temp_embeddings.npz") # For resuming computation final_cache_path = Path("embeddings_cache.npz") # For permanent storage start_idx = 0 emb_dim = getattr(model, "get_sentence_embedding_dimension", lambda: 768)() # Try to load from permanent cache first if final_cache_path.exists(): try: cached = np.load(final_cache_path, allow_pickle=True) cached_names = list(cached["names"]) if names == cached_names: # Exact match - use final cache print(f"βœ… Using final embeddings cache ({len(cached_names)} models)") return compute_similarities_from_cache(thr) except Exception as e: print(f"⚠️ Failed to load final cache: {e}") # Try to resume from temp cache if temp_cache_path.exists(): try: cached = np.load(temp_cache_path, allow_pickle=True) cached_names = list(cached["names"]) if names[:len(cached_names)] == cached_names: loaded = cached["embeddings"].astype(np.float32) all_embeddings.append(loaded) start_idx = len(cached_names) print(f"πŸ”„ Resuming from temp cache: {start_idx}/{len(names)} models") except Exception as e: print(f"⚠️ Failed to load temp cache: {e}") # ─────────────────────────────────────────────────────────────────────────── for i in tqdm(range(start_idx, len(names), batch_size), desc="Batches", leave=False): batch_names = names[i:i+batch_size] batch_texts = [texts[name] for name in batch_names] try: print(f"Processing batch: {batch_names}") emb = model.encode(batch_texts, convert_to_numpy=True, show_progress_bar=False) except Exception as e: print(f"⚠️ GPU worker error for batch {batch_names}: {type(e).__name__}: {e}") emb = np.zeros((len(batch_names), emb_dim), dtype=np.float32) all_embeddings.append(emb) # save to temp cache after each batch (for resume) try: cur = np.vstack(all_embeddings).astype(np.float32) np.savez( temp_cache_path, embeddings=cur, names=np.array(names[:i+len(batch_names)], dtype=object), ) except Exception as e: print(f"⚠️ Failed to write temp cache: {e}") if (i - start_idx) % (3 * batch_size) == 0 and torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() print(f"🧹 Cleared GPU cache after batch {(i - start_idx)//batch_size + 1}") embeddings = np.vstack(all_embeddings).astype(np.float32) norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12 embeddings = embeddings / norms print("Computing pairwise similarities...") sims_mat = embeddings @ embeddings.T out = {} matrix_size = embeddings.shape[0] processed_names = names[:matrix_size] for i in range(matrix_size): for j in range(i + 1, matrix_size): s = float(sims_mat[i, j]) if s >= thr: out[(processed_names[i], processed_names[j])] = s # Save to final cache when complete try: np.savez(final_cache_path, embeddings=embeddings, names=np.array(names, dtype=object)) print(f"πŸ’Ύ Final embeddings saved to {final_cache_path}") # Clean up temp cache if temp_cache_path.exists(): temp_cache_path.unlink() print(f"🧹 Cleaned up temp cache") except Exception as e: print(f"⚠️ Failed to save final cache: {e}") return out def compute_similarities_from_cache(threshold: float) -> Dict[Tuple[str, str], float]: """Compute similarities from cached embeddings without reprocessing.""" embeddings_path = Path("embeddings_cache.npz") if not embeddings_path.exists(): return {} try: cached = np.load(embeddings_path, allow_pickle=True) embeddings = cached["embeddings"].astype(np.float32) names = list(cached["names"]) # Normalize embeddings norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12 embeddings = embeddings / norms # Compute similarities sims_mat = embeddings @ embeddings.T out = {} for i in range(len(names)): for j in range(i + 1, len(names)): s = float(sims_mat[i, j]) if s >= threshold: out[(names[i], names[j])] = s print(f"⚑ Computed {len(out)} similarities from cache (threshold: {threshold})") return out except Exception as e: print(f"⚠️ Failed to compute from cache: {e}") return {} def filter_similarities_by_threshold(similarities: Dict[Tuple[str, str], float], threshold: float) -> Dict[Tuple[str, str], float]: return {pair: score for pair, score in similarities.items() if score >= threshold} def filter_graph_by_threshold(graph_data: dict, threshold: float) -> dict: filtered_links = [] for link in graph_data["links"]: if link.get("cand", False): try: score = float(link["label"].rstrip('%')) / 100.0 if score >= threshold: filtered_links.append(link) except (ValueError, AttributeError): filtered_links.append(link) else: filtered_links.append(link) return { "nodes": graph_data["nodes"], "links": filtered_links, **{k: v for k, v in graph_data.items() if k not in ["nodes", "links"]} } # ──────────────────────────────────────────────────────────────────────────────── # 2) Scan *modular_*.py* files to build an import‑dependency graph # – only **modeling_*** imports are considered (skip configuration / processing) # ──────────────────────────────────────────────────────────────────────────────── def modular_files(models_root: Path) -> List[Path]: return [p for p in models_root.rglob("modular_*.py") if p.suffix == ".py"] def dependency_graph(modular_files: List[Path], models_root: Path) -> Dict[str, List[Dict[str,str]]]: """Return {derived_model: [{source, imported_class}, ...]} Only `modeling_*` imports are kept; anything coming from configuration/processing/ image* utils is ignored so the visual graph focuses strictly on modelling code. Excludes edges to sources whose model name is not a model dir. """ model_names = {p.name for p in models_root.iterdir() if p.is_dir()} deps: Dict[str, List[Dict[str,str]]] = defaultdict(list) for fp in modular_files: derived = fp.parent.name try: tree = ast.parse(fp.read_text(encoding="utf‑8"), filename=str(fp)) except Exception as e: print(f"⚠️ AST parse failed for {fp}: {e}") continue for node in ast.walk(tree): if not isinstance(node, ast.ImportFrom) or not node.module: continue mod = node.module # keep only *modeling_* imports, drop anything else if ("modeling_" not in mod or "configuration_" in mod or "processing_" in mod or "image_processing" in mod or "modeling_attn_mask_utils" in mod): continue parts = re.split(r"[./]", mod) src = next((p for p in parts if p not in {"", "models", "transformers"}), "") if not src or src == derived or src not in model_names: continue for alias in node.names: deps[derived].append({"source": src, "imported_class": alias.name}) return dict(deps) # modular_graph_and_candidates.py (top-level) def get_missing_models(models_root: Path, multimodal: bool = False) -> Tuple[List[str], Dict[str, List[Set[str]]], Dict[str, int]]: """Get list of models missing modular implementations.""" bags, pix_hits = build_token_bags(models_root) mod_files = modular_files(models_root) models_with_modular = {p.parent.name for p in mod_files} missing = [m for m in bags if m not in models_with_modular] if multimodal: missing = [m for m in missing if pix_hits[m] >= PIXEL_MIN_HITS] return missing, bags, pix_hits def compute_similarities(models_root: Path, missing: List[str], bags: Dict[str, List[Set[str]]], threshold: float, sim_method: str) -> Dict[Tuple[str, str], float]: min_threshold = 0.1 if sim_method == "jaccard": return similarity_clusters({m: bags[m] for m in missing}, min_threshold) else: embeddings_path = Path("embeddings_cache.npz") if embeddings_path.exists(): cached_sims = compute_similarities_from_cache(min_threshold) if cached_sims: return cached_sims return embedding_similarity_clusters(models_root, missing, min_threshold) def build_graph_json( transformers_dir: Path, threshold: float = SIM_DEFAULT, multimodal: bool = False, sim_method: str = "jaccard", ) -> dict: """Return the {nodes, links} dict that D3 needs.""" # Check if we can use cached embeddings only embeddings_cache = Path("embeddings_cache.npz") print(f"πŸ” Cache file exists: {embeddings_cache.exists()}, sim_method: {sim_method}") if sim_method == "embedding" and embeddings_cache.exists(): try: # Try to compute from cache without accessing repo cached_sims = compute_similarities_from_cache(0.1) print(f"πŸ” Got {len(cached_sims)} cached similarities") if cached_sims: # Create graph with cached similarities + modular dependencies cached_data = np.load(embeddings_cache, allow_pickle=True) missing = list(cached_data["names"]) # Still need to get modular dependencies from repo models_root = transformers_dir / "src/transformers/models" mod_files = modular_files(models_root) deps = dependency_graph(mod_files, models_root) # Build full graph structure nodes = set(missing) # Start with cached models links = [] # Add dependency links for drv, lst in deps.items(): for d in lst: links.append({ "source": d["source"], "target": drv, "label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports", "cand": False }) nodes.update({d["source"], drv}) # Add similarity links for (a, b), s in cached_sims.items(): links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True}) # Create node list with proper classification targets = {lk["target"] for lk in links if not lk["cand"]} sources = {lk["source"] for lk in links if not lk["cand"]} nodelist = [] for n in sorted(nodes): if n in missing and n not in sources and n not in targets: cls = "cand" elif n in sources and n not in targets: cls = "base" else: cls = "derived" nodelist.append({"id": n, "cls": cls, "sz": 1}) graph = {"nodes": nodelist, "links": links} print(f"⚑ Built graph from cache: {len(nodelist)} nodes, {len(links)} links") if threshold > 0.1: graph = filter_graph_by_threshold(graph, threshold) return graph except Exception as e: print(f"⚠️ Cache-only build failed: {e}, falling back to full build") # Full build with repository access models_root = transformers_dir / "src/transformers/models" # Get missing models and their data missing, bags, pix_hits = get_missing_models(models_root, multimodal) # Build dependency graph mod_files = modular_files(models_root) deps = dependency_graph(mod_files, models_root) # Compute similarities sims = compute_similarities(models_root, missing, bags, threshold, sim_method) # ---- assemble nodes & links ---- nodes: Set[str] = set() links: List[dict] = [] for drv, lst in deps.items(): for d in lst: links.append({ "source": d["source"], "target": drv, "label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports", "cand": False }) nodes.update({d["source"], drv}) for (a, b), s in sims.items(): links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True}) nodes.update({a, b}) nodes.update(missing) deg = Counter() for lk in links: deg[lk["source"]] += 1 deg[lk["target"]] += 1 max_deg = max(deg.values() or [1]) targets = {lk["target"] for lk in links if not lk["cand"]} sources = {lk["source"] for lk in links if not lk["cand"]} missing_only = [m for m in missing if m not in sources and m not in targets] nodes.update(missing_only) nodelist = [] for n in sorted(nodes): if n in missing_only: cls = "cand" elif n in sources and n not in targets: cls = "base" else: cls = "derived" nodelist.append({"id": n, "cls": cls, "sz": 1 + 2*(deg[n]/max_deg)}) graph = {"nodes": nodelist, "links": links} if threshold > 0.1: graph = filter_graph_by_threshold(graph, threshold) return graph # ──────────────────────────────────────────────────────────────────────────────── # Timeline functions for chronological visualization # ──────────────────────────────────────────────────────────────────────────────── def get_model_creation_dates(transformers_dir: Path) -> Dict[str, datetime]: """Get creation dates for all model directories by finding the earliest add of the directory path.""" models_root = transformers_dir / "src/transformers/models" creation_dates: Dict[str, datetime] = {} if not models_root.exists(): return creation_dates def run_git(args: list[str]) -> subprocess.CompletedProcess: return subprocess.run( ["git"] + args, cwd=transformers_dir, capture_output=True, text=True, timeout=120, ) # Ensure full history; shallow clones make every path look newly added "today". shallow = run_git(["rev-parse", "--is-shallow-repository"]) if shallow.returncode == 0 and shallow.stdout.strip() == "true": # Try best-effort unshallow; if it fails, we still proceed. run_git(["fetch", "--unshallow", "--tags", "--prune"]) # ignore return code # Fallback if server forbids --unshallow run_git(["fetch", "--depth=100000", "--tags", "--prune"]) for model_dir in models_root.iterdir(): if not model_dir.is_dir(): continue rel = f"src/transformers/models/{model_dir.name}/" # Earliest commit that ADDED something under this directory. # Use a stable delimiter to avoid locale/spacing issues. proc = run_git([ "log", "--reverse", # oldest β†’ newest "--diff-filter=A", # additions only "--date=short", # YYYY-MM-DD '--format=%H|%ad', # hash|date "--", rel, ]) if proc.returncode != 0 or not proc.stdout.strip(): # As a fallback, look at the earliest commit touching any tracked file under the dir. # This can catch cases where files were moved (rename) rather than added. ls = run_git(["ls-files", rel]) files = [ln for ln in ls.stdout.splitlines() if ln.strip()] best_date: datetime | None = None if files: for fp in files: proc_file = run_git([ "log", "--reverse", "--diff-filter=A", "--date=short", "--format=%H|%ad", "--", fp, ]) line = proc_file.stdout.splitlines()[0].strip() if proc_file.stdout else "" if line and "|" in line: _, d = line.split("|", 1) try: dt = datetime.strptime(d.strip(), "%Y-%m-%d") if best_date is None or dt < best_date: best_date = dt except ValueError: pass if best_date is not None: creation_dates[model_dir.name] = best_date print(f"βœ… {model_dir.name}: {best_date.strftime('%Y-%m-%d')}") else: print(f"❌ {model_dir.name}: no add commit found") continue first_line = proc.stdout.splitlines()[0].strip() # oldest add if "|" in first_line: _, date_str = first_line.split("|", 1) try: creation_dates[model_dir.name] = datetime.strptime(date_str.strip(), "%Y-%m-%d") print(f"βœ… {model_dir.name}: {date_str.strip()}") except ValueError: print(f"❌ {model_dir.name}: bad date format: {date_str!r}") else: print(f"❌ {model_dir.name}: unexpected log format: {first_line!r}") return creation_dates def build_timeline_json( transformers_dir: Path, threshold: float = SIM_DEFAULT, multimodal: bool = False, sim_method: str = "jaccard", ) -> dict: """Build chronological timeline with modular connections.""" # Get the standard dependency graph for connections graph = build_graph_json(transformers_dir, threshold, multimodal, sim_method) # Get creation dates for chronological positioning creation_dates = get_model_creation_dates(transformers_dir) # Enhance nodes with chronological data for node in graph["nodes"]: model_name = node["id"] if model_name in creation_dates: creation_date = creation_dates[model_name] node.update({ "date": creation_date.isoformat(), "year": creation_date.year, "timestamp": creation_date.timestamp() }) else: # Fallback for models without date info node.update({ "date": "2020-01-01T00:00:00", # Default date "year": 2020, "timestamp": datetime(2020, 1, 1).timestamp() }) # Add timeline metadata valid_dates = [n for n in graph["nodes"] if n["timestamp"] > 0] if valid_dates: min_year = min(n["year"] for n in valid_dates) max_year = max(n["year"] for n in valid_dates) graph["timeline_meta"] = { "min_year": min_year, "max_year": max_year, "total_models": len(graph["nodes"]), "dated_models": len(valid_dates) } else: graph["timeline_meta"] = { "min_year": 2018, "max_year": 2024, "total_models": len(graph["nodes"]), "dated_models": 0 } return graph def generate_html(graph: dict) -> str: """Return the full HTML string with inlined CSS/JS + graph JSON.""" js = JS.replace("__GRAPH_DATA__", json.dumps(graph, separators=(",", ":"))) return HTML.replace("__CSS__", CSS).replace("__JS__", js) def generate_timeline_html(timeline: dict) -> str: """Return the full HTML string for chronological timeline visualization.""" js = TIMELINE_JS.replace("__TIMELINE_DATA__", json.dumps(timeline, separators=(",", ":"))) return TIMELINE_HTML.replace("__TIMELINE_CSS__", TIMELINE_CSS).replace("__TIMELINE_JS__", js) # ──────────────────────────────────────────────────────────────────────────────── # 3) HTML (D3.js) boilerplate – CSS + JS templates (unchanged design) # ──────────────────────────────────────────────────────────────────────────────── CSS = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap'); :root{ --bg:#ffffff; --text:#222222; --muted:#555555; --outline:#ffffff; } @media (prefers-color-scheme: dark){ :root{ --bg:#0b0d10; --text:#e8e8e8; --muted:#c8c8c8; --outline:#000000; } } body{ margin:0; font-family:'Inter',Arial,sans-serif; background:var(--bg); overflow:hidden; } svg{ width:100vw; height:100vh; } .link{ stroke:#999; stroke-opacity:.6; } .link.cand{ stroke:#e63946; stroke-width:2.5; } .node-label{ fill:var(--text); pointer-events:none; text-anchor:middle; font-weight:600; paint-order:stroke fill; stroke:var(--outline); stroke-width:3px; } .link-label{ fill:var(--muted); pointer-events:none; text-anchor:middle; font-size:12px; paint-order:stroke fill; stroke:var(--bg); stroke-width:2px; } .node.base image{ width:60px; height:60px; transform:translate(-30px,-30px); } .node.derived circle{ fill:#1f77b4; } .node.cand circle, .node.cand path{ fill:#e63946; } #legend{ position:fixed; top:18px; left:18px; background:rgba(255,255,255,.92); padding:18px 28px; border-radius:10px; border:1.5px solid #bbb; font-size:22px; box-shadow:0 2px 8px rgba(0,0,0,.08); } @media (prefers-color-scheme: dark){ #legend{ background:rgba(20,22,25,.92); color:#e8e8e8; border-color:#444; } } """ JS = """ function updateVisibility() { const show = document.getElementById('toggleRed').checked; svg.selectAll('.link.cand').style('display', show ? null : 'none'); svg.selectAll('.node.cand').style('display', show ? null : 'none'); svg.selectAll('.link-label').filter(d => d.cand).style('display', show ? null : 'none'); } document.getElementById('toggleRed').addEventListener('change', updateVisibility); const graph = __GRAPH_DATA__; const W = innerWidth, H = innerHeight; const svg = d3.select('#dependency').call(d3.zoom().on('zoom', e => g.attr('transform', e.transform))); const g = svg.append('g'); const link = g.selectAll('line') .data(graph.links) .join('line') .attr('class', d => d.cand ? 'link cand' : 'link'); const linkLbl = g.selectAll('text.link-label') .data(graph.links) .join('text') .attr('class', 'link-label') .text(d => d.label); const node = g.selectAll('g.node') .data(graph.nodes) .join('g') .attr('class', d => `node ${d.cls}`) .call(d3.drag().on('start', dragStart).on('drag', dragged).on('end', dragEnd)); const baseSel = node.filter(d => d.cls === 'base'); baseSel.append('circle').attr('r', d => 22*d.sz).attr('fill', '#ffbe0b'); node.filter(d => d.cls !== 'base').append('circle').attr('r', d => 20*d.sz); node.append('text') .attr('class','node-label') .attr('dy','-2.4em') .style('font-size', d => d.cls === 'base' ? '110px' : '70px') .style('font-weight', d => d.cls === 'base' ? 'bold' : 'normal') .text(d => d.id); const sim = d3.forceSimulation(graph.nodes) .force('link', d3.forceLink(graph.links).id(d => d.id).distance(520)) .force('charge', d3.forceManyBody().strength(-600)) .force('center', d3.forceCenter(W / 2, H / 2)) .force('collide', d3.forceCollide(d => 50)); sim.on('tick', () => { link.attr('x1', d=>d.source.x).attr('y1', d=>d.source.y) .attr('x2', d=>d.target.x).attr('y2', d=>d.target.y); linkLbl.attr('x', d=> (d.source.x+d.target.x)/2) .attr('y', d=> (d.source.y+d.target.y)/2); node.attr('transform', d=>`translate(${d.x},${d.y})`); }); function dragStart(e,d){ if(!e.active) sim.alphaTarget(.3).restart(); d.fx=d.x; d.fy=d.y; } function dragged(e,d){ d.fx=e.x; d.fy=e.y; } function dragEnd(e,d){ if(!e.active) sim.alphaTarget(0); d.fx=d.fy=null; } """ HTML = """ Transformers modular graph
🟑 base
πŸ”΅ modular
πŸ”΄ candidate
red edgeΒ = high embedding similarity

""" # ──────────────────────────────────────────────────────────────────────────────── # Timeline HTML Templates # ──────────────────────────────────────────────────────────────────────────────── TIMELINE_CSS = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap'); :root{ --bg:#ffffff; --text:#222222; --muted:#555555; --outline:#ffffff; --timeline-line:#dee2e6; --base-color:#ffbe0b; --derived-color:#1f77b4; --candidate-color:#e63946; } @media (prefers-color-scheme: dark){ :root{ --bg:#0b0d10; --text:#e8e8e8; --muted:#c8c8c8; --outline:#000000; --timeline-line:#343a40; } } body{ margin:0; font-family:'Inter',Arial,sans-serif; background:var(--bg); overflow:hidden; } svg{ width:100vw; height:100vh; } /* Enhanced link styles for chronological flow */ .link{ stroke:#4a90e2; stroke-opacity:0.6; stroke-width:1.5; fill:none; transition: stroke-opacity 0.3s ease; } .link.cand{ stroke:var(--candidate-color); stroke-width:2.5; stroke-opacity:0.8; stroke-dasharray: 4,4; } .link:hover{ stroke-opacity:1; stroke-width:3; } /* Improved node label styling */ .node-label{ fill:var(--text); pointer-events:none; text-anchor:middle; font-weight:600; font-size:50px; paint-order:stroke fill; stroke:var(--outline); stroke-width:3px; cursor:default; } /* Enhanced node styling with better visual hierarchy */ .node.base circle{ fill:var(--base-color); stroke:#d4a000; stroke-width:2; } .node.derived circle{ fill:var(--derived-color); stroke:#1565c0; stroke-width:2; } .node.cand circle{ fill:var(--candidate-color); stroke:#c62828; stroke-width:2; } .node circle{ transition: r 0.3s ease, stroke-width 0.3s ease; cursor:grab; } .node:hover circle{ r:22; stroke-width:3; } .node:active{ cursor:grabbing; } /* Timeline axis styling */ .timeline-axis { stroke: var(--timeline-line); stroke-width: 3px; stroke-opacity: 0.8; } .timeline-tick { stroke: var(--timeline-line); stroke-width: 2px; stroke-opacity: 0.6; } .timeline-month-tick { stroke: var(--timeline-line); stroke-width: 1px; stroke-opacity: 0.4; } .timeline-label { fill: var(--muted); font-size: 40px; font-weight: 600; text-anchor: middle; } .timeline-month-label { fill: var(--muted); font-size: 35px; font-weight: 400; text-anchor: middle; opacity: 0.7; } .modular-milestone { stroke: #ff6b35; stroke-width: 3px; stroke-opacity: 0.8; stroke-dasharray: 5,5; } .modular-milestone-label { fill: #ff6b35; font-size: 35px; font-weight: 600; text-anchor: middle; } /* Enhanced controls panel */ #controls{ position:fixed; top:20px; left:20px; background:rgba(255,255,255,.95); padding:20px 26px; border-radius:12px; border:1.5px solid #e0e0e0; font-size:24px; box-shadow:0 4px 16px rgba(0,0,0,.12); z-index: 100; backdrop-filter: blur(8px); max-width: 280px; } @media (prefers-color-scheme: dark){ #controls{ background:rgba(20,22,25,.95); color:#e8e8e8; border-color:#404040; } } #controls label{ display:flex; align-items:center; margin-top:10px; cursor:pointer; } #controls input[type="checkbox"]{ margin-right:8px; cursor:pointer; } """ TIMELINE_JS = """ function updateVisibility() { const show = document.getElementById('toggleRed').checked; svg.selectAll('.link.cand').style('display', show ? null : 'none'); svg.selectAll('.node.cand').style('display', show ? null : 'none'); } document.getElementById('toggleRed').addEventListener('change', updateVisibility); const timeline = __TIMELINE_DATA__; const W = innerWidth, H = innerHeight; // Create SVG with zoom behavior const svg = d3.select('#timeline-svg'); const g = svg.append('g'); // Enhanced timeline configuration for maximum horizontal spread const MARGIN = { top: 60, right: 200, bottom: 120, left: 200 }; const CONTENT_HEIGHT = H - MARGIN.top - MARGIN.bottom; const VERTICAL_LANES = 4; // Number of horizontal lanes for better organization const zoomBehavior = d3.zoom() .scaleExtent([0.1, 8]) .on('zoom', handleZoom); svg.call(zoomBehavior); svg.on("click", function(event) { if (event.target.tagName === "svg") { node.select("circle").style("opacity", 1); link.style("opacity", 1); g.selectAll(".node-label").style("opacity", 1); } }); // Time scale for chronological positioning with much wider spread const timeExtent = d3.extent(timeline.nodes.filter(d => d.timestamp > 0), d => d.timestamp); let timeScale; if (timeExtent[0] && timeExtent[1]) { // Much wider timeline for maximum horizontal spread const timeWidth = Math.max(W * 8, 8000); timeScale = d3.scaleTime() .domain(timeExtent.map(t => new Date(t * 1000))) .range([MARGIN.left, timeWidth - MARGIN.right]); // Timeline axis at the bottom const timelineG = g.append('g').attr('class', 'timeline'); const timelineY = H - 80; timelineG.append('line') .attr('class', 'timeline-axis') .attr('x1', MARGIN.left) .attr('y1', timelineY) .attr('x2', timeWidth - MARGIN.right) .attr('y2', timelineY); // Enhanced year markers with better spacing const years = d3.timeYear.range(new Date(timeExtent[0] * 1000), new Date(timeExtent[1] * 1000 + 365*24*60*60*1000)); const months = d3.timeMonth.range(new Date(timeExtent[0] * 1000), new Date(timeExtent[1] * 1000 + 365*24*60*60*1000)); timelineG.selectAll('.timeline-tick') .data(years) .join('line') .attr('class', 'timeline-tick') .attr('x1', d => timeScale(d)) .attr('y1', timelineY - 15) .attr('x2', d => timeScale(d)) .attr('y2', timelineY + 15); timelineG.selectAll('.timeline-month-tick') .data(months) .join('line') .attr('class', 'timeline-month-tick') .attr('x1', d => timeScale(d)) .attr('y1', timelineY - 8) .attr('x2', d => timeScale(d)) .attr('y2', timelineY + 8); timelineG.selectAll('.timeline-label') .data(years) .join('text') .attr('class', 'timeline-label') .attr('x', d => timeScale(d)) .attr('y', timelineY + 30) .text(d => d.getFullYear()); timelineG.selectAll('.timeline-month-label') .data(months.filter((d, i) => i % 3 === 0)) .join('text') .attr('class', 'timeline-month-label') .attr('x', d => timeScale(d)) .attr('y', timelineY + 45) .text(d => d.toLocaleDateString('en', { month: 'short' })); // Modular logic milestone marker - May 31, 2024 const modularDate = new Date(2024, 4, 31); timelineG.append('line') .attr('class', 'modular-milestone') .attr('x1', timeScale(modularDate)) .attr('y1', MARGIN.top) .attr('x2', timeScale(modularDate)) .attr('y2', H - MARGIN.bottom); timelineG.append('text') .attr('class', 'modular-milestone-label') .attr('x', timeScale(modularDate)) .attr('y', MARGIN.top - 10) .attr('text-anchor', 'middle') .text('Modular Logic Added'); } function handleZoom(event) { const { transform } = event; g.attr('transform', transform); } // Enhanced curved links for better chronological flow visualization const link = g.selectAll('path.link') .data(timeline.links) .join('path') .attr('class', d => d.cand ? 'link cand' : 'link') .attr('fill', 'none') .attr('stroke-width', d => d.cand ? 2.5 : 1.5); const linkedByIndex = {}; timeline.links.forEach(d => { const s = typeof d.source === 'object' ? d.source.id : d.source; const t = typeof d.target === 'object' ? d.target.id : d.target; linkedByIndex[`${s},${t}`] = true; linkedByIndex[`${t},${s}`] = true; }); function isConnected(a, b) { return linkedByIndex[`${a.id},${b.id}`] || a.id === b.id; } function isConnected(a, b) { return linkedByIndex[`${a.id},${b.id}`] || a.id === b.id; } // Nodes with improved positioning strategy const node = g.selectAll('g.node') .data(timeline.nodes) .join('g') .attr('class', d => `node ${d.cls}`) .call(d3.drag().on('start', dragStart).on('drag', dragged).on('end', dragEnd)); node.on("click", function(event, d) { event.stopPropagation(); node.select("circle").style("opacity", o => isConnected(d, o) ? 1 : 0.1); g.selectAll(".node-label").style("opacity", o => isConnected(d, o) ? 1 : 0.1); link.style("opacity", o => (o.source.id === d.id || o.target.id === d.id) ? 1 : 0.1); }); const baseSel = node.filter(d => d.cls === 'base'); baseSel.append('circle').attr('r', 20).attr('fill', '#ffbe0b'); node.filter(d => d.cls !== 'base').append('circle').attr('r', 18); node.append('text') .attr('class', 'node-label') .attr('dy', '-2.2em') .text(d => d.id); // Organize nodes by chronological lanes for better vertical distribution timeline.nodes.forEach((d, i) => { if (d.timestamp > 0) { // Assign lane based on chronological order within similar timeframes const yearNodes = timeline.nodes.filter(n => n.timestamp > 0 && Math.abs(n.timestamp - d.timestamp) < 365*24*60*60 ); d.lane = yearNodes.indexOf(d) % VERTICAL_LANES; } else { d.lane = i % VERTICAL_LANES; } }); // Enhanced force simulation for optimal horizontal chronological layout const sim = d3.forceSimulation(timeline.nodes) .force('link', d3.forceLink(timeline.links).id(d => d.id) .distance(d => d.cand ? 200 : 300) .strength(d => d.cand ? 0.1 : 0.3)) .force('charge', d3.forceManyBody().strength(-800)) .force('collide', d3.forceCollide(d => 70).strength(1)) // Very strong chronological X positioning for proper horizontal spread if (timeScale) { sim.force('chronological', d3.forceX(d => { if (d.timestamp > 0) { return timeScale(new Date(d.timestamp * 1000)); } // Place undated models at the end return timeScale.range()[1] + 100; }).strength(0.75)); } // Organized Y positioning using lanes instead of random spread sim.force('lanes', d3.forceY(d => { const centerY = H / 2 - 100; // Position above timeline const laneHeight = (H - 200) / (VERTICAL_LANES + 1); // Account for timeline space const targetY = centerY - ((H - 200) / 2) + (d.lane + 1) * laneHeight; return targetY; }).strength(0.7)); // Add center force to prevent rightward drift sim.force('center', d3.forceCenter(timeScale ? (timeScale.range()[0] + timeScale.range()[1]) / 2 : W / 2, H / 2 - 100).strength(0.1)); // Custom path generator for curved links that follow chronological flow function linkPath(d) { const sourceX = d.source.x || 0; const sourceY = d.source.y || 0; const targetX = d.target.x || 0; const targetY = d.target.y || 0; // Create curved paths for better visual flow const dx = targetX - sourceX; const dy = targetY - sourceY; const dr = Math.sqrt(dx * dx + dy * dy) * 0.3; // Curve direction based on chronological order const curve = dx > 0 ? dr : -dr; return `M${sourceX},${sourceY}A${dr},${dr} 0 0,1 ${targetX},${targetY}`; } function idOf(x){ return typeof x === 'object' ? x.id : x; } function neighborsOf(id){ const out = new Set([id]); Object.keys(linkedByIndex).forEach(k=>{ const [a,b] = k.split(','); if(a===id) out.add(b); if(b===id) out.add(a); }); return out; } // Highlight matches + neighbors; empty query resets function applySearch(q){ q = (q || '').trim().toLowerCase(); if(!q){ node.select("circle").style("opacity", 1); g.selectAll(".node-label").style("opacity", 1); link.style("opacity", 1); return; } const matches = new Set(timeline.nodes.filter(n => n.id.toLowerCase().includes(q)).map(n=>n.id)); const keep = new Set(); matches.forEach(m => neighborsOf(m).forEach(x => keep.add(x))); node.select("circle").style("opacity", d => keep.has(d.id) ? 1 : 0.08); g.selectAll(".node-label").style("opacity", d => keep.has(d.id) ? 1 : 0.08); link.style("opacity", d => { const s = idOf(d.source), t = idOf(d.target); return (keep.has(s) && keep.has(t)) ? 1 : 0.08; }); } // wire it up document.getElementById('searchBox').addEventListener('input', e => applySearch(e.target.value)); sim.on('tick', () => { link.attr('d', linkPath); node.attr('transform', d => `translate(${d.x},${d.y})`); }); function dragStart(e, d) { if (!e.active) sim.alphaTarget(.3).restart(); d.fx = d.x; d.fy = d.y; } function dragged(e, d) { d.fx = e.x; d.fy = e.y; } function dragEnd(e, d) { if (!e.active) sim.alphaTarget(0); d.fx = d.fy = null; } // Initialize updateVisibility(); // Auto-fit timeline view with better zoom for horizontal spread setTimeout(() => { if (timeScale && timeExtent[0] && timeExtent[1]) { const timeWidth = timeScale.range()[1] - timeScale.range()[0]; const scale = Math.min((W * 0.9) / timeWidth, 1); const translateX = (W - timeWidth * scale) / 2; const translateY = 0; svg.transition() .duration(2000) .call(zoomBehavior.transform, d3.zoomIdentity.translate(translateX, translateY).scale(scale)); } }, 1500); """ TIMELINE_HTML = """ Transformers Chronological Timeline
Chronological Timeline
🟑 base
πŸ”΅ modular
πŸ”΄ candidate
Models positioned by creation date
Scroll & zoom to explore timeline
""" # ──────────────────────────────────────────────────────────────────────────────── # HTML writer # ──────────────────────────────────────────────────────────────────────────────── def write_html(graph_data: dict, path: Path): path.write_text(generate_html(graph_data), encoding="utf-8") # ──────────────────────────────────────────────────────────────────────────────── # MAIN # ──────────────────────────────────────────────────────────────────────────────── def main(): ap = argparse.ArgumentParser(description="Visualise modular dependencies + candidates") ap.add_argument("transformers", help="Path to local πŸ€— transformers repo root") ap.add_argument("--multimodal", action="store_true", help="filter to models with β‰₯3 'pixel_values'") ap.add_argument("--sim-threshold", type=float, default=SIM_DEFAULT) ap.add_argument("--out", default=HTML_DEFAULT) ap.add_argument("--sim-method", choices=["jaccard", "embedding"], default="jaccard", help="Similarity method: 'jaccard' or 'embedding'") args = ap.parse_args() graph = build_graph_json( transformers_dir=Path(args.transformers).expanduser().resolve(), threshold=args.sim_threshold, multimodal=args.multimodal, sim_method=args.sim_method, ) write_html(graph, Path(args.out).expanduser()) if __name__ == "__main__": main()