|
import React, { useState } from "react" |
|
import { |
|
PiFinnTheHumanLight, |
|
PiFinnTheHumanBold, |
|
PiFinnTheHumanFill, |
|
PiFinnTheHumanDuotone, |
|
} from "react-icons/pi" |
|
import { Play } from "lucide-react" |
|
import { |
|
Select, |
|
SelectContent, |
|
SelectItem, |
|
SelectTrigger, |
|
SelectValue, |
|
} from "@/components/ui/select" |
|
import { Button } from "@/components/ui/button" |
|
import { Tabs, TabsList, TabsTrigger, TabsContent } from "@/components/ui/tabs" |
|
|
|
import { Footer } from "@/components/footer" |
|
import { Header } from "@/components/header" |
|
|
|
|
|
|
|
|
|
const categories = [ |
|
{ value: "lgbtq+", label: "LGBT+", prompt: "The gay people are" }, |
|
{ value: "nationality", label: "Nationality", prompt: "The French people are" }, |
|
{ value: "race", label: "Race", prompt: "The Black people are" }, |
|
] |
|
|
|
|
|
|
|
|
|
function ModelCard({ |
|
Icon, |
|
title, |
|
text, |
|
streaming, |
|
}: { |
|
Icon: React.ElementType |
|
title: string |
|
text: string |
|
streaming?: boolean |
|
}) { |
|
return ( |
|
<div className="border border-gray-200 dark:border-gray-700 rounded-lg p-6 flex flex-col min-h-[120px]"> |
|
<div className="flex items-center gap-2 mb-2"> |
|
<Icon className="h-5 w-5" /> |
|
<h3 className="font-semibold">{title}</h3> |
|
</div> |
|
{/* If streaming == true, we append "●" at the end to mimic a typing indicator */} |
|
<div className="min-h-40 w-80 text-left"> |
|
<p className="text-gray-700 dark:text-gray-300 whitespace-pre-wrap break-words"> |
|
{text} |
|
{streaming && "●"} |
|
{/* // ⏺ or ⬤ or ● */} |
|
</p> |
|
</div> |
|
</div> |
|
) |
|
} |
|
|
|
|
|
|
|
|
|
function TabPanel({ |
|
datasetKey, |
|
modelKey, |
|
categoryKey, |
|
prompt, |
|
}: { |
|
datasetKey: string |
|
modelKey: string |
|
categoryKey: string |
|
prompt: string |
|
}) { |
|
|
|
const modelSequence = [ |
|
{ type: "original", title: "Original Model", key: "origin", icon: PiFinnTheHumanLight }, |
|
{ type: "origin+steer", title: "Original + Steering", key: "origin+steer", icon: PiFinnTheHumanBold }, |
|
{ type: "trained", title: "Trained Model", key: "trained", icon: PiFinnTheHumanFill }, |
|
{ type: "trained-steer", title: "Trained - Steering", key: "trained-steer", icon: PiFinnTheHumanDuotone }, |
|
] |
|
|
|
|
|
const [outputs, setOutputs] = useState(["", "", "", ""]) |
|
|
|
const [activeIndex, setActiveIndex] = useState(-1) |
|
|
|
|
|
async function fetchInChunks(genType: string, index: number) { |
|
const payload = { |
|
model: modelKey, |
|
dataset: datasetKey, |
|
category: categoryKey, |
|
type: genType, |
|
} |
|
|
|
const apiBaseUrl = import.meta.env.VITE_API_BASE_URL || "" |
|
const response = await fetch(`${apiBaseUrl}/api/generate`, { |
|
method: "POST", |
|
headers: { "Content-Type": "application/json" }, |
|
body: JSON.stringify(payload), |
|
}) |
|
|
|
|
|
const reader = response.body?.getReader() |
|
if (!reader) return |
|
const decoder = new TextDecoder("utf-8") |
|
let partial = "" |
|
|
|
while (true) { |
|
const { done, value } = await reader.read() |
|
if (done) break |
|
|
|
|
|
partial += decoder.decode(value, { stream: true }) |
|
|
|
setOutputs((prev) => { |
|
const copy = [...prev] |
|
copy[index] = partial |
|
return copy |
|
}) |
|
} |
|
} |
|
|
|
|
|
async function handlePlay() { |
|
|
|
setOutputs(["", "", "", ""]) |
|
setActiveIndex(-1) |
|
|
|
|
|
for (let i = 0; i < modelSequence.length; i++) { |
|
setActiveIndex(i) |
|
await fetchInChunks(modelSequence[i].type, i) |
|
setActiveIndex(-1) |
|
} |
|
} |
|
|
|
return ( |
|
<div className="space-y-6"> |
|
<div className="grid md:grid-cols-2 gap-6"> |
|
{modelSequence.map((seq, i) => { |
|
const Icon = seq.icon |
|
return ( |
|
<ModelCard |
|
key={seq.type} |
|
Icon={Icon} |
|
title={seq.title} |
|
text={prompt + outputs[i] || ""} |
|
streaming={i === activeIndex} |
|
/> |
|
) |
|
})} |
|
</div> |
|
|
|
<div className="flex justify-center"> |
|
<Button |
|
size="lg" |
|
className="bg-blue-500 hover:bg-blue-600 text-white px-8 rounded-full" |
|
onClick={handlePlay} |
|
> |
|
<Play className="w-5 h-5 mr-2" /> |
|
Play |
|
</Button> |
|
</div> |
|
</div> |
|
) |
|
} |
|
|
|
|
|
|
|
|
|
export default function App() { |
|
const [dataset, setDataset] = useState("Bias (EMGSD)") |
|
const [model, setModel] = useState("GPT-2") |
|
const [category, setCategory] = useState(categories[0].value) |
|
|
|
|
|
const datasetKey = dataset === "Bias (EMGSD)" ? "emgsd" : "emgsd" |
|
const modelKey = model === "GPT-2" ? "gpt2" : "gpt2" |
|
|
|
return ( |
|
<div className="min-h-screen flex flex-col dark:bg-transparent"> |
|
<Header /> |
|
{/* Main content */} |
|
<main className="flex-grow flex flex-col items-center justify-center"> |
|
<div className="text-center space-y-8"> |
|
<div className="space-y-4"> |
|
<h1 className="text-7xl font-mono tracking-tighter text-black dark:text-white"> |
|
CorrSteer |
|
</h1> |
|
<p className="text-xl leading-relaxed text-gray-700 dark:text-gray-300 italic px-6"> |
|
Text Classification dataset can be used to <span className="font-bold">Steer</span> LLMs, |
|
<br /> |
|
<span className="font-bold">Corr</span>elating with SAE features |
|
</p> |
|
</div> |
|
|
|
{/* Dropdowns */} |
|
<div className="grid md:grid-cols-2 gap-8"> |
|
<div className="space-y-2 px-6 mx-8"> |
|
<label className="text-sm font-medium dark:text-gray-300"> |
|
Dataset |
|
</label> |
|
<Select value={dataset} onValueChange={setDataset}> |
|
<SelectTrigger className="dark:bg-gray-800 dark:text-white"> |
|
<SelectValue /> |
|
</SelectTrigger> |
|
<SelectContent className="dark:bg-gray-800 dark:text-white"> |
|
<SelectItem value="Bias (EMGSD)">Bias (EMGSD)</SelectItem> |
|
</SelectContent> |
|
</Select> |
|
</div> |
|
|
|
<div className="space-y-2 px-6 mx-8"> |
|
<label className="text-sm font-medium dark:text-gray-300"> |
|
Language Model |
|
</label> |
|
<Select value={model} onValueChange={setModel}> |
|
<SelectTrigger className="dark:bg-gray-800 dark:text-white"> |
|
<SelectValue /> |
|
</SelectTrigger> |
|
<SelectContent className="dark:bg-gray-800 dark:text-white"> |
|
<SelectItem value="GPT-2">GPT-2</SelectItem> |
|
</SelectContent> |
|
</Select> |
|
</div> |
|
</div> |
|
|
|
{/* Tabs: 3 categories -> each has its own content */} |
|
<Tabs value={category} onValueChange={setCategory}> |
|
<TabsList className="gap-1 bg-transparent"> |
|
{categories.map((cat) => ( |
|
<TabsTrigger |
|
key={cat.value} |
|
value={cat.value} |
|
className="data-[state=active]:bg-blue-400 dark:data-[state=active]:bg-blue-500 data-[state=inactive]:bg-slate-200 dark:data-[state=inactive]:bg-slate-800 data-[state=active]:border-gray-300 px-4 py-2 text-sm" |
|
> |
|
{cat.label} |
|
</TabsTrigger> |
|
))} |
|
</TabsList> |
|
|
|
{categories.map((cat) => ( |
|
<TabsContent key={cat.value} value={cat.value} className="p-6"> |
|
<TabPanel |
|
datasetKey={datasetKey} |
|
modelKey={modelKey} |
|
categoryKey={category} |
|
prompt={cat.prompt} |
|
/> |
|
</TabsContent> |
|
))} |
|
</Tabs> |
|
</div> |
|
</main> |
|
<Footer /> |
|
</div> |
|
) |
|
} |
|
|