daltron commited on
Commit
44db4cf
·
verified ·
1 Parent(s): 0388965

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -38
app.py CHANGED
@@ -10,18 +10,14 @@ from transformers import (
10
  )
11
  from openai import OpenAI
12
 
13
- # -------------------------
14
- # Runtime tuning for small CPU Spaces
15
- # -------------------------
16
  try:
17
  torch.set_num_threads(min(2, os.cpu_count() or 2))
18
  torch.set_num_interop_threads(1)
19
  except Exception:
20
  pass
21
 
22
- # -------------------------
23
- # Model choices
24
- # -------------------------
25
  MODEL_OPTIONS = [
26
  "GPT-1 (openai-gpt) - local",
27
  "GPT-2 (gpt2) - local",
@@ -43,7 +39,7 @@ OPENAI_CLIENT = OpenAI(api_key=OPENAI_KEY) if OPENAI_KEY else None
43
 
44
 
45
  def get_hf_pipeline(model_id: str):
46
- """Create/fetch a lightweight text-generation pipeline for CPU/GPU with cached weights."""
47
  if model_id in HF_PIPELINES:
48
  return HF_PIPELINES[model_id]
49
 
@@ -53,10 +49,10 @@ def get_hf_pipeline(model_id: str):
53
  mdl = AutoModelForCausalLM.from_pretrained(
54
  model_id,
55
  low_cpu_mem_usage=True,
56
- torch_dtype=torch.float32,
57
  )
58
 
59
- # Some older models (e.g., GPT-1/2) have no pad token
60
  if tok.pad_token_id is None and tok.eos_token_id is not None:
61
  tok.pad_token = tok.eos_token
62
 
@@ -66,7 +62,7 @@ def get_hf_pipeline(model_id: str):
66
 
67
 
68
  def generate_stream(model_choice, prompt, max_new_tokens, temperature, top_p, seed):
69
- """Stream tokens for both HF and OpenAI for faster perceived latency."""
70
  prompt = (prompt or "").strip()
71
  if not prompt:
72
  yield "Please enter a prompt."
@@ -113,7 +109,7 @@ def generate_stream(model_choice, prompt, max_new_tokens, temperature, top_p, se
113
 
114
  if kind == "openai-chat":
115
  if OPENAI_CLIENT is None:
116
- yield "⚠️ To use GPT-3.5, set OPENAI_API_KEY in your Space (Settings → Variables & secrets)."
117
  return
118
 
119
  stream = OPENAI_CLIENT.chat.completions.create(
@@ -151,10 +147,10 @@ def maybe_warn(choice):
151
  return gr.update(visible=False)
152
 
153
 
154
- # -------------------------
155
- # UI
156
- # -------------------------
157
- with gr.Blocks(title="Mini GPT Playground") as demo:
158
  gr.Markdown(
159
  """
160
  # Mini GPT Playground
@@ -180,34 +176,16 @@ with gr.Blocks(title="Mini GPT Playground") as demo:
180
  output = gr.Textbox(lines=12, label="Output")
181
 
182
  model_choice.change(maybe_warn, inputs=[model_choice], outputs=[warn])
183
-
184
  generate_btn.click(
185
  fn=generate_stream,
186
  inputs=[model_choice, prompt, max_new_tokens, temperature, top_p, seed],
187
  outputs=[output],
188
  )
189
 
190
- # -------------------------
191
- # Spaces-friendly init
192
- # -------------------------
193
- # 1) Expose FastAPI app if running in a Python SDK Space (optional)
194
- app = None
195
  try:
196
- # If FastAPI is available, we provide an app so Python SDK Spaces can import it.
197
- from fastapi import FastAPI
198
- app = FastAPI()
199
- try:
200
- demo = demo.queue(max_size=8)
201
- except TypeError:
202
- pass
203
- app = gr.mount_gradio_app(app, demo, path="/")
204
- except Exception:
205
- app = None # fine on Gradio SDK Spaces
206
 
207
- # 2) For local runs / Gradio SDK Spaces: DO NOT set server_port; let Gradio pick the env port.
208
- if __name__ == "__main__":
209
- try:
210
- demo = demo.queue(max_size=8)
211
- except TypeError:
212
- pass
213
- demo.launch() # no server_port — avoids port collisions on Spaces
 
10
  )
11
  from openai import OpenAI
12
 
13
+ # -------- Runtime tuning for tiny CPU Spaces --------
 
 
14
  try:
15
  torch.set_num_threads(min(2, os.cpu_count() or 2))
16
  torch.set_num_interop_threads(1)
17
  except Exception:
18
  pass
19
 
20
+ # -------- Model choices --------
 
 
21
  MODEL_OPTIONS = [
22
  "GPT-1 (openai-gpt) - local",
23
  "GPT-2 (gpt2) - local",
 
39
 
40
 
41
  def get_hf_pipeline(model_id: str):
42
+ """Create/fetch a text-generation pipeline; cache to avoid reloads."""
43
  if model_id in HF_PIPELINES:
44
  return HF_PIPELINES[model_id]
45
 
 
49
  mdl = AutoModelForCausalLM.from_pretrained(
50
  model_id,
51
  low_cpu_mem_usage=True,
52
+ torch_dtype=torch.float32, # CPU-safe
53
  )
54
 
55
+ # Older GPT models lack pad_token; map to EOS
56
  if tok.pad_token_id is None and tok.eos_token_id is not None:
57
  tok.pad_token = tok.eos_token
58
 
 
62
 
63
 
64
  def generate_stream(model_choice, prompt, max_new_tokens, temperature, top_p, seed):
65
+ """Stream tokens for both HF and OpenAI to improve perceived latency."""
66
  prompt = (prompt or "").strip()
67
  if not prompt:
68
  yield "Please enter a prompt."
 
109
 
110
  if kind == "openai-chat":
111
  if OPENAI_CLIENT is None:
112
+ yield "⚠️ To use GPT-3.5, set OPENAI_API_KEY in Space (Settings → Variables & secrets)."
113
  return
114
 
115
  stream = OPENAI_CLIENT.chat.completions.create(
 
147
  return gr.update(visible=False)
148
 
149
 
150
+ # -------- UI --------
151
+ CSS = ".gradio-container{max-width:960px;margin:0 auto;}"
152
+
153
+ with gr.Blocks(title="Mini GPT Playground", css=CSS) as demo:
154
  gr.Markdown(
155
  """
156
  # Mini GPT Playground
 
176
  output = gr.Textbox(lines=12, label="Output")
177
 
178
  model_choice.change(maybe_warn, inputs=[model_choice], outputs=[warn])
 
179
  generate_btn.click(
180
  fn=generate_stream,
181
  inputs=[model_choice, prompt, max_new_tokens, temperature, top_p, seed],
182
  outputs=[output],
183
  )
184
 
185
+ # -------- Spaces-friendly launch (no custom port) --------
 
 
 
 
186
  try:
187
+ demo = demo.queue(max_size=8) # keep small on 2 vCPU
188
+ except TypeError:
189
+ pass
 
 
 
 
 
 
 
190
 
191
+ demo.launch() # don't pass server_port; Spaces sets it