daltron commited on
Commit
8e477e5
·
verified ·
1 Parent(s): fb73398

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -30
app.py CHANGED
@@ -1,25 +1,42 @@
1
  import os
2
  import gradio as gr
3
  import torch
4
- from transformers import pipeline
 
 
 
 
 
 
5
  from openai import OpenAI
6
 
 
 
 
 
 
 
 
 
 
7
  # -------------------------
8
  # Model choices
9
  # -------------------------
10
  MODEL_OPTIONS = [
11
  "GPT-1 (openai-gpt) - local",
12
  "GPT-2 (gpt2) - local",
 
13
  "GPT-3.5 (gpt-3.5-turbo) - OpenAI",
14
  ]
15
 
16
  MODEL_MAP = {
17
  "GPT-1 (openai-gpt) - local": {"kind": "hf", "id": "openai-gpt"},
18
  "GPT-2 (gpt2) - local": {"kind": "hf", "id": "gpt2"},
 
19
  "GPT-3.5 (gpt-3.5-turbo) - OpenAI": {"kind": "openai-chat", "id": "gpt-3.5-turbo"},
20
  }
21
 
22
- # Cache pipelines for HF models so we only load once
23
  HF_PIPELINES = {}
24
 
25
  # OpenAI client (only if key exists)
@@ -28,65 +45,118 @@ OPENAI_CLIENT = OpenAI(api_key=OPENAI_KEY) if OPENAI_KEY else None
28
 
29
 
30
  def get_hf_pipeline(model_id: str):
31
- """Create/fetch a lightweight text-generation pipeline for CPU/GPU."""
32
  if model_id in HF_PIPELINES:
33
  return HF_PIPELINES[model_id]
34
 
35
  device = 0 if torch.cuda.is_available() else -1
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  gen = pipeline(
37
  "text-generation",
38
- model=model_id,
 
39
  device=device,
40
  )
41
  HF_PIPELINES[model_id] = gen
42
  return gen
43
 
44
 
45
- def generate(model_choice, prompt, max_new_tokens, temperature, top_p, seed):
46
- if not prompt.strip():
47
- return "Please enter a prompt."
 
 
 
48
 
49
  info = MODEL_MAP[model_choice]
50
  kind = info["kind"]
51
  model_id = info["id"]
52
 
53
- if seed is not None and int(seed) >= 0:
54
- torch.manual_seed(int(seed))
55
-
56
  try:
 
 
 
57
  if kind == "hf":
58
  gen = get_hf_pipeline(model_id)
59
- out = gen(
60
- prompt,
 
 
 
 
 
 
 
 
 
 
 
61
  max_new_tokens=int(max_new_tokens),
62
- do_sample=temperature > 0,
63
  temperature=max(1e-6, float(temperature)),
64
  top_p=float(top_p),
65
- pad_token_id=gen.tokenizer.eos_token_id,
66
- return_full_text=False, # don't echo the prompt
 
67
  )
68
- return out[0]["generated_text"]
 
 
 
 
 
 
 
 
 
69
 
70
  if kind == "openai-chat":
71
  if OPENAI_CLIENT is None:
72
- return "⚠️ To use GPT-3.5, set OPENAI_API_KEY in your Space (Settings → Variables & secrets)."
73
- resp = OPENAI_CLIENT.chat.completions.create(
 
 
74
  model=model_id,
75
  messages=[{"role": "user", "content": prompt}],
76
  max_tokens=int(max_new_tokens),
77
  temperature=float(temperature),
78
  top_p=float(top_p),
 
79
  )
80
- return (resp.choices[0].message.content or "").strip()
81
 
82
- return f"Unknown model kind: {kind}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  except Exception as e:
85
- return f"❌ Error from {model_choice} ({model_id}): {str(e)}"
86
 
87
 
88
  def maybe_warn(choice):
89
- """Show a small banner if user picked GPT-3.5 without an API key set."""
90
  info = MODEL_MAP[choice]
91
  needs_key = (info["kind"] == "openai-chat") and (OPENAI_CLIENT is None)
92
  if needs_key:
@@ -99,30 +169,35 @@ with gr.Blocks(title="Mini GPT Playground") as demo:
99
  """
100
  # Mini GPT Playground
101
  Type a prompt and choose a model.
102
- **Local (HF):** GPT-1 / GPT-2 — runs in this Space container with `transformers`.
103
- **OpenAI (API):** GPT-3.5 — requires `OPENAI_API_KEY`.
 
104
  """
