|
|
""" |
|
|
Cognitive Proxy - Brain-Steered Language Model |
|
|
Hugging Face Spaces deployment |
|
|
Author: Sandro Andric |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import pickle |
|
|
import os |
|
|
from pathlib import Path |
|
|
from sklearn.decomposition import PCA |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import plotly.graph_objects as go |
|
|
import plotly.express as px |
|
|
import spaces |
|
|
|
|
|
|
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
SCRIPT_DIR = Path(__file__).parent if __file__ else Path.cwd() |
|
|
|
|
|
|
|
|
if (SCRIPT_DIR / "results" / "final_atlas_256_vocab.pkl").exists(): |
|
|
ATLAS_PATH = str(SCRIPT_DIR / "results" / "final_atlas_256_vocab.pkl") |
|
|
ADAPTER_PATH = str(SCRIPT_DIR / "results" / "tinyllama_adapter_direct.pt") |
|
|
elif (SCRIPT_DIR / "final_atlas_256_vocab.pkl").exists(): |
|
|
ATLAS_PATH = str(SCRIPT_DIR / "final_atlas_256_vocab.pkl") |
|
|
ADAPTER_PATH = str(SCRIPT_DIR / "tinyllama_adapter_direct.pt") |
|
|
else: |
|
|
|
|
|
ATLAS_PATH = "results/final_atlas_256_vocab.pkl" |
|
|
ADAPTER_PATH = "results/tinyllama_adapter_direct.pt" |
|
|
|
|
|
print(f"Atlas path: {ATLAS_PATH}") |
|
|
print(f"Adapter path: {ADAPTER_PATH}") |
|
|
|
|
|
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
|
|
|
|
|
|
|
class TinyLlamaAdapterDirect(nn.Module): |
|
|
def __init__(self, input_dim=2048, hidden_dim=1024, output_dim=65536): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(input_dim, hidden_dim), |
|
|
nn.LayerNorm(hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(hidden_dim, hidden_dim), |
|
|
nn.LayerNorm(hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.LayerNorm(hidden_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim // 2, output_dim), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
|
|
|
system = None |
|
|
|
|
|
def load_system(): |
|
|
global system |
|
|
if system is not None: |
|
|
return system |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
try: |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=dtype).to(device) |
|
|
except TypeError: |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=dtype).to(device) |
|
|
model.eval() |
|
|
|
|
|
adapter = TinyLlamaAdapterDirect().to(device).to(dtype) |
|
|
if os.path.exists(ADAPTER_PATH): |
|
|
adapter.load_state_dict(torch.load(ADAPTER_PATH, map_location=device, weights_only=True)) |
|
|
adapter.eval() |
|
|
|
|
|
if os.path.exists(ATLAS_PATH): |
|
|
print(f"Loading atlas from {ATLAS_PATH}") |
|
|
with open(ATLAS_PATH, 'rb') as f: |
|
|
data = pickle.load(f) |
|
|
if isinstance(data, dict): |
|
|
print(f"Atlas data keys: {list(data.keys())[:5]}") |
|
|
if 'means' in data: |
|
|
atlas = data['means'] |
|
|
print(f"Using 'means' key, got {len(atlas) if isinstance(atlas, dict) else 'not a dict'} items") |
|
|
else: |
|
|
atlas = data |
|
|
print(f"Using data directly, got {len(atlas) if isinstance(atlas, dict) else 'not a dict'} items") |
|
|
else: |
|
|
atlas = data |
|
|
print(f"Atlas is not a dict, type: {type(data)}") |
|
|
else: |
|
|
print(f"Atlas file not found at {ATLAS_PATH}") |
|
|
atlas = {} |
|
|
|
|
|
|
|
|
if not atlas or not isinstance(atlas, dict): |
|
|
print(f"Warning: Atlas is empty or invalid, using fallback") |
|
|
atlas = {'word1': np.random.randn(256, 256), 'word2': np.random.randn(256, 256)} |
|
|
|
|
|
words = list(atlas.keys()) |
|
|
print(f"Loaded atlas with {len(words)} words") |
|
|
if len(words) < 2: |
|
|
print(f"Warning: Not enough words in atlas ({len(words)}), using fallback") |
|
|
atlas = {'word1': np.random.randn(256, 256), 'word2': np.random.randn(256, 256)} |
|
|
words = list(atlas.keys()) |
|
|
|
|
|
|
|
|
first_val = np.array(atlas[words[0]]) |
|
|
if first_val.shape == (256, 256): |
|
|
plv_matrix = np.array([np.array(atlas[w]).flatten() for w in words]) |
|
|
else: |
|
|
plv_matrix = np.array([np.array(atlas[w]) for w in words]) |
|
|
|
|
|
|
|
|
if len(plv_matrix.shape) == 1 or plv_matrix.shape[0] < 2: |
|
|
print(f"Warning: Invalid PLV matrix shape {plv_matrix.shape}, using fallback") |
|
|
plv_matrix = np.random.randn(10, 65536) |
|
|
|
|
|
pca = PCA(n_components=min(10, plv_matrix.shape[0] - 1)) |
|
|
pca.fit(plv_matrix) |
|
|
pc1_axis = pca.components_[0] |
|
|
pc1_axis = pc1_axis / np.linalg.norm(pc1_axis) |
|
|
global_mean = plv_matrix.mean(axis=0) |
|
|
|
|
|
system = { |
|
|
'model': model, |
|
|
'tokenizer': tokenizer, |
|
|
'adapter': adapter, |
|
|
'axis': torch.tensor(pc1_axis, dtype=torch.float32).to(device), |
|
|
'global_mean': torch.tensor(global_mean, dtype=torch.float32).to(device), |
|
|
'device': device |
|
|
} |
|
|
return system |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def generate_variants(prompt, scenario, max_tokens): |
|
|
"""Generate all three variants""" |
|
|
sys = load_system() |
|
|
|
|
|
if scenario == "Educational": |
|
|
prompt_formatted = f"<|user|>\n{prompt}\n<|assistant|>\n" |
|
|
alpha_strength = 5.0 |
|
|
elif scenario == "Technical writing": |
|
|
prompt_formatted = f"<|user|>\n{prompt}\n<|assistant|>\n" |
|
|
alpha_strength = 5.0 |
|
|
else: |
|
|
prompt_formatted = prompt |
|
|
alpha_strength = 3.0 |
|
|
|
|
|
outputs = [] |
|
|
for alpha in [-alpha_strength, 0, alpha_strength]: |
|
|
inputs = sys['tokenizer'](prompt_formatted, return_tensors='pt').to(sys['device']) |
|
|
generated_ids = inputs.input_ids.clone() |
|
|
|
|
|
for _ in range(max_tokens): |
|
|
outputs_model = sys['model'](generated_ids, output_hidden_states=True) |
|
|
hidden = outputs_model.hidden_states[-1][:, -1, :] |
|
|
|
|
|
|
|
|
adapter_dtype = next(sys['adapter'].parameters()).dtype |
|
|
hidden = hidden.to(adapter_dtype) |
|
|
|
|
|
if alpha != 0: |
|
|
hidden = hidden.detach().requires_grad_(True) |
|
|
plv_pred = sys['adapter'](hidden) |
|
|
score = torch.sum(plv_pred * sys['axis'].to(adapter_dtype)) |
|
|
grad = torch.autograd.grad(score, hidden, retain_graph=False)[0] |
|
|
grad = grad / (grad.norm() + 1e-8) |
|
|
hidden = hidden.detach() + alpha * grad.detach() |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = sys['model'].lm_head(sys['model'].model.norm(hidden)) |
|
|
probs = torch.softmax(logits / 0.8, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
generated_ids = torch.cat([generated_ids, next_token], dim=-1) |
|
|
if next_token.item() == sys['tokenizer'].eos_token_id: |
|
|
break |
|
|
|
|
|
text = sys['tokenizer'].decode(generated_ids[0], skip_special_tokens=True) |
|
|
if "<|assistant|>" in text: |
|
|
text = text.split("<|assistant|>")[-1].strip() |
|
|
outputs.append(text) |
|
|
|
|
|
return outputs[0], outputs[1], outputs[2] |
|
|
|
|
|
@spaces.GPU(duration=30) |
|
|
def analyze_text(text): |
|
|
"""Analyze text and return score with visualization""" |
|
|
sys = load_system() |
|
|
|
|
|
with torch.no_grad(): |
|
|
inputs = sys['tokenizer'](text, return_tensors='pt').to(sys['device']) |
|
|
out = sys['model'](**inputs, output_hidden_states=True) |
|
|
last_hidden = out.hidden_states[-1][0, -1, :] |
|
|
|
|
|
adapter_dtype = next(sys['adapter'].parameters()).dtype |
|
|
last_hidden = last_hidden.to(adapter_dtype) |
|
|
plv_pred = sys['adapter'](last_hidden.unsqueeze(0)) |
|
|
plv_flat = plv_pred[0] |
|
|
plv_centered = plv_flat - sys['global_mean'].to(adapter_dtype) |
|
|
score = (plv_centered * sys['axis'].to(adapter_dtype)).sum().item() |
|
|
|
|
|
|
|
|
gauge_min = min(-300, score - 50) |
|
|
gauge_max = max(300, score + 50) |
|
|
|
|
|
fig = go.Figure(go.Indicator( |
|
|
mode="number+gauge", |
|
|
value=score, |
|
|
gauge={ |
|
|
'shape': "angular", |
|
|
'axis': {'range': [gauge_min, gauge_max], 'tickwidth': 0.5, 'tickcolor': '#ccc'}, |
|
|
'bar': {'color': "#333", 'thickness': 0.15}, |
|
|
'bgcolor': "white", |
|
|
'borderwidth': 1, |
|
|
'bordercolor': "#e0e0e0", |
|
|
'steps': [ |
|
|
{'range': [gauge_min, -5], 'color': "#e8f5e9"}, |
|
|
{'range': [-5, 5], 'color': "#fafafa"}, |
|
|
{'range': [5, gauge_max], 'color': "#fff3e0"} |
|
|
], |
|
|
}, |
|
|
number={'font': {'size': 36, 'color': '#000'}} |
|
|
)) |
|
|
|
|
|
fig.update_layout( |
|
|
height=300, |
|
|
width=400, |
|
|
margin={'l': 30, 'r': 30, 't': 50, 'b': 30}, |
|
|
paper_bgcolor='white', |
|
|
font={'color': '#666'} |
|
|
) |
|
|
|
|
|
if score > 5: |
|
|
interpretation = "**Syntactic dominance** \nText patterns match brain activity during grammatical processing" |
|
|
elif score < -5: |
|
|
interpretation = "**Semantic dominance** \nText patterns match brain activity during meaning comprehension" |
|
|
else: |
|
|
interpretation = "**Balanced** \nMixed patterns - both structure and meaning equally present" |
|
|
|
|
|
|
|
|
plv_np = plv_pred[0].cpu().numpy() |
|
|
plv_matrix = plv_np[:65536].reshape(256, 256) |
|
|
|
|
|
fig_plv = px.imshow( |
|
|
plv_matrix, |
|
|
color_continuous_scale='Viridis', |
|
|
aspect='auto' |
|
|
) |
|
|
fig_plv.update_layout( |
|
|
coloraxis_showscale=True, |
|
|
coloraxis=dict( |
|
|
colorbar=dict( |
|
|
thickness=10, |
|
|
len=0.7, |
|
|
title=dict(text="Synchrony", side="right"), |
|
|
tickfont=dict(size=10) |
|
|
) |
|
|
), |
|
|
margin={'l': 0, 'r': 40, 't': 10, 'b': 0}, |
|
|
height=300 |
|
|
) |
|
|
fig_plv.update_xaxes(visible=False) |
|
|
fig_plv.update_yaxes(visible=False) |
|
|
|
|
|
return fig, interpretation, score, fig_plv |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def generate_steered(prompt, alpha, max_tokens): |
|
|
"""Generate with custom steering""" |
|
|
sys = load_system() |
|
|
|
|
|
inputs = sys['tokenizer'](prompt, return_tensors='pt').to(sys['device']) |
|
|
generated_ids = inputs.input_ids.clone() |
|
|
|
|
|
for _ in range(max_tokens): |
|
|
outputs_model = sys['model'](generated_ids, output_hidden_states=True) |
|
|
hidden = outputs_model.hidden_states[-1][:, -1, :] |
|
|
|
|
|
|
|
|
adapter_dtype = next(sys['adapter'].parameters()).dtype |
|
|
hidden = hidden.to(adapter_dtype) |
|
|
|
|
|
if alpha != 0: |
|
|
hidden = hidden.detach().requires_grad_(True) |
|
|
plv_pred = sys['adapter'](hidden) |
|
|
score = torch.sum(plv_pred * sys['axis'].to(adapter_dtype)) |
|
|
grad = torch.autograd.grad(score, hidden, retain_graph=False)[0] |
|
|
grad = grad / (grad.norm() + 1e-8) |
|
|
hidden = hidden.detach() + alpha * grad.detach() |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = sys['model'].lm_head(sys['model'].model.norm(hidden)) |
|
|
probs = torch.softmax(logits / 0.8, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
generated_ids = torch.cat([generated_ids, next_token], dim=-1) |
|
|
if next_token.item() == sys['tokenizer'].eos_token_id: |
|
|
break |
|
|
|
|
|
return sys['tokenizer'].decode(generated_ids[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
<style> |
|
|
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap'); |
|
|
|
|
|
/* Global font */ |
|
|
.gradio-container, .gradio-container * { |
|
|
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important; |
|
|
} |
|
|
|
|
|
/* Clean header */ |
|
|
.main-header { |
|
|
font-size: 14px; |
|
|
font-weight: 300; |
|
|
letter-spacing: 2px; |
|
|
text-transform: uppercase; |
|
|
color: #666; |
|
|
margin-bottom: 8px; |
|
|
} |
|
|
|
|
|
.main-title { |
|
|
font-size: 48px; |
|
|
font-weight: 300; |
|
|
line-height: 1.1; |
|
|
letter-spacing: -1px; |
|
|
margin-bottom: 16px; |
|
|
} |
|
|
|
|
|
.subtitle { |
|
|
font-size: 18px; |
|
|
font-weight: 300; |
|
|
color: #666; |
|
|
line-height: 1.6; |
|
|
} |
|
|
|
|
|
/* Clean tabs like Streamlit */ |
|
|
.tabs { |
|
|
border-bottom: 1px solid #e0e0e0 !important; |
|
|
} |
|
|
|
|
|
.tab-nav button { |
|
|
background: none !important; |
|
|
border: none !important; |
|
|
border-bottom: 2px solid transparent !important; |
|
|
color: #666 !important; |
|
|
font-weight: 400 !important; |
|
|
font-size: 14px !important; |
|
|
padding: 8px 16px !important; |
|
|
text-transform: none !important; |
|
|
} |
|
|
|
|
|
.tab-nav button.selected { |
|
|
color: #000 !important; |
|
|
border-bottom-color: #000 !important; |
|
|
} |
|
|
|
|
|
/* Minimal buttons */ |
|
|
button.primary { |
|
|
background: white !important; |
|
|
border: 1px solid #000 !important; |
|
|
color: #000 !important; |
|
|
font-weight: 400 !important; |
|
|
padding: 10px 20px !important; |
|
|
transition: all 0.2s !important; |
|
|
} |
|
|
|
|
|
button.primary:hover { |
|
|
background: #000 !important; |
|
|
color: white !important; |
|
|
} |
|
|
|
|
|
/* Clean textboxes */ |
|
|
textarea, input[type="text"] { |
|
|
border: 1px solid #e0e0e0 !important; |
|
|
border-radius: 0 !important; |
|
|
font-size: 14px !important; |
|
|
} |
|
|
|
|
|
/* Section titles */ |
|
|
.section-title { |
|
|
font-size: 11px; |
|
|
font-weight: 500; |
|
|
letter-spacing: 1.5px; |
|
|
text-transform: uppercase; |
|
|
color: #999; |
|
|
margin: 24px 0 16px 0; |
|
|
} |
|
|
|
|
|
/* Value labels */ |
|
|
.value-label { |
|
|
font-size: 12px; |
|
|
color: #999; |
|
|
margin-bottom: 4px; |
|
|
} |
|
|
|
|
|
/* Remove gradio branding */ |
|
|
footer { display: none !important; } |
|
|
</style> |
|
|
""" |
|
|
|
|
|
|
|
|
DEFAULT_PROMPTS = { |
|
|
"Technical writing": "Draft a short SMS to the customer informing them their payment has failed.", |
|
|
"Educational": "Explain in 2 sentences what the butterfly effect is.", |
|
|
"Free form": "Brainstorm creative uses of brain-steered language models in five bullet points." |
|
|
} |
|
|
|
|
|
SCENARIO_AXIS_TEXT = { |
|
|
"Technical writing": { |
|
|
"left_label": "Semantic / Content (meaning-heavy, concrete) [empathetic/actionable tone]", |
|
|
"baseline_label": "Baseline", |
|
|
"right_label": "Syntactic / Function (structure-heavy, abstract) [formal/policy tone]", |
|
|
"left_caption": "*Steered toward meaning (brain semantic side)*", |
|
|
"baseline_caption": "*No brain steering*", |
|
|
"right_caption": "*Steered toward structure (brain syntactic side)*", |
|
|
}, |
|
|
"Educational": { |
|
|
"left_label": "Semantic / Content (meaning-heavy, concrete) [analogy/concrete style]", |
|
|
"baseline_label": "Baseline", |
|
|
"right_label": "Syntactic / Function (structure-heavy, abstract) [definition/logical style]", |
|
|
"left_caption": "*Steered toward meaning (brain semantic side)*", |
|
|
"baseline_caption": "*No brain steering*", |
|
|
"right_caption": "*Steered toward structure (brain syntactic side)*", |
|
|
}, |
|
|
"Free form": { |
|
|
"left_label": "Semantic / Content (meaning-heavy, concrete)", |
|
|
"baseline_label": "Baseline", |
|
|
"right_label": "Syntactic / Function (structure-heavy, abstract)", |
|
|
"left_caption": "*Steered toward meaning (brain semantic side)*", |
|
|
"baseline_caption": "*No brain steering*", |
|
|
"right_caption": "*Steered toward structure (brain syntactic side)*", |
|
|
}, |
|
|
} |
|
|
|
|
|
with gr.Blocks( |
|
|
title="Cognitive Proxy", |
|
|
theme=gr.themes.Base( |
|
|
primary_hue="gray", |
|
|
neutral_hue="gray", |
|
|
text_size="md", |
|
|
spacing_size="lg", |
|
|
radius_size="none", |
|
|
), |
|
|
css=custom_css |
|
|
) as demo: |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div> |
|
|
<div class="main-header">Neural Language Interface</div> |
|
|
<div class="main-title">Cognitive Proxy</div> |
|
|
<div class="subtitle">Steering language models through brain-derived coordinate spaces.<br> |
|
|
Using MEG phase-locking patterns from 21 subjects as control geometry.</div> |
|
|
<div style="color: #999; font-size: 13px; margin-top: 16px;">Sandro Andric</div> |
|
|
<div style="color: #999; font-size: 11px; margin-top: 8px;">Demo model: TinyLlama-1.1B-Chat</div> |
|
|
<div style="margin-top: 12px;"><a href="https://arxiv.org/abs/2512.19399" style="color: #666; font-size: 12px;">📄 Read our latest research on brain-LLM alignment</a></div> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Accordion("How this works", open=False): |
|
|
gr.Markdown(""" |
|
|
**What makes this special:** This AI is controlled by real human brain data. |
|
|
We recorded brain activity from 21 people listening to stories, discovered how their brains organize language, |
|
|
and now use those patterns to steer what the AI generates. |
|
|
|
|
|
**Try this:** |
|
|
1. Start with the **Compare** tab and choose **Educational** |
|
|
2. Click "Generate all variants" to see three versions side by side |
|
|
3. Notice how the left (concrete) version uses analogies while the right (abstract) uses logic |
|
|
4. The difference comes from steering along brain axes discovered from MEG recordings |
|
|
|
|
|
**The science:** Different brain regions activate for grammar vs meaning. |
|
|
We project the AI's internal states into this brain coordinate system and steer along the axis. |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.TabItem("Compare"): |
|
|
gr.HTML('<div class="section-title">Comparative Analysis</div>') |
|
|
|
|
|
gr.Markdown(""" |
|
|
See how brain steering affects AI output. Try **Educational** to see the difference between |
|
|
abstract explanations vs concrete analogies, or **Technical writing** to compare formal vs friendly tones. |
|
|
All controlled by brain patterns from 21 human subjects. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
scenario = gr.Dropdown( |
|
|
choices=["Technical writing", "Educational", "Free form"], |
|
|
value="Technical writing", |
|
|
label="Scenario", |
|
|
container=False |
|
|
) |
|
|
|
|
|
prompt = gr.Textbox( |
|
|
value=DEFAULT_PROMPTS["Technical writing"], |
|
|
label="", |
|
|
placeholder="Enter your prompt...", |
|
|
lines=4 |
|
|
) |
|
|
|
|
|
def update_prompt(selected): |
|
|
return DEFAULT_PROMPTS.get(selected, DEFAULT_PROMPTS["Free form"]) |
|
|
|
|
|
scenario.change( |
|
|
update_prompt, |
|
|
inputs=[scenario], |
|
|
outputs=[prompt] |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
max_tokens = gr.Slider(20, 150, 80, label="Max tokens", container=False) |
|
|
generate_btn = gr.Button("Generate all variants", variant="primary") |
|
|
|
|
|
gr.HTML('<div style="margin-top: 24px;"></div>') |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
axis_text = SCENARIO_AXIS_TEXT["Technical writing"] |
|
|
left_label = gr.HTML(f'<div class="value-label">{axis_text["left_label"]}</div>') |
|
|
output_semantic = gr.Textbox( |
|
|
label="", |
|
|
lines=10, |
|
|
interactive=False, |
|
|
container=False |
|
|
) |
|
|
left_caption = gr.Markdown(axis_text["left_caption"], elem_classes=["caption"]) |
|
|
|
|
|
with gr.Column(): |
|
|
baseline_label = gr.HTML(f'<div class="value-label">{axis_text["baseline_label"]}</div>') |
|
|
output_baseline = gr.Textbox( |
|
|
label="", |
|
|
lines=10, |
|
|
interactive=False, |
|
|
container=False |
|
|
) |
|
|
baseline_caption = gr.Markdown(axis_text["baseline_caption"], elem_classes=["caption"]) |
|
|
|
|
|
with gr.Column(): |
|
|
right_label = gr.HTML(f'<div class="value-label">{axis_text["right_label"]}</div>') |
|
|
output_syntactic = gr.Textbox( |
|
|
label="", |
|
|
lines=10, |
|
|
interactive=False, |
|
|
container=False |
|
|
) |
|
|
right_caption = gr.Markdown(axis_text["right_caption"], elem_classes=["caption"]) |
|
|
|
|
|
def update_axis_labels(selected): |
|
|
data = SCENARIO_AXIS_TEXT.get(selected, SCENARIO_AXIS_TEXT["Free form"]) |
|
|
return ( |
|
|
data["left_label"], |
|
|
data["baseline_label"], |
|
|
data["right_label"], |
|
|
data["left_caption"], |
|
|
data["baseline_caption"], |
|
|
data["right_caption"], |
|
|
) |
|
|
|
|
|
scenario.change( |
|
|
update_axis_labels, |
|
|
inputs=[scenario], |
|
|
outputs=[left_label, baseline_label, right_label, left_caption, baseline_caption, right_caption], |
|
|
) |
|
|
|
|
|
generate_btn.click( |
|
|
generate_variants, |
|
|
inputs=[prompt, scenario, max_tokens], |
|
|
outputs=[output_semantic, output_baseline, output_syntactic] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.TabItem("Inspect"): |
|
|
gr.HTML('<div class="section-title">Brain Space Projection</div>') |
|
|
|
|
|
gr.Markdown(""" |
|
|
Enter any text to see how it aligns with brain patterns. The meter shows whether your text |
|
|
activates brain regions associated with grammar/structure (positive) or meaning/content (negative). |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
text_input = gr.Textbox( |
|
|
value="The scientist discovered", |
|
|
label="", |
|
|
placeholder="Enter text to analyze...", |
|
|
lines=6 |
|
|
) |
|
|
analyze_btn = gr.Button("Project", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
gauge_plot = gr.Plot(label="") |
|
|
interpretation = gr.Markdown("") |
|
|
|
|
|
with gr.Accordion("What the number means", open=False): |
|
|
gr.Markdown(""" |
|
|
- **Negative values (green)** = semantic/meaning focus |
|
|
- **Positive values (amber)** = syntactic/grammar focus |
|
|
- **Larger magnitude** = stronger pattern |
|
|
- **Range** typically -300 to +300 |
|
|
""") |
|
|
|
|
|
with gr.Accordion("View brain connectivity pattern", open=False): |
|
|
gr.Markdown(""" |
|
|
Phase-Locking Value (PLV) shows how synchronized different brain regions are. |
|
|
Brighter colors = stronger synchronization between sensor pairs. |
|
|
Each pixel represents connectivity between two of 256 MEG sensors. |
|
|
""") |
|
|
plv_plot = gr.Plot(label="") |
|
|
|
|
|
def analyze_text_wrapper(text): |
|
|
fig, interp, _, fig_plv = analyze_text(text) |
|
|
return fig, interp, fig_plv |
|
|
|
|
|
analyze_btn.click( |
|
|
analyze_text_wrapper, |
|
|
inputs=[text_input], |
|
|
outputs=[gauge_plot, interpretation, plv_plot] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.TabItem("Steer"): |
|
|
gr.HTML('<div class="section-title">Neural Steering</div>') |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
prompt_steer = gr.Textbox( |
|
|
value="The scientist discovered", |
|
|
label="", |
|
|
placeholder="Enter prompt...", |
|
|
lines=5 |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.HTML('<div class="value-label">Tokens</div>') |
|
|
tokens_steer = gr.Slider(20, 150, 60, label="", container=False) |
|
|
|
|
|
gr.HTML('<div class="value-label">Alpha</div>') |
|
|
alpha_steer = gr.Slider(-5.0, 5.0, 0.0, 0.5, label="", container=False) |
|
|
gr.Markdown("*negative → semantic | positive → syntactic*", elem_classes=["caption"]) |
|
|
|
|
|
steer_btn = gr.Button("Generate", variant="primary") |
|
|
|
|
|
gr.HTML('<div class="section-title">Output</div>') |
|
|
output_steer = gr.Textbox(label="", lines=8, interactive=False, container=False) |
|
|
|
|
|
steer_btn.click( |
|
|
generate_steered, |
|
|
inputs=[prompt_steer, alpha_steer, tokens_steer], |
|
|
outputs=[output_steer] |
|
|
) |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; color: #999; font-size: 12px; padding: 40px 0 20px 0; border-top: 1px solid #e0e0e0; margin-top: 40px;"> |
|
|
© 2025 Sandro Andric | <a href="https://ainthusiast.com" style="color: #999;">Ainthusiast.com</a> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
demo.launch() |
|
|
|