daltron commited on
Commit
738f792
·
verified ·
1 Parent(s): 8a0c1c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -33
app.py CHANGED
@@ -36,10 +36,8 @@ MODEL_MAP = {
36
  "GPT-3.5 (gpt-3.5-turbo) - OpenAI": {"kind": "openai-chat", "id": "gpt-3.5-turbo"},
37
  }
38
 
39
- # Cache for loaded Hugging Face models/pipelines
40
  HF_PIPELINES = {}
41
 
42
- # OpenAI client (only if key exists)
43
  OPENAI_KEY = os.getenv("OPENAI_API_KEY")
44
  OPENAI_CLIENT = OpenAI(api_key=OPENAI_KEY) if OPENAI_KEY else None
45
 
@@ -55,19 +53,14 @@ def get_hf_pipeline(model_id: str):
55
  mdl = AutoModelForCausalLM.from_pretrained(
56
  model_id,
57
  low_cpu_mem_usage=True,
58
- torch_dtype=torch.float32, # CPU-safe
59
  )
60
 
61
  # Some older models (e.g., GPT-1/2) have no pad token
62
  if tok.pad_token_id is None and tok.eos_token_id is not None:
63
  tok.pad_token = tok.eos_token
64
 
65
- gen = pipeline(
66
- "text-generation",
67
- model=mdl,
68
- tokenizer=tok,
69
- device=device,
70
- )
71
  HF_PIPELINES[model_id] = gen
72
  return gen
73
 
@@ -92,9 +85,7 @@ def generate_stream(model_choice, prompt, max_new_tokens, temperature, top_p, se
92
  tok = gen.tokenizer
93
  mdl = gen.model
94
 
95
- streamer = TextIteratorStreamer(
96
- tok, skip_prompt=True, skip_special_tokens=True
97
- )
98
 
99
  inputs = tok(prompt, return_tensors="pt")
100
  if torch.cuda.is_available():
@@ -111,7 +102,6 @@ def generate_stream(model_choice, prompt, max_new_tokens, temperature, top_p, se
111
  streamer=streamer,
112
  )
113
 
114
- # Run generation in a thread so we can iterate streamer
115
  thread = Thread(target=mdl.generate, kwargs=generate_kwargs)
116
  thread.start()
117
 
@@ -191,7 +181,6 @@ with gr.Blocks(title="Mini GPT Playground") as demo:
191
 
192
  model_choice.change(maybe_warn, inputs=[model_choice], outputs=[warn])
193
 
194
- # Streamed generation
195
  generate_btn.click(
196
  fn=generate_stream,
197
  inputs=[model_choice, prompt, max_new_tokens, temperature, top_p, seed],
@@ -199,30 +188,26 @@ with gr.Blocks(title="Mini GPT Playground") as demo:
199
  )
200
 
201
  # -------------------------
202
- # Robust initialization for HF Spaces
203
  # -------------------------
204
- # 1) Try to mount into a FastAPI app (works for "Python" Spaces)
205
- # 2) Otherwise, fall back to launching Gradio directly
206
  app = None
207
- if os.getenv("SPACE_ID"):
 
 
 
208
  try:
209
- from fastapi import FastAPI
210
- app = FastAPI()
211
- # Optional: small queue (if available in your Gradio version)
212
- try:
213
- demo = demo.queue(max_size=8)
214
- except TypeError:
215
- pass
216
- app = gr.mount_gradio_app(app, demo, path="/")
217
- except Exception:
218
- # FastAPI not available or mount failed; we'll rely on launch below when run locally
219
- app = None
220
-
221
- # For local dev or plain Gradio Spaces
222
  if __name__ == "__main__":
223
- # Version-safe queue/launch
224
  try:
225
  demo = demo.queue(max_size=8)
226
  except TypeError:
227
  pass
228
- demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
 
36
  "GPT-3.5 (gpt-3.5-turbo) - OpenAI": {"kind": "openai-chat", "id": "gpt-3.5-turbo"},
37
  }
38
 
 
39
  HF_PIPELINES = {}
40
 
 
41
  OPENAI_KEY = os.getenv("OPENAI_API_KEY")
42
  OPENAI_CLIENT = OpenAI(api_key=OPENAI_KEY) if OPENAI_KEY else None
43
 
 
53
  mdl = AutoModelForCausalLM.from_pretrained(
54
  model_id,
55
  low_cpu_mem_usage=True,
56
+ torch_dtype=torch.float32,
57
  )
58
 
59
  # Some older models (e.g., GPT-1/2) have no pad token
60
  if tok.pad_token_id is None and tok.eos_token_id is not None:
61
  tok.pad_token = tok.eos_token
62
 
63
+ gen = pipeline("text-generation", model=mdl, tokenizer=tok, device=device)
 
 
 
 
 
64
  HF_PIPELINES[model_id] = gen
65
  return gen
66
 
 
85
  tok = gen.tokenizer
86
  mdl = gen.model
87
 
88
+ streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
 
 
89
 
90
  inputs = tok(prompt, return_tensors="pt")
91
  if torch.cuda.is_available():
 
102
  streamer=streamer,
103
  )
104
 
 
105
  thread = Thread(target=mdl.generate, kwargs=generate_kwargs)
106
  thread.start()
107
 
 
181
 
182
  model_choice.change(maybe_warn, inputs=[model_choice], outputs=[warn])
183
 
 
184
  generate_btn.click(
185
  fn=generate_stream,
186
  inputs=[model_choice, prompt, max_new_tokens, temperature, top_p, seed],
 
188
  )
189
 
190
  # -------------------------
191
+ # Spaces-friendly init
192
  # -------------------------
193
+ # 1) Expose FastAPI app if running in a Python SDK Space (optional)
 
194
  app = None
195
+ try:
196
+ # If FastAPI is available, we provide an app so Python SDK Spaces can import it.
197
+ from fastapi import FastAPI
198
+ app = FastAPI()
199
  try:
200
+ demo = demo.queue(max_size=8)
201
+ except TypeError:
202
+ pass
203
+ app = gr.mount_gradio_app(app, demo, path="/")
204
+ except Exception:
205
+ app = None # fine on Gradio SDK Spaces
206
+
207
+ # 2) For local runs / Gradio SDK Spaces: DO NOT set server_port; let Gradio pick the env port.
 
 
 
 
 
208
  if __name__ == "__main__":
 
209
  try:
210
  demo = demo.queue(max_size=8)
211
  except TypeError:
212
  pass
213
+ demo.launch() # no server_port — avoids port collisions on Spaces