105
  )
106
 
107
  with gr.Row():
108
- model_choice = gr.Dropdown(MODEL_OPTIONS, value=MODEL_OPTIONS[1], label="Model")
109
- max_new_tokens = gr.Slider(1, 512, value=128, step=1, label="Max new tokens")
110
  with gr.Row():
111
  temperature = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="Temperature")
112
- top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p")
113
  seed = gr.Number(value=42, precision=0, label="Seed (≥0 to fix sampling)")
114
 
115
- prompt = gr.Textbox(lines=6, label="Prompt", placeholder="Write a short story about a curious robot...")
116
  warn = gr.Markdown("", visible=False)
117
 
118
  generate_btn = gr.Button("Generate", variant="primary")
119
  output = gr.Textbox(lines=12, label="Output")
120
 
121
  model_choice.change(maybe_warn, inputs=[model_choice], outputs=[warn])
 
 
122
  generate_btn.click(
123
- generate,
124
  inputs=[model_choice, prompt, max_new_tokens, temperature, top_p, seed],
125
  outputs=[output],
126
  )
127
 
128
- demo.queue(max_size=16).launch()
 
 
 
1
  import os
2
  import gradio as gr
3
  import torch
4
+ from threading import Thread
5
+ from transformers import (
6
+ pipeline,
7
+ AutoTokenizer,
8
+ AutoModelForCausalLM,
9
+ TextIteratorStreamer,
10
+ )
11
  from openai import OpenAI
12
 
13
+ # -------------------------
14
+ # Runtime tuning for 2 vCPU Spaces
15
+ # -------------------------
16
+ try:
17
+ torch.set_num_threads(min(2, os.cpu_count() or 2))
18
+ torch.set_num_interop_threads(1)
19
+ except Exception:
20
+ pass
21
+
22
  # -------------------------
23
  # Model choices
24
  # -------------------------
25
  MODEL_OPTIONS = [
26
  "GPT-1 (openai-gpt) - local",
27
  "GPT-2 (gpt2) - local",
28
+ "DistilGPT-2 (distilgpt2) - local (fast)",
29
  "GPT-3.5 (gpt-3.5-turbo) - OpenAI",
30
  ]
31
 
32
  MODEL_MAP = {
33
  "GPT-1 (openai-gpt) - local": {"kind": "hf", "id": "openai-gpt"},
34
  "GPT-2 (gpt2) - local": {"kind": "hf", "id": "gpt2"},
35
+ "DistilGPT-2 (distilgpt2) - local (fast)": {"kind": "hf", "id": "distilgpt2"},
36
  "GPT-3.5 (gpt-3.5-turbo) - OpenAI": {"kind": "openai-chat", "id": "gpt-3.5-turbo"},
37
  }
38
 
39
+ # Cache for loaded Hugging Face models/pipelines
40
  HF_PIPELINES = {}
41
 
42
  # OpenAI client (only if key exists)
 
45
 
46
 
47
  def get_hf_pipeline(model_id: str):
48
+ """Create/fetch a lightweight text-generation pipeline for CPU/GPU with cached weights."""
49
  if model_id in HF_PIPELINES:
50
  return HF_PIPELINES[model_id]
51
 
52
  device = 0 if torch.cuda.is_available() else -1
53
+
54
+ # Prefer safetensors, load once
55
+ tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
56
+ mdl = AutoModelForCausalLM.from_pretrained(
57
+ model_id,
58
+ low_cpu_mem_usage=True,
59
+ torch_dtype=torch.float32, # CPU-safe
60
+ )
61
+
62
+ # Some older models (e.g., GPT-1/2) have no pad token
63
+ if tok.pad_token_id is None and tok.eos_token_id is not None:
64
+ tok.pad_token = tok.eos_token
65
+
66
  gen = pipeline(
67
  "text-generation",
68
+ model=mdl,
69
+ tokenizer=tok,
70
  device=device,
71
  )
72
  HF_PIPELINES[model_id] = gen
73
  return gen
74
 
75
 
76
+ def generate_stream(model_choice, prompt, max_new_tokens, temperature, top_p, seed):
77
+ """Stream tokens for both HF and OpenAI for faster perceived latency."""
78
+ prompt = (prompt or "").strip()
79
+ if not prompt:
80
+ yield "Please enter a prompt."
81
+ return
82
 
83
  info = MODEL_MAP[model_choice]
84
  kind = info["kind"]
85
  model_id = info["id"]
86
 
 
 
 
87
  try:
