Nu Appleblossom
back to basics, giving up on scrolling text inactual UI now trying to make tree image
bf14101
raw
history blame
30.9 kB
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from safetensors import safe_open
import os
import requests
import json
import math
import numpy as np
from sklearn.decomposition import PCA
import logging
import time
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
import spaces
import traceback
from graphviz import Digraph
from PIL import Image, ImageDraw, ImageFont
from io import BytesIO
import functools
import logging
# Set up custom logger
custom_logger = logging.getLogger("custom_logger")
custom_logger.setLevel(logging.INFO)
# Prevent the root logger from duplicating messages
custom_logger.propagate = False
# Set up custom handler and formatter
custom_handler = logging.StreamHandler()
custom_handler.setFormatter(logging.Formatter('%(message)s'))
custom_logger.addHandler(custom_handler)
# Load environment variables
load_dotenv()
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info(f"HF_TOKEN_GEMMA set: {'HF_TOKEN_GEMMA' in os.environ}")
logger.info(f"HF_TOKEN_EMBEDDINGS set: {'HF_TOKEN_EMBEDDINGS' in os.environ}")
class Config:
def __init__(self):
self.MODEL_NAME = "google/gemma-2b"
self.ACCESS_TOKEN = os.environ.get("HF_TOKEN_GEMMA")
self.EMBEDDINGS_TOKEN = os.environ.get("HF_TOKEN_EMBEDDINGS")
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
self.DTYPE = torch.float32
self.TOPK = 5
self.CUTOFF = 0.00001 # Cumulative probability cutoff for tree branches
self.OUTPUT_LENGTH = 20
self.SUB_TOKEN_ID = 23070 # Arbitrary token ID to overwrite with embedding (token = "OSS")
self.LOG_BASE = 10
def get_sub_token_string(self, tokenizer):
return tokenizer.decode([self.SUB_TOKEN_ID])
config = Config()
def load_tokenizer():
try:
logger.info(f"Attempting to load tokenizer with token: {config.ACCESS_TOKEN[:5]}...")
tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME, token=config.ACCESS_TOKEN)
logger.info("Tokenizer loaded successfully")
return tokenizer
except Exception as e:
logger.error(f"Error loading tokenizer: {str(e)}")
return None
def load_model():
try:
logger.info(f"Attempting to load model with token: {config.ACCESS_TOKEN[:5]}...")
model = AutoModelForCausalLM.from_pretrained(config.MODEL_NAME, device_map="auto", token=config.ACCESS_TOKEN)
logger.info("Model loaded successfully")
return model
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return None
def load_token_embeddings():
try:
logger.info(f"Attempting to load token embeddings with token: {config.EMBEDDINGS_TOKEN[:5]}...")
embeddings_path = hf_hub_download(
repo_id="mwatkins1970/gemma-2b-embeddings",
filename="gemma_2b_embeddings.pt",
token=config.EMBEDDINGS_TOKEN
)
logger.info(f"Embeddings downloaded to: {embeddings_path}")
embeddings = torch.load(embeddings_path, map_location=config.DEVICE, weights_only=True)
logger.info("Embeddings loaded successfully")
return embeddings.to(dtype=config.DTYPE)
except Exception as e:
logger.error(f"Error loading token embeddings: {str(e)}")
return None
def load_sae_weights(sae_name):
start_time = time.time()
base_url = 'https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs/resolve/main/'
sae_urls = {
"Gemma-2B layer 6": "gemma_2b_blocks.6.hook_resid_post_16384_anthropic_fast_lr/sae_weights.safetensors",
"Gemma-2B layer 0": "gemma_2b_blocks.0.hook_resid_post_16384_anthropic/sae_weights.safetensors",
"Gemma-2B layer 10": "gemma_2b_blocks.10.hook_resid_post_16384/sae_weights.safetensors",
"Gemma-2B layer 12": "gemma_2b_blocks.12.hook_resid_post_16384/sae_weights.safetensors"
}
if sae_name not in sae_urls:
raise ValueError(f"Unknown SAE: {sae_name}")
url = f'{base_url}{sae_urls[sae_name]}?download=true'
local_filename = f'sae_{sae_name.replace(" ", "_").lower()}.safetensors'
if not os.path.exists(local_filename):
try:
response = requests.get(url)
response.raise_for_status()
with open(local_filename, 'wb') as f:
f.write(response.content)
logger.info(f'SAE weights for {sae_name} downloaded successfully!')
except requests.RequestException as e:
logger.error(f"Failed to download SAE weights for {sae_name}: {str(e)}")
return None, None
try:
with safe_open(local_filename, framework="pt") as f:
w_dec = f.get_tensor("W_dec").to(device=config.DEVICE, dtype=config.DTYPE)
w_enc = f.get_tensor("W_enc").to(device=config.DEVICE, dtype=config.DTYPE)
logger.info(f"Successfully loaded weights for {sae_name}")
logger.info(f"Time taken to load weights: {time.time() - start_time:.2f} seconds")
return w_enc, w_dec
except Exception as e:
logger.error(f"Error loading SAE weights for {sae_name}: {str(e)}")
return None, None
@torch.no_grad()
def create_feature_vector(w_enc, w_dec, feature_number, weight_type, token_centroid, use_token_centroid, scaling_factor):
if weight_type == "encoder":
feature_vector = w_enc[:, feature_number]
else:
feature_vector = w_dec[feature_number]
if use_token_centroid:
feature_vector = token_centroid + scaling_factor * (feature_vector - token_centroid) / torch.norm(feature_vector - token_centroid)
return feature_vector
def perform_pca(_embeddings):
try:
logger.info(f"Starting PCA. Embeddings shape: {_embeddings.shape}")
pca = PCA(n_components=1)
embeddings_cpu = _embeddings.detach().cpu().numpy()
logger.info(f"Embeddings converted to numpy. Shape: {embeddings_cpu.shape}")
pca.fit(embeddings_cpu)
logger.info("PCA fit completed")
pca_direction = torch.tensor(pca.components_[0], dtype=config.DTYPE, device=config.DEVICE)
logger.info(f"PCA direction calculated. Shape: {pca_direction.shape}")
normalized_direction = F.normalize(pca_direction, p=2, dim=0)
logger.info(f"PCA direction normalized. Shape: {normalized_direction.shape}")
return normalized_direction
except Exception as e:
logger.error(f"Error in perform_pca: {str(e)}")
logger.error(f"Embeddings stats - min: {_embeddings.min()}, max: {_embeddings.max()}, mean: {_embeddings.mean()}, std: {_embeddings.std()}")
logger.error(traceback.format_exc())
raise RuntimeError(f"PCA calculation failed: {str(e)}")
@torch.no_grad()
def create_ghost_token(_feature_vector, _token_centroid, _pca_direction, target_distance, pca_weight):
feature_direction = F.normalize(_feature_vector - _token_centroid, p=2, dim=0)
combined_direction = (1 - pca_weight) * feature_direction + pca_weight * _pca_direction
combined_direction = F.normalize(combined_direction, p=2, dim=0)
return _token_centroid + target_distance * combined_direction
@torch.no_grad()
def find_closest_tokens(_emb, _token_embeddings, _tokenizer, top_k=500, num_exp=1.4, denom_exp=1.0):
token_centroid = torch.mean(_token_embeddings, dim=0)
emb_norm = F.normalize(_emb.view(1, -1), p=2, dim=1)
centroid_norm = F.normalize(token_centroid.view(1, -1), p=2, dim=1)
normalized_embeddings = F.normalize(_token_embeddings, p=2, dim=1)
similarities_emb = torch.mm(emb_norm, normalized_embeddings.t()).squeeze()
similarities_centroid = torch.mm(centroid_norm, normalized_embeddings.t()).squeeze()
distances_emb = torch.pow(1 - similarities_emb, num_exp)
distances_centroid = torch.pow(1 - similarities_centroid, denom_exp)
ratios = distances_emb / distances_centroid
top_ratios, top_indices = torch.topk(ratios, k=top_k, largest=False)
closest_tokens = [_tokenizer.decode([idx.item()]) for idx in top_indices]
return list(zip(closest_tokens, top_ratios.tolist()))
def get_neuronpedia_url(layer, feature):
return f"https://neuronpedia.org/gemma-2b/{layer}-res-jb/{feature}?embed=true&embedexplanation=true&embedplots=false&embedtest=false&height=300"
# New functions for tree generation and visualization
def update_token_embedding(model, token_id, new_embedding):
new_embedding = new_embedding.to(model.get_input_embeddings().weight.device)
model.get_input_embeddings().weight.data[token_id] = new_embedding
def produce_next_token_ids(input_ids, model, topk, sub_token_id):
input_ids = input_ids.to(model.device)
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
last_logits = logits[:, -1, :]
last_logits[:, sub_token_id] = float('-inf')
softmax_probs = torch.softmax(last_logits, dim=-1)
top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk, dim=-1)
return top_k_ids[0], top_k_probs[0]
def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth=0, max_depth=25, cumulative_prob=1.0):
if depth >= max_depth or cumulative_prob < config.CUTOFF:
return
current_prompt = tokenizer.decode(input_ids[0], skip_special_tokens=True)
# Extract only the part that extends the base prompt
extended_prompt = current_prompt[len(base_prompt):].strip()
extended_prompt = extended_prompt.replace("\n", "|") # Replace \n with |
# Format the line to align "PROB:..." vertically, with additional padding
formatted_line = f"Depth {depth}: {extended_prompt:<45} PROB: {cumulative_prob:.4f}"
# Log only the formatted line without the "INFO:custom_logger" prefix
custom_logger.info(formatted_line)
top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
for idx, token_id in enumerate(top_k_ids.tolist()):
if token_id == config.SUB_TOKEN_ID:
continue # Skip the substitute token to avoid circular definitions
token_id_tensor = torch.tensor([token_id], dtype=torch.long).to(model.device)
new_input_ids = torch.cat([input_ids, token_id_tensor.view(1, 1)], dim=-1)
new_cumulative_prob = cumulative_prob * top_k_probs[idx].item()
if new_cumulative_prob < config.CUTOFF:
continue
token_str = tokenizer.decode([token_id], skip_special_tokens=True)
new_child = {
"token_id": token_id,
"token": token_str,
"cumulative_prob": new_cumulative_prob,
"children": []
}
data['children'].append(new_child)
yield from build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob)
def generate_definition_tree(base_prompt, embedding, model, tokenizer, config):
logger.info(f"Starting generate_definition_tree with base_prompt: {base_prompt}")
results_dict = {"token": "", "cumulative_prob": 1, "children": []}
# Reset the token embedding
token_embedding = torch.unsqueeze(embedding, dim=0).to(model.device)
update_token_embedding(model, config.SUB_TOKEN_ID, token_embedding)
# Clear the model's cache if it has one
if hasattr(model, 'reset_cache'):
model.reset_cache()
input_ids = tokenizer.encode(base_prompt, return_tensors="pt").to(model.device)
logger.info(f"Encoded input_ids: {input_ids}")
for item in build_def_tree(input_ids, results_dict, base_prompt, model, tokenizer, config):
yield item
logger.info("Finished building tree, yielding results_dict")
yield results_dict
def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
current_max = max(current_max, node.get('cumulative_prob', 0))
if node.get('cumulative_prob', 1) > 0:
current_min = min(current_min, node.get('cumulative_prob', 1))
for child in node.get('children', []):
current_max, current_min = find_max_min_cumulative_weight(child, current_max, current_min)
return current_max, current_min
def scale_edge_width(cumulative_weight, max_weight, min_weight, log_base, max_thickness=33, min_thickness=1):
cumulative_weight = max(cumulative_weight, min_weight)
log_weight = math.log(cumulative_weight, log_base) - math.log(min_weight, log_base)
log_max = math.log(max_weight, log_base) - math.log(min_weight, log_base)
amplified_weight = (log_weight / log_max) ** 2.5
scaled_weight = (amplified_weight * (max_thickness - min_thickness)) + min_thickness
return scaled_weight
def add_nodes_edges(dot, node, config, max_weight, min_weight, parent=None, is_root=True, depth=0, trim_cutoff=0):
node_id = str(id(node))
token = node.get('token', '').strip()
cumulative_prob = node.get('cumulative_prob', 1)
if cumulative_prob < trim_cutoff and not is_root:
return
if is_root or token:
if parent and not is_root:
edge_weight = scale_edge_width(cumulative_prob, max_weight, min_weight, config.LOG_BASE)
dot.edge(parent, node_id, arrowhead='dot', arrowsize='1', color='darkblue', penwidth=str(edge_weight))
label = "*" if is_root else token
dot.node(node_id, label=label, shape='plaintext', fontsize="36", fontname='Helvetica')
for child in node.get('children', []):
add_nodes_edges(dot, child, config, max_weight, min_weight, parent=node_id, is_root=False, depth=depth+1, trim_cutoff=trim_cutoff)
def create_tree_diagram(data, config, max_weight, min_weight, trim_cutoff=0):
dot = Digraph(comment='Definition Tree', format='png')
dot.attr(rankdir='LR', size='5040,5000', margin='0.06', nodesep='0.06', ranksep='1', dpi='120', bgcolor='white')
add_nodes_edges(dot, data, config, max_weight, min_weight, trim_cutoff=trim_cutoff)
# Save to a temporary file first
temp_filename = "temp_tree_diagram"
dot.render(temp_filename, format='png', cleanup=True)
# Read the file back into a BytesIO object
with open(f"{temp_filename}.png", "rb") as f:
output = BytesIO(f.read())
# Add white background
with Image.open(output) as img:
bg = Image.new("RGB", (img.width, 5000), (255, 255, 255))
y_offset = (5000 - img.height) // 2
bg.paste(img, (0, y_offset))
final_output = BytesIO()
bg.save(final_output, 'PNG')
final_output.seek(0)
return final_output
# Global variables to store loaded resources
tokenizer = None
model = None
token_embeddings = None
w_enc_dict = {}
w_dec_dict = {}
@functools.lru_cache(maxsize=None)
def cached_load_tokenizer():
return load_tokenizer()
@functools.lru_cache(maxsize=None)
def cached_load_model():
return load_model()
@functools.lru_cache(maxsize=None)
def cached_load_token_embeddings():
return load_token_embeddings()
def initialize_resources():
global tokenizer, model, token_embeddings
logger.info("Initializing resources...")
tokenizer = cached_load_tokenizer()
if tokenizer is None:
raise RuntimeError("Failed to load tokenizer.")
model = cached_load_model()
if model is None:
raise RuntimeError("Failed to load model.")
token_embeddings = cached_load_token_embeddings()
if token_embeddings is None:
raise RuntimeError("Failed to load token embeddings.")
logger.info("Resources initialized successfully.")
@spaces.GPU
def process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=False, progress=None):
global w_enc_dict, w_dec_dict, model, tokenizer, token_embeddings
try:
logger.info(f"Processing input: SAE={selected_sae}, feature_number={feature_number}, mode={mode}")
# Load the SAE weights if they are not already loaded
if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict:
logger.info("Loading SAE weights for {}".format(selected_sae))
w_enc, w_dec = load_sae_weights(selected_sae)
if w_enc is None or w_dec is None:
error_message = f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection."
logger.error(error_message)
return error_message, None
w_enc_dict[selected_sae] = w_enc
w_dec_dict[selected_sae] = w_dec
else:
w_enc, w_dec = w_enc_dict[selected_sae], w_dec_dict[selected_sae]
# Create the feature vector
token_centroid = torch.mean(token_embeddings, dim=0)
feature_vector = create_feature_vector(w_enc, w_dec, int(feature_number), weight_type, token_centroid, use_token_centroid, scaling_factor)
logger.info(f"Feature vector created. Shape: {feature_vector.shape}")
# Apply PCA if requested
if use_pca:
pca_direction = perform_pca(token_embeddings)
feature_vector = create_ghost_token(feature_vector, token_centroid, pca_direction, scaling_factor, pca_weight)
logger.info(f"PCA applied. New feature vector shape: {feature_vector.shape}")
if mode == "cosine distance token lists":
logger.info("Generating cosine distance token list")
closest_tokens_with_values = find_closest_tokens(
feature_vector, token_embeddings, tokenizer,
top_k=500, num_exp=num_exp, denom_exp=denom_exp
)
if top_500:
# Generate the top 500 list
result = ", ".join([f"'{token}': {value:.4f}" for token, value in closest_tokens_with_values])
logger.info("Returning top 500 list")
return result, None
else:
# Generate the top 100 list
token_list = [token for token, _ in closest_tokens_with_values[:100]]
result = f"100 tokens whose embeddings produce the smallest ratio (cos distance to feature vector)^m/(cos distance to token centroid)^n:\n\n"
result += f"[{', '.join(repr(token) for token in token_list)}]\n"
logger.info("Returning top 100 tokens")
return result, None
elif mode == "definition tree generation":
logger.info("Generating definition tree")
base_prompt = f'A typical definition of "{config.get_sub_token_string(tokenizer)}" would be "'
tree_generator = generate_definition_tree(base_prompt, feature_vector, model, tokenizer, config)
# Collect the log output
log_output = []
tree_data = None
for item in tree_generator:
if isinstance(item, str):
log_output.append(item)
else:
tree_data = item
# Join the log output into a single string
log_text = "\n".join(log_output)
# Generate the tree image
if tree_data:
logger.info("Generating tree image")
max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight)
logger.info("Tree image generated successfully")
return log_text, tree_image
else:
logger.error("Failed to generate tree data")
return "Error: Failed to generate tree data.", None
return "Mode not recognized or not implemented in this step.", None
except Exception as e:
logger.error(f"Error in process_input: {str(e)}")
return f"Error: {str(e)}", None
finally:
del feature_vector
if 'token_centroid' in locals():
del token_centroid
if use_pca and 'pca_direction' in locals():
del pca_direction
torch.cuda.empty_cache()
def trim_tree(trim_cutoff, tree_data):
max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
trimmed_tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight, trim_cutoff=float(trim_cutoff))
return trimmed_tree_image
def gradio_interface():
def update_visibility(mode):
if mode == "definition tree generation":
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
def update_neuronpedia(selected_sae, feature_number):
layer_number = int(selected_sae.split()[-1])
url = get_neuronpedia_url(layer_number, feature_number)
return f'<iframe src="{url}" width="100%" height="300px"></iframe>'
@spaces.GPU
def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, progress=gr.Progress()):
global w_enc_dict, w_dec_dict, model, tokenizer, token_embeddings
try:
logger.info(f"Processing input: SAE={selected_sae}, feature_number={feature_number}, mode={mode}")
# Load the SAE weights if they are not already loaded
if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict:
logger.info("Loading SAE weights for {}".format(selected_sae))
w_enc, w_dec = load_sae_weights(selected_sae)
if w_enc is None or w_dec is None:
error_message = f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection."
logger.error(error_message)
return error_message, None
w_enc_dict[selected_sae] = w_enc
w_dec_dict[selected_sae] = w_dec
else:
w_enc, w_dec = w_enc_dict[selected_sae], w_dec_dict[selected_sae]
# Create the feature vector
token_centroid = torch.mean(token_embeddings, dim=0)
feature_vector = create_feature_vector(w_enc, w_dec, int(feature_number), weight_type, token_centroid, use_token_centroid, scaling_factor)
logger.info(f"Feature vector created. Shape: {feature_vector.shape}")
# Apply PCA if requested
if use_pca:
pca_direction = perform_pca(token_embeddings)
feature_vector = create_ghost_token(feature_vector, token_centroid, pca_direction, scaling_factor, pca_weight)
logger.info(f"PCA applied. New feature vector shape: {feature_vector.shape}")
if mode == "cosine distance token lists":
logger.info("Generating cosine distance token list")
closest_tokens_with_values = find_closest_tokens(
feature_vector, token_embeddings, tokenizer,
top_k=500, num_exp=num_exp, denom_exp=denom_exp
)
token_list = [token for token, _ in closest_tokens_with_values[:100]]
result = f"100 tokens whose embeddings produce the smallest ratio (cos distance to feature vector)^m/(cos distance to token centroid)^n:\n\n"
result += f"[{', '.join(repr(token) for token in token_list)}]\n"
logger.info("Returning top 100 tokens")
return result, None
elif mode == "definition tree generation":
logger.info("Generating definition tree")
base_prompt = f'A typical definition of "{config.get_sub_token_string(tokenizer)}" would be "'
tree_generator = generate_definition_tree(base_prompt, feature_vector, model, tokenizer, config)
# Collect the log output
log_output = []
tree_data = None
for item in tree_generator:
if isinstance(item, str):
log_output.append(item)
logger.info(item) # Log each step
else:
tree_data = item
logger.info("Received tree data")
# Join the log output into a single string
log_text = "\n".join(log_output)
# Generate the tree image
if tree_data:
logger.info("Generating tree image")
max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight)
logger.info("Tree image generated successfully")
return log_text, tree_image
else:
logger.error("Failed to generate tree data")
return "Error: Failed to generate tree data.", None
return "Mode not recognized or not implemented in this step.", None
except Exception as e:
logger.error(f"Error in process_input: {str(e)}")
return f"Error: {str(e)}", None
finally:
del feature_vector
del token_centroid
if use_pca:
del pca_direction
torch.cuda.empty_cache()
@spaces.GPU
def generate_top_500(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
# Call process_input with top_500=True to generate the full list
return process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=True)
def trim_tree(trim_cutoff, tree_data):
if tree_data is None:
return None
max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
trimmed_tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight, trim_cutoff=float(trim_cutoff))
return trimmed_tree_image
with gr.Blocks() as demo:
gr.Markdown("# Gemma-2B SAE Feature Explorer")
with gr.Row():
with gr.Column(scale=2):
selected_sae = gr.Dropdown(choices=["Gemma-2B layer 0", "Gemma-2B layer 6", "Gemma-2B layer 10", "Gemma-2B layer 12"], label="Select SAE")
feature_number = gr.Number(label="Select feature number", minimum=0, maximum=16383, value=0)
mode = gr.Radio(
choices=["cosine distance token lists", "definition tree generation"],
label="Select mode",
value="cosine distance token lists"
)
weight_type = gr.Radio(["encoder", "decoder"], label="Select weight type for feature vector construction", value="encoder")
use_token_centroid = gr.Checkbox(label="Use token centroid offset", value=True)
scaling_factor = gr.Slider(minimum=0.1, maximum=10.0, value=3.8, label="Scaling factor (3.8 is mean distance from token embeddings to token centroid)")
num_exp = gr.Slider(minimum=0.1, maximum=5.0, value=1.4, label="Numerator exponent m")
denom_exp = gr.Slider(minimum=0.1, maximum=5.0, value=1.0, label="Denominator exponent n")
use_pca = gr.Checkbox(label="Introduce first PCA component")
pca_weight = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="PCA weight")
with gr.Column(scale=3):
generate_btn = gr.Button("Generate Output")
output_stream = gr.Textbox(label="Output", lines=20)
output_image = gr.Image(label="Tree Diagram", visible=False)
generate_btn.click(
update_output,
inputs=[selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode],
outputs=[output_stream, output_image],
show_progress="full"
)
generate_top_500_btn = gr.Button("Generate Top 500 Tokens and Power Ratios", visible=True)
output_500_text = gr.Textbox(label="Top 500 Output", lines=20, visible=False)
trim_slider = gr.Slider(minimum=0.00001, maximum=0.1, value=0.00001, label="Trim cutoff for cumulative probability", visible=False)
trim_btn = gr.Button("Trim Tree", visible=False)
tree_data_state = gr.State()
neuronpedia_html = gr.HTML(label="Neuronpedia")
inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode]
generate_btn.click(
update_output,
inputs=inputs,
outputs=[output_stream, output_image],
show_progress="full"
).then(lambda: gr.update(visible=False, value=""), None, [output_500_text])
generate_top_500_btn.click(
generate_top_500,
inputs=inputs,
outputs=[output_500_text],
show_progress="full"
).then(lambda: gr.update(visible=True), None, [output_500_text])
trim_btn.click(trim_tree, inputs=[trim_slider, tree_data_state], outputs=[output_image])
mode.change(update_visibility, inputs=[mode], outputs=[output_image, trim_slider, trim_btn, generate_top_500_btn, output_500_text])
selected_sae.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html])
feature_number.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html])
output_stream.change(
lambda text: (gr.update(visible=True), gr.update(visible=True)) if "100 tokens" in text else (gr.update(visible=False), gr.update(visible=False)),
inputs=[output_stream],
outputs=[generate_top_500_btn, output_500_text]
)
return demo
if __name__ == "__main__":
try:
logger.info("Starting application initialization...")
initialize_resources()
logger.info("Creating Gradio interface...")
iface = gradio_interface()
logger.info("Launching Gradio interface...")
iface.launch()
logger.info("Gradio interface launched successfully")
except Exception as e:
logger.error(f"Error during application startup: {str(e)}")
logger.error(traceback.format_exc())