daltron commited on
Commit
fb73398
·
verified ·
1 Parent(s): fe6f7ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -58
app.py CHANGED
@@ -4,43 +4,45 @@ import torch
4
  from transformers import pipeline
5
  from openai import OpenAI
6
 
7
- # ------------------------------------------------------------
8
- # Model registry
9
- # ------------------------------------------------------------
10
  MODEL_OPTIONS = [
11
- "GPT-1 (OpenAI GPT) - local",
12
- "GPT-2 (small) - local",
13
- "GPT-3 (text-davinci-003) - OpenAI",
14
  "GPT-3.5 (gpt-3.5-turbo) - OpenAI",
15
  ]
16
 
17
  MODEL_MAP = {
18
- "GPT-1 (OpenAI GPT) - local": {"kind": "hf", "id": "openai-gpt"},
19
- "GPT-2 (small) - local": {"kind": "hf", "id": "gpt2"},
20
- "GPT-3 (text-davinci-003) - OpenAI": {"kind": "openai-completion", "id": "text-davinci-003"},
21
  "GPT-3.5 (gpt-3.5-turbo) - OpenAI": {"kind": "openai-chat", "id": "gpt-3.5-turbo"},
22
  }
23
 
24
- # Cache for local HF pipelines
25
  HF_PIPELINES = {}
26
 
27
- # OpenAI client (only initialized if key exists)
28
  OPENAI_KEY = os.getenv("OPENAI_API_KEY")
29
  OPENAI_CLIENT = OpenAI(api_key=OPENAI_KEY) if OPENAI_KEY else None
30
 
31
 
32
  def get_hf_pipeline(model_id: str):
33
- """Create or fetch a cached text-generation pipeline for a HF model."""
34
  if model_id in HF_PIPELINES:
35
  return HF_PIPELINES[model_id]
36
 
37
  device = 0 if torch.cuda.is_available() else -1
38
- gen = pipeline("text-generation", model=model_id, device=device)
 
 
 
 
39
  HF_PIPELINES[model_id] = gen
40
  return gen
41
 
42
 
43
- def generate(model_choice: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float, seed: int):
44
  if not prompt.strip():
45
  return "Please enter a prompt."
46
 
