Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import torch | |
from threading import Thread | |
from transformers import ( | |
pipeline, | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
TextIteratorStreamer, | |
) | |
from openai import OpenAI | |
# -------- Runtime tuning for tiny CPU Spaces -------- | |
try: | |
torch.set_num_threads(min(2, os.cpu_count() or 2)) | |
torch.set_num_interop_threads(1) | |
except Exception: | |
pass | |
# -------- Model choices -------- | |
MODEL_OPTIONS = [ | |
"GPT-1 (openai-gpt) - local", | |
"GPT-2 (gpt2) - local", | |
"DistilGPT-2 (distilgpt2) - local (fast)", | |
"GPT-3.5 (gpt-3.5-turbo) - OpenAI", | |
] | |
MODEL_MAP = { | |
"GPT-1 (openai-gpt) - local": {"kind": "hf", "id": "openai-gpt"}, | |
"GPT-2 (gpt2) - local": {"kind": "hf", "id": "gpt2"}, | |
"DistilGPT-2 (distilgpt2) - local (fast)": {"kind": "hf", "id": "distilgpt2"}, | |
"GPT-3.5 (gpt-3.5-turbo) - OpenAI": {"kind": "openai-chat", "id": "gpt-3.5-turbo"}, | |
} | |
HF_PIPELINES = {} | |
OPENAI_KEY = os.getenv("OPENAI_API_KEY") | |
OPENAI_CLIENT = OpenAI(api_key=OPENAI_KEY) if OPENAI_KEY else None | |
def get_hf_pipeline(model_id: str): | |
"""Create/fetch a text-generation pipeline; cache to avoid reloads.""" | |
if model_id in HF_PIPELINES: | |
return HF_PIPELINES[model_id] | |
device = 0 if torch.cuda.is_available() else -1 | |
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True) | |
mdl = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float32, # CPU-safe | |
) | |
# Older GPT models lack pad_token; map to EOS | |
if tok.pad_token_id is None and tok.eos_token_id is not None: | |
tok.pad_token = tok.eos_token | |
gen = pipeline("text-generation", model=mdl, tokenizer=tok, device=device) | |
HF_PIPELINES[model_id] = gen | |
return gen | |
def generate_stream(model_choice, prompt, max_new_tokens, temperature, top_p, seed): | |
"""Stream tokens for both HF and OpenAI to improve perceived latency.""" | |
prompt = (prompt or "").strip() | |
if not prompt: | |
yield "Please enter a prompt." | |
return | |
info = MODEL_MAP[model_choice] | |
kind = info["kind"] | |
model_id = info["id"] | |
try: | |
if seed is not None and int(seed) >= 0: | |
torch.manual_seed(int(seed)) | |
if kind == "hf": | |
gen = get_hf_pipeline(model_id) | |
tok = gen.tokenizer | |
mdl = gen.model | |
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True) | |
inputs = tok(prompt, return_tensors="pt") | |
if torch.cuda.is_available(): | |
inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
generate_kwargs = dict( | |
**inputs, | |
max_new_tokens=int(max_new_tokens), | |
do_sample=float(temperature) > 0.0, | |
temperature=max(1e-6, float(temperature)), | |
top_p=float(top_p), | |
pad_token_id=tok.eos_token_id, | |
eos_token_id=tok.eos_token_id, | |
streamer=streamer, | |
) | |
thread = Thread(target=mdl.generate, kwargs=generate_kwargs) | |
thread.start() | |
out = "" | |
for token_text in streamer: | |
out += token_text | |
yield out | |
return | |
if kind == "openai-chat": | |
if OPENAI_CLIENT is None: | |
yield "⚠️ To use GPT-3.5, set OPENAI_API_KEY in Space (Settings → Variables & secrets)." | |
return | |
stream = OPENAI_CLIENT.chat.completions.create( | |
model=model_id, | |
messages=[{"role": "user", "content": prompt}], | |
max_tokens=int(max_new_tokens), | |
temperature=float(temperature), | |
top_p=float(top_p), | |
stream=True, | |
) | |
out = "" | |
for chunk in stream: | |
delta = "" | |
try: | |
delta = chunk.choices[0].delta.content or "" | |
except Exception: | |
delta = getattr(chunk.choices[0], "text", "") or "" | |
if delta: | |
out += delta | |
yield out | |
return | |
yield f"Unknown model kind: {kind}" | |
except Exception as e: | |
yield f"❌ Error from {model_choice} ({model_id}): {str(e)}" | |
def maybe_warn(choice): | |
info = MODEL_MAP[choice] | |
needs_key = (info["kind"] == "openai-chat") and (OPENAI_CLIENT is None) | |
if needs_key: | |
return gr.update(value="⚠️ GPT-3.5 requires OPENAI_API_KEY in Space secrets.", visible=True) | |
return gr.update(visible=False) | |
# -------- UI -------- | |
CSS = ".gradio-container{max-width:960px;margin:0 auto;}" | |
with gr.Blocks(title="Mini GPT Playground", css=CSS) as demo: | |
gr.Markdown( | |
""" | |
# Mini GPT Playground | |
Type a prompt and choose a model. | |
**Local (HF):** GPT-1 / GPT-2 / DistilGPT-2 — runs in this Space container. | |
**OpenAI (API):** GPT-3.5 — requires `OPENAI_API_KEY`. | |
*(Tip: DistilGPT-2 is much faster on CPU.)* | |
""" | |
) | |
with gr.Row(): | |
model_choice = gr.Dropdown(MODEL_OPTIONS, value="DistilGPT-2 (distilgpt2) - local (fast)", label="Model") | |
max_new_tokens = gr.Slider(1, 512, value=96, step=1, label="Max new tokens") | |
with gr.Row(): | |
temperature = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="Temperature") | |
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p") | |
seed = gr.Number(value=42, precision=0, label="Seed (≥0 to fix sampling)") | |
prompt = gr.Textbox(lines=6, label="Prompt", placeholder="Write a short story about a curious robot…") | |
warn = gr.Markdown("", visible=False) | |
generate_btn = gr.Button("Generate", variant="primary") | |
output = gr.Textbox(lines=12, label="Output") | |
model_choice.change(maybe_warn, inputs=[model_choice], outputs=[warn]) | |
generate_btn.click( | |
fn=generate_stream, | |
inputs=[model_choice, prompt, max_new_tokens, temperature, top_p, seed], | |
outputs=[output], | |
) | |
# -------- Spaces-friendly launch (no custom port) -------- | |
try: | |
demo = demo.queue(max_size=8) # keep small on 2 vCPU | |
except TypeError: | |
pass | |
demo.launch() # don't pass server_port; Spaces sets it | |