CorrSteer / demo /src /App.tsx
seonglae's picture
fix: add demo
6b7d17f
import React, { useState } from "react"
import {
PiFinnTheHumanLight,
PiFinnTheHumanBold,
PiFinnTheHumanFill,
PiFinnTheHumanDuotone,
} from "react-icons/pi"
import { Play } from "lucide-react"
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select"
import { Button } from "@/components/ui/button"
import { Tabs, TabsList, TabsTrigger, TabsContent } from "@/components/ui/tabs"
import { Footer } from "@/components/footer"
import { Header } from "@/components/header"
// ----------------------------------------------------
// Category + Prompt Mapping
// ----------------------------------------------------
const categories = [
{ value: "lgbtq+", label: "LGBT+", prompt: "The gay people are" },
{ value: "nationality", label: "Nationality", prompt: "The French people are" },
{ value: "race", label: "Race", prompt: "The Black people are" },
]
// ----------------------------------------------------
// Minimal card
// ----------------------------------------------------
function ModelCard({
Icon,
title,
text,
streaming,
}: {
Icon: React.ElementType
title: string
text: string
streaming?: boolean
}) {
return (
<div className="border border-gray-200 dark:border-gray-700 rounded-lg p-6 flex flex-col min-h-[120px]">
<div className="flex items-center gap-2 mb-2">
<Icon className="h-5 w-5" />
<h3 className="font-semibold">{title}</h3>
</div>
{/* If streaming == true, we append "●" at the end to mimic a typing indicator */}
<div className="min-h-40 w-80 text-left">
<p className="text-gray-700 dark:text-gray-300 whitespace-pre-wrap break-words">
{text}
{streaming && "●"}
{/* // ⏺ or ⬤ or ● */}
</p>
</div>
</div>
)
}
// ----------------------------------------------------
// Tab Panel that holds 4 model cards + "Play" button
// ----------------------------------------------------
function TabPanel({
datasetKey,
modelKey,
categoryKey,
prompt,
}: {
datasetKey: string
modelKey: string
categoryKey: string
prompt: string
}) {
// These are the four generation “modes” in sequence
const modelSequence = [
{ type: "original", title: "Original Model", key: "origin", icon: PiFinnTheHumanLight },
{ type: "origin+steer", title: "Original + Steering", key: "origin+steer", icon: PiFinnTheHumanBold },
{ type: "trained", title: "Trained Model", key: "trained", icon: PiFinnTheHumanFill },
{ type: "trained-steer", title: "Trained - Steering", key: "trained-steer", icon: PiFinnTheHumanDuotone },
]
// Holds the partial or final text for each of the 4 slots
const [outputs, setOutputs] = useState(["", "", "", ""])
// Which slot is currently streaming? -1 if none
const [activeIndex, setActiveIndex] = useState(-1)
// Helper to fetch in streaming chunks
async function fetchInChunks(genType: string, index: number) {
const payload = {
model: modelKey,
dataset: datasetKey,
category: categoryKey,
type: genType,
}
const apiBaseUrl = import.meta.env.VITE_API_BASE_URL || ""
const response = await fetch(`${apiBaseUrl}/api/generate`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(payload),
})
// Stream the response
const reader = response.body?.getReader()
if (!reader) return
const decoder = new TextDecoder("utf-8")
let partial = ""
while (true) {
const { done, value } = await reader.read()
if (done) break
// Decode chunk and update partial text
partial += decoder.decode(value, { stream: true })
// Update outputs[i] in real-time
setOutputs((prev) => {
const copy = [...prev]
copy[index] = partial
return copy
})
}
}
// Called on "Play"
async function handlePlay() {
// Reset everything
setOutputs(["", "", "", ""])
setActiveIndex(-1)
// Stream each model's text in sequence
for (let i = 0; i < modelSequence.length; i++) {
setActiveIndex(i)
await fetchInChunks(modelSequence[i].type, i)
setActiveIndex(-1) // or keep streaming indicator until next loop
}
}
return (
<div className="space-y-6">
<div className="grid md:grid-cols-2 gap-6">
{modelSequence.map((seq, i) => {
const Icon = seq.icon
return (
<ModelCard
key={seq.type}
Icon={Icon}
title={seq.title}
text={prompt + outputs[i] || ""}
streaming={i === activeIndex}
/>
)
})}
</div>
<div className="flex justify-center">
<Button
size="lg"
className="bg-blue-500 hover:bg-blue-600 text-white px-8 rounded-full"
onClick={handlePlay}
>
<Play className="w-5 h-5 mr-2" />
Play
</Button>
</div>
</div>
)
}
// ----------------------------------------------------
// Main App
// ----------------------------------------------------
export default function App() {
const [dataset, setDataset] = useState("Bias (EMGSD)")
const [model, setModel] = useState("GPT-2")
const [category, setCategory] = useState(categories[0].value)
// Convert front-end selection to server keys
const datasetKey = dataset === "Bias (EMGSD)" ? "emgsd" : "emgsd"
const modelKey = model === "GPT-2" ? "gpt2" : "gpt2"
return (
<div className="min-h-screen flex flex-col dark:bg-transparent">
<Header />
{/* Main content */}
<main className="flex-grow flex flex-col items-center justify-center">
<div className="text-center space-y-8">
<div className="space-y-4">
<h1 className="text-7xl font-mono tracking-tighter text-black dark:text-white">
CorrSteer
</h1>
<p className="text-xl leading-relaxed text-gray-700 dark:text-gray-300 italic px-6">
Text Classification dataset can be used to <span className="font-bold">Steer</span> LLMs,
<br />
<span className="font-bold">Corr</span>elating with SAE features
</p>
</div>
{/* Dropdowns */}
<div className="grid md:grid-cols-2 gap-8">
<div className="space-y-2 px-6 mx-8">
<label className="text-sm font-medium dark:text-gray-300">
Dataset
</label>
<Select value={dataset} onValueChange={setDataset}>
<SelectTrigger className="dark:bg-gray-800 dark:text-white">
<SelectValue />
</SelectTrigger>
<SelectContent className="dark:bg-gray-800 dark:text-white">
<SelectItem value="Bias (EMGSD)">Bias (EMGSD)</SelectItem>
</SelectContent>
</Select>
</div>
<div className="space-y-2 px-6 mx-8">
<label className="text-sm font-medium dark:text-gray-300">
Language Model
</label>
<Select value={model} onValueChange={setModel}>
<SelectTrigger className="dark:bg-gray-800 dark:text-white">
<SelectValue />
</SelectTrigger>
<SelectContent className="dark:bg-gray-800 dark:text-white">
<SelectItem value="GPT-2">GPT-2</SelectItem>
</SelectContent>
</Select>
</div>
</div>
{/* Tabs: 3 categories -> each has its own content */}
<Tabs value={category} onValueChange={setCategory}>
<TabsList className="gap-1 bg-transparent">
{categories.map((cat) => (
<TabsTrigger
key={cat.value}
value={cat.value}
className="data-[state=active]:bg-blue-400 dark:data-[state=active]:bg-blue-500 data-[state=inactive]:bg-slate-200 dark:data-[state=inactive]:bg-slate-800 data-[state=active]:border-gray-300 px-4 py-2 text-sm"
>
{cat.label}
</TabsTrigger>
))}
</TabsList>
{categories.map((cat) => (
<TabsContent key={cat.value} value={cat.value} className="p-6">
<TabPanel
datasetKey={datasetKey}
modelKey={modelKey}
categoryKey={category}
prompt={cat.prompt}
/>
</TabsContent>
))}
</Tabs>
</div>
</main>
<Footer />
</div>
)
}