88
+ if seed is not None and int(seed) >= 0:
89
+ torch.manual_seed(int(seed))
90
+
91
  if kind == "hf":
92
  gen = get_hf_pipeline(model_id)
93
+ tok = gen.tokenizer
94
+ mdl = gen.model
95
+
96
+ streamer = TextIteratorStreamer(
97
+ tok, skip_prompt=True, skip_special_tokens=True
98
+ )
99
+
100
+ inputs = tok(prompt, return_tensors="pt")
101
+ if torch.cuda.is_available():
102
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
103
+
104
+ generate_kwargs = dict(
105
+ **inputs,
106
  max_new_tokens=int(max_new_tokens),
107
+ do_sample=float(temperature) > 0.0,
108
  temperature=max(1e-6, float(temperature)),
109
  top_p=float(top_p),
110
+ pad_token_id=tok.eos_token_id,
111
+ eos_token_id=tok.eos_token_id,
112
+ streamer=streamer,
113
  )
114
+
115
+ # Run generation in a thread so we can iterate streamer
116
+ thread = Thread(target=mdl.generate, kwargs=generate_kwargs)
117
+ thread.start()
118
+
119
+ out = ""
120
+ for token_text in streamer:
121
+ out += token_text
122
+ yield out
123
+ return
124
 
125
  if kind == "openai-chat":
126
  if OPENAI_CLIENT is None:
127
+ yield "⚠️ To use GPT-3.5, set OPENAI_API_KEY in your Space (Settings → Variables & secrets)."
128
+ return
129
+
130
+ stream = OPENAI_CLIENT.chat.completions.create(
131
  model=model_id,
132
  messages=[{"role": "user", "content": prompt}],
133
  max_tokens=int(max_new_tokens),
134
  temperature=float(temperature),
135
  top_p=float(top_p),
136
+ stream=True,
137
  )
 
138
 
139
+ out = ""
140
+ for chunk in stream:
141
+ delta = ""
142
+ try:
143
+ # v1 SDK streaming shape
144
+ delta = chunk.choices[0].delta.content or ""
145
+ except Exception:
146
+ # fallback if SDK variant differs
147
+ delta = getattr(chunk.choices[0], "text", "") or ""
148
+ if delta:
149
+ out += delta
150
+ yield out
151
+ return
152
+
153
+ yield f"Unknown model kind: {kind}"
154
 
155
  except Exception as e:
156
+ yield f"❌ Error from {model_choice} ({model_id}): {str(e)}"
157
 
158
 
159
  def maybe_warn(choice):
 
160
  info = MODEL_MAP[choice]
161
  needs_key = (info["kind"] == "openai-chat") and (OPENAI_CLIENT is None)
162
  if needs_key:
 
169
  """
170
  # Mini GPT Playground
171
  Type a prompt and choose a model.
172
+ **Local (HF):** GPT-1 / GPT-2 / DistilGPT-2 — runs in this Space container.
173
+ **OpenAI (API):** GPT-3.5 — requires `OPENAI_API_KEY`.
174
+ *(Tip: DistilGPT-2 is much faster on CPU.)*
175
  """
176
  )
177
 
178
  with gr.Row():
179
+ model_choice = gr.Dropdown(MODEL_OPTIONS, value="DistilGPT-2 (distilgpt2) - local (fast)", label="Model")
180
+ max_new_tokens = gr.Slider(1, 512, value=96, step=1, label="Max new tokens") # lower default for speed
181
  with gr.Row():
182
  temperature = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="Temperature")
183
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
184
  seed = gr.Number(value=42, precision=0, label="Seed (≥0 to fix sampling)")
185
 
186
+ prompt = gr.Textbox(lines=6, label="Prompt", placeholder="Write a short story about a curious robot")
187
  warn = gr.Markdown("", visible=False)
188
 
189
  generate_btn = gr.Button("Generate", variant="primary")
190
  output = gr.Textbox(lines=12, label="Output")
191
 
192
  model_choice.change(maybe_warn, inputs=[model_choice], outputs=[warn])
193
+
194
+ # Streamed generation
195
  generate_btn.click(
196
+ fn=generate_stream,
197
  inputs=[model_choice, prompt, max_new_tokens, temperature, top_p, seed],
198
  outputs=[output],
199
  )
200
 
201
+ # Keep concurrency low on 2 vCPU; smaller queue reduces tail latency
202
+ demo.queue(concurrency_count=1, max_size=8, status_update_rate=75).launch()
203
+