Spaces:
Sleeping
Sleeping
File size: 6,295 Bytes
32475ad 8e477e5 32475ad 44db4cf 8e477e5 44db4cf 32475ad fb73398 8e477e5 32475ad fb73398 8e477e5 32475ad 44db4cf 32475ad 8e477e5 44db4cf 8e477e5 44db4cf 8e477e5 738f792 32475ad 8e477e5 44db4cf 8e477e5 32475ad 8e477e5 32475ad 8e477e5 738f792 8e477e5 fb73398 8e477e5 fb73398 8e477e5 32475ad 8e477e5 32475ad 44db4cf 8e477e5 32475ad fb73398 8e477e5 32475ad 8e477e5 32475ad 8e477e5 32475ad fb73398 32475ad fb73398 fe6f7ca 32475ad 44db4cf 32475ad fb73398 8e477e5 32475ad 8e477e5 8a0c1c1 32475ad 8e477e5 fb73398 32475ad 8e477e5 fb73398 32475ad fb73398 32475ad 8e477e5 32475ad 44db4cf 738f792 44db4cf 738f792 44db4cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import os
import gradio as gr
import torch
from threading import Thread
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
)
from openai import OpenAI
# -------- Runtime tuning for tiny CPU Spaces --------
try:
torch.set_num_threads(min(2, os.cpu_count() or 2))
torch.set_num_interop_threads(1)
except Exception:
pass
# -------- Model choices --------
MODEL_OPTIONS = [
"GPT-1 (openai-gpt) - local",
"GPT-2 (gpt2) - local",
"DistilGPT-2 (distilgpt2) - local (fast)",
"GPT-3.5 (gpt-3.5-turbo) - OpenAI",
]
MODEL_MAP = {
"GPT-1 (openai-gpt) - local": {"kind": "hf", "id": "openai-gpt"},
"GPT-2 (gpt2) - local": {"kind": "hf", "id": "gpt2"},
"DistilGPT-2 (distilgpt2) - local (fast)": {"kind": "hf", "id": "distilgpt2"},
"GPT-3.5 (gpt-3.5-turbo) - OpenAI": {"kind": "openai-chat", "id": "gpt-3.5-turbo"},
}
HF_PIPELINES = {}
OPENAI_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_CLIENT = OpenAI(api_key=OPENAI_KEY) if OPENAI_KEY else None
def get_hf_pipeline(model_id: str):
"""Create/fetch a text-generation pipeline; cache to avoid reloads."""
if model_id in HF_PIPELINES:
return HF_PIPELINES[model_id]
device = 0 if torch.cuda.is_available() else -1
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
mdl = AutoModelForCausalLM.from_pretrained(
model_id,
low_cpu_mem_usage=True,
torch_dtype=torch.float32, # CPU-safe
)
# Older GPT models lack pad_token; map to EOS
if tok.pad_token_id is None and tok.eos_token_id is not None:
tok.pad_token = tok.eos_token
gen = pipeline("text-generation", model=mdl, tokenizer=tok, device=device)
HF_PIPELINES[model_id] = gen
return gen
def generate_stream(model_choice, prompt, max_new_tokens, temperature, top_p, seed):
"""Stream tokens for both HF and OpenAI to improve perceived latency."""
prompt = (prompt or "").strip()
if not prompt:
yield "Please enter a prompt."
return
info = MODEL_MAP[model_choice]
kind = info["kind"]
model_id = info["id"]
try:
if seed is not None and int(seed) >= 0:
torch.manual_seed(int(seed))
if kind == "hf":
gen = get_hf_pipeline(model_id)
tok = gen.tokenizer
mdl = gen.model
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
inputs = tok(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
generate_kwargs = dict(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=float(temperature) > 0.0,
temperature=max(1e-6, float(temperature)),
top_p=float(top_p),
pad_token_id=tok.eos_token_id,
eos_token_id=tok.eos_token_id,
streamer=streamer,
)
thread = Thread(target=mdl.generate, kwargs=generate_kwargs)
thread.start()
out = ""
for token_text in streamer:
out += token_text
yield out
return
if kind == "openai-chat":
if OPENAI_CLIENT is None:
yield "⚠️ To use GPT-3.5, set OPENAI_API_KEY in Space (Settings → Variables & secrets)."
return
stream = OPENAI_CLIENT.chat.completions.create(
model=model_id,
messages=[{"role": "user", "content": prompt}],
max_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
stream=True,
)
out = ""
for chunk in stream:
delta = ""
try:
delta = chunk.choices[0].delta.content or ""
except Exception:
delta = getattr(chunk.choices[0], "text", "") or ""
if delta:
out += delta
yield out
return
yield f"Unknown model kind: {kind}"
except Exception as e:
yield f"❌ Error from {model_choice} ({model_id}): {str(e)}"
def maybe_warn(choice):
info = MODEL_MAP[choice]
needs_key = (info["kind"] == "openai-chat") and (OPENAI_CLIENT is None)
if needs_key:
return gr.update(value="⚠️ GPT-3.5 requires OPENAI_API_KEY in Space secrets.", visible=True)
return gr.update(visible=False)
# -------- UI --------
CSS = ".gradio-container{max-width:960px;margin:0 auto;}"
with gr.Blocks(title="Mini GPT Playground", css=CSS) as demo:
gr.Markdown(
"""
# Mini GPT Playground
Type a prompt and choose a model.
**Local (HF):** GPT-1 / GPT-2 / DistilGPT-2 — runs in this Space container.
**OpenAI (API):** GPT-3.5 — requires `OPENAI_API_KEY`.
*(Tip: DistilGPT-2 is much faster on CPU.)*
"""
)
with gr.Row():
model_choice = gr.Dropdown(MODEL_OPTIONS, value="DistilGPT-2 (distilgpt2) - local (fast)", label="Model")
max_new_tokens = gr.Slider(1, 512, value=96, step=1, label="Max new tokens")
with gr.Row():
temperature = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
seed = gr.Number(value=42, precision=0, label="Seed (≥0 to fix sampling)")
prompt = gr.Textbox(lines=6, label="Prompt", placeholder="Write a short story about a curious robot…")
warn = gr.Markdown("", visible=False)
generate_btn = gr.Button("Generate", variant="primary")
output = gr.Textbox(lines=12, label="Output")
model_choice.change(maybe_warn, inputs=[model_choice], outputs=[warn])
generate_btn.click(
fn=generate_stream,
inputs=[model_choice, prompt, max_new_tokens, temperature, top_p, seed],
outputs=[output],
)
# -------- Spaces-friendly launch (no custom port) --------
try:
demo = demo.queue(max_size=8) # keep small on 2 vCPU
except TypeError:
pass
demo.launch() # don't pass server_port; Spaces sets it
|