@@ -48,46 +50,34 @@ def generate(model_choice: str, prompt: str, max_new_tokens: int, temperature: f
48
  kind = info["kind"]
49
  model_id = info["id"]
50
 
51
- if seed is not None and seed >= 0:
52
- torch.manual_seed(seed)
53
 
54
  try:
55
  if kind == "hf":
56
  gen = get_hf_pipeline(model_id)
57
  out = gen(
58
  prompt,
59
- max_new_tokens=max_new_tokens,
60
  do_sample=temperature > 0,
61
- temperature=max(1e-6, temperature),
62
- top_p=top_p,
63
  pad_token_id=gen.tokenizer.eos_token_id,
64
- return_full_text=False,
65
  )
66
  return out[0]["generated_text"]
67
 
68
- if kind == "openai-completion":
69
- if OPENAI_CLIENT is None:
70
- return "⚠️ OPENAI_API_KEY not set. Add it in your Space secrets to use GPT-3."
71
- resp = OPENAI_CLIENT.completions.create(
72
- model=model_id,
73
- prompt=prompt,
74
- max_tokens=max_new_tokens,
75
- temperature=temperature,
76
- top_p=top_p,
77
- )
78
- return resp.choices[0].text.strip()
79
-
80
  if kind == "openai-chat":
81
  if OPENAI_CLIENT is None:
82
- return "⚠️ OPENAI_API_KEY not set. Add it in your Space secrets to use GPT-3.5."
83
  resp = OPENAI_CLIENT.chat.completions.create(
84
  model=model_id,
85
  messages=[{"role": "user", "content": prompt}],
86
- max_tokens=max_new_tokens,
87
- temperature=temperature,
88
- top_p=top_p,
89
  )
90
- return resp.choices[0].message.content.strip()
91
 
92
  return f"Unknown model kind: {kind}"
93
 
@@ -95,20 +85,22 @@ def generate(model_choice: str, prompt: str, max_new_tokens: int, temperature: f
95
  return f"❌ Error from {model_choice} ({model_id}): {str(e)}"
96
 
97
 
98
- def toggle_openai_visibility(choice):
99
- """Show a helpful banner if OpenAI key is missing and user picked an OpenAI model."""
100
  info = MODEL_MAP[choice]
101
- if "openai" in info["kind"] and OPENAI_CLIENT is None:
102
- return gr.update(value="⚠️ To use GPT-3 / GPT-3.5, set OPENAI_API_KEY in your Space secrets.", visible=True)
 
103
  return gr.update(visible=False)
104
 
105
 
106
- with gr.Blocks(title="GPT Playground: GPT-1 / GPT-2 / GPT-3 / GPT-3.5") as demo:
107
  gr.Markdown(
108
  """
109
- # Simple GPT Playground
110
- Type a prompt, pick a model, and generate a continuation or reply.
111
- **Local models** (GPT-1 / GPT-2) run with `transformers`. **OpenAI models** (GPT-3 / GPT-3.5) require `OPENAI_API_KEY`.
 
112
  """
113
  )
114
 
@@ -118,29 +110,19 @@ with gr.Blocks(title="GPT Playground: GPT-1 / GPT-2 / GPT-3 / GPT-3.5") as demo:
118
  with gr.Row():
119
  temperature = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="Temperature")
120
  top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p")
121
- seed = gr.Number(value=42, precision=0, label="Seed (set ≥0 to fix sampling)")
122
 
123
  prompt = gr.Textbox(lines=6, label="Prompt", placeholder="Write a short story about a curious robot...")
124
-
125
- warning_md = gr.Markdown("", visible=False)
126
 
127
  generate_btn = gr.Button("Generate", variant="primary")
128
  output = gr.Textbox(lines=12, label="Output")
129
 
130
- model_choice.change(toggle_openai_visibility, inputs=[model_choice], outputs=[warning_md])
131
  generate_btn.click(
132
  generate,
133
  inputs=[model_choice, prompt, max_new_tokens, temperature, top_p, seed],
134
  outputs=[output],
135
  )
136
 
137
- gr.Markdown(
138
- """
139
- ---
140
- **Tips**
141
- - If you see an error on GPT-3: it may no longer be enabled on your account. Try GPT-3.5.
142
- - Local models here are the small baseline versions (`openai-gpt`, `gpt2`) to keep Spaces lightweight.
143
- """
144
- )
145
-
146
  demo.queue(max_size=16).launch()
 
4
  from transformers import pipeline
5
  from openai import OpenAI
6
 
7
+ # -------------------------
8
+ # Model choices
9
+ # -------------------------
10
  MODEL_OPTIONS = [
11
+ "GPT-1 (openai-gpt) - local",
12
+ "GPT-2 (gpt2) - local",
 
13
  "GPT-3.5 (gpt-3.5-turbo) - OpenAI",
14
  ]
15
 
16
  MODEL_MAP = {
17
+ "GPT-1 (openai-gpt) - local": {"kind": "hf", "id": "openai-gpt"},
18
+ "GPT-2 (gpt2) - local": {"kind": "hf", "id": "gpt2"},
 
19
  "GPT-3.5 (gpt-3.5-turbo) - OpenAI": {"kind": "openai-chat", "id": "gpt-3.5-turbo"},
20
  }
21
 
22
+ # Cache pipelines for HF models so we only load once
23
  HF_PIPELINES = {}
24
 
25
+ # OpenAI client (only if key exists)
26
  OPENAI_KEY = os.getenv("OPENAI_API_KEY")
27
  OPENAI_CLIENT = OpenAI(api_key=OPENAI_KEY) if OPENAI_KEY else None
28
 
29
 
30
  def get_hf_pipeline(model_id: str):
31
+ """Create/fetch a lightweight text-generation pipeline for CPU/GPU."""
32
  if model_id in HF_PIPELINES:
33
  return HF_PIPELINES[model_id]
34
 
35
  device = 0 if torch.cuda.is_available() else -1
36
+ gen = pipeline(
37
+ "text-generation",
38
+ model=model_id,
39
+ device=device,
40
+ )
41
  HF_PIPELINES[model_id] = gen
42
  return gen
43
 
44
 
45
+ def generate(model_choice, prompt, max_new_tokens, temperature, top_p, seed):
46
  if not prompt.strip():
47
  return "Please enter a prompt."
48
 
 
50
  kind = info["kind"]
51
  model_id = info["id"]
52
 
53
+ if seed is not None and int(seed) >= 0:
54
+ torch.manual_seed(int(seed))
55
 
56
  try:
57
  if kind == "hf":
58
  gen = get_hf_pipeline(model_id)
59
  out = gen(
60
  prompt,
61
+ max_new_tokens=int(max_new_tokens),
62
  do_sample=temperature > 0,
63
+ temperature=max(1e-6, float(temperature)),
64
+ top_p=float(top_p),
65
  pad_token_id=gen.tokenizer.eos_token_id,
66
+ return_full_text=False, # don't echo the prompt
67
  )
68
  return out[0]["generated_text"]
69
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  if kind == "openai-chat":
71
  if OPENAI_CLIENT is None:
72
+ return "⚠️ To use GPT-3.5, set OPENAI_API_KEY in your Space (Settings Variables & secrets)."
73
  resp = OPENAI_CLIENT.chat.completions.create(
74
  model=model_id,
75
  messages=[{"role": "user", "content": prompt}],
76
+ max_tokens=int(max_new_tokens),
77
+ temperature=float(temperature),
78
+ top_p=float(top_p),
79
  )
80
+ return (resp.choices[0].message.content or "").strip()
81
 
82
  return f"Unknown model kind: {kind}"
83
 
 
85
  return f"❌ Error from {model_choice} ({model_id}): {str(e)}"
86
 
87
 
88
+ def maybe_warn(choice):
89
+ """Show a small banner if user picked GPT-3.5 without an API key set."""
90
  info = MODEL_MAP[choice]
91
+ needs_key = (info["kind"] == "openai-chat") and (OPENAI_CLIENT is None)
92
+ if needs_key:
93
+ return gr.update(value="⚠️ GPT-3.5 requires OPENAI_API_KEY in Space secrets.", visible=True)
94
  return gr.update(visible=False)
95
 
96
 
97
+ with gr.Blocks(title="Mini GPT Playground") as demo:
98
  gr.Markdown(
99
  """
100
+ # Mini GPT Playground
101
+ Type a prompt and choose a model.
102
+ **Local (HF):** GPT-1 / GPT-2 runs in this Space container with `transformers`.
103
+ **OpenAI (API):** GPT-3.5 — requires `OPENAI_API_KEY`.
104
  """
105
  )
106
 
 
110
  with gr.Row():
111
  temperature = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="Temperature")
112
  top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p")
113
+ seed = gr.Number(value=42, precision=0, label="Seed (≥0 to fix sampling)")
114
 
115
  prompt = gr.Textbox(lines=6, label="Prompt", placeholder="Write a short story about a curious robot...")
116
+ warn = gr.Markdown("", visible=False)
 
117
 
118
  generate_btn = gr.Button("Generate", variant="primary")
119
  output = gr.Textbox(lines=12, label="Output")
120
 
121
+ model_choice.change(maybe_warn, inputs=[model_choice], outputs=[warn])
122
  generate_btn.click(
123
  generate,
124
  inputs=[model_choice, prompt, max_new_tokens, temperature, top_p, seed],
125
  outputs=[output],
126
  )
127
 
 
 
 
 
 
 
 
 
 
128
  demo.queue(max_size=16).launch()