tasal9 commited on
Commit
b88833e
·
verified ·
1 Parent(s): 0448449

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -30
app.py CHANGED
@@ -15,7 +15,7 @@ import importlib
15
 
16
  # ---------------- Configuration ----------------
17
  MODEL_ID = os.getenv("MODEL_ID", "tasal9/ZamAI-mT5-Pashto")
18
- CACHE_DIR = os.getenv("HF_HOME", None)
19
  HEALTH_PORT = int(os.getenv("HEALTH_PORT", "8080"))
20
  GRADIO_HOST = os.getenv("GRADIO_HOST", "0.0.0.0")
21
  GRADIO_PORT = int(os.getenv("GRADIO_PORT", "7860"))
@@ -23,7 +23,8 @@ DEFAULT_MAX_NEW_TOKENS = int(os.getenv("DEFAULT_MAX_NEW_TOKENS", "128"))
23
 
24
 
25
  # ---------------- Logging ----------------
26
- logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
 
27
  logger = logging.getLogger("zamai-app")
28
 
29
 
@@ -40,6 +41,7 @@ SAMPLE_INSTRUCTIONS = [
40
 
41
 
42
  def _start_health_server(port: int):
 
43
  class HealthHandler(http.server.SimpleHTTPRequestHandler):
44
  def do_GET(self):
45
  if self.path == "/health":
@@ -52,14 +54,19 @@ def _start_health_server(port: int):
52
  self.end_headers()
53
 
54
  def _serve():
55
- with socketserver.TCPServer(("", int(port)), HealthHandler) as httpd:
56
- logger.info("Health endpoint listening on port %s", port)
57
- httpd.serve_forever()
 
 
 
58
 
59
- threading.Thread(target=_serve, daemon=True).start()
 
60
 
61
 
62
  def _detect_device() -> int:
 
63
  try:
64
  if torch.cuda.is_available():
65
  logger.info("CUDA available; using GPU device 0")
@@ -75,8 +82,40 @@ def get_generator(model_id: str = MODEL_ID, cache_dir: Optional[str] = CACHE_DIR
75
  device = _detect_device()
76
  logger.info("Loading tokenizer and model: %s (device=%s)", model_id, device)
77
 
78
- tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir, use_fast=True)
79
- gen = pipeline("text2text-generation", model=model_id, tokenizer=tokenizer, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  return gen
81
 
82
 
@@ -88,28 +127,47 @@ def predict(instruction: str,
88
  temperature: float,
89
  top_p: float,
90
  num_return_sequences: int):
91
-
92
  if not instruction or not instruction.strip():
93
- return "⚠️ مهرباني وکړئ یوه لارښوونه ولیکئ."
94
 
95
- # Just concatenate instruction + input if provided
96
  prompt = instruction.strip()
97
- if input_text:
98
  prompt += "\n" + input_text.strip()
99
 
 
 
 
 
 
 
 
 
 
 
 
100
  try:
101
  gen = get_generator()
102
- outputs = gen(
103
- prompt,
104
- max_new_tokens=int(max_new_tokens),
105
- num_beams=int(num_beams) if not do_sample else 1,
106
- do_sample=do_sample,
107
- temperature=float(temperature),
108
- top_p=float(top_p),
109
- num_return_sequences=max(1, int(num_return_sequences)),
110
- )
111
-
112
- texts = [out["generated_text"].strip() for out in outputs]
 
 
 
 
 
 
 
 
113
  return "\n\n---\n\n".join(texts)
114
 
115
  except Exception as e:
@@ -119,7 +177,13 @@ def predict(instruction: str,
119
 
120
  def build_ui():
121
  with gr.Blocks() as demo:
122
- gr.Markdown("# ZamAI mT5 Pashto Demo")
 
 
 
 
 
 
123
 
124
  with gr.Row():
125
  with gr.Column(scale=2):
@@ -129,17 +193,21 @@ def build_ui():
129
  value=SAMPLE_INSTRUCTIONS[0],
130
  interactive=True,
131
  )
132
- instruction_textbox = gr.Textbox(lines=3, placeholder="دلته لارښوونه ولیکئ...", label="لارښوونه")
 
 
 
 
133
  input_text = gr.Textbox(lines=2, placeholder="اختیاري متن...", label="متن")
134
  output = gr.Textbox(label="ځواب", interactive=False, lines=8)
135
  generate_btn = gr.Button("جوړول", variant="primary")
136
 
137
  with gr.Column(scale=1):
138
  gr.Markdown("### د تولید تنظیمات")
139
- max_new_tokens = gr.Slider(16, 512, value=DEFAULT_MAX_NEW_TOKENS, step=1, label="اعظمي نوي ټوکنونه")
140
- num_beams = gr.Slider(1, 8, value=2, step=1, label="شمیر شعاعونه")
141
- do_sample = gr.Checkbox(label="نمونې فعال کړئ", value=True)
142
- temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="تودوخه")
143
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
144
  num_return_sequences = gr.Slider(1, 4, value=1, step=1, label="د راګرځېدونکو تسلسلو شمېر")
145
 
@@ -156,6 +224,10 @@ def build_ui():
156
 
157
  if __name__ == "__main__":
158
  logger.info("Starting ZamAI mT5 Pashto Demo (model=%s)", MODEL_ID)
159
- _start_health_server(HEALTH_PORT)
 
 
 
 
160
  demo = build_ui()
161
  demo.launch(server_name=GRADIO_HOST, server_port=GRADIO_PORT)
 
15
 
16
  # ---------------- Configuration ----------------
17
  MODEL_ID = os.getenv("MODEL_ID", "tasal9/ZamAI-mT5-Pashto")
18
+ CACHE_DIR = os.getenv("HF_HOME", None) # optional cache dir for transformers
19
  HEALTH_PORT = int(os.getenv("HEALTH_PORT", "8080"))
20
  GRADIO_HOST = os.getenv("GRADIO_HOST", "0.0.0.0")
21
  GRADIO_PORT = int(os.getenv("GRADIO_PORT", "7860"))
 
23
 
24
 
25
  # ---------------- Logging ----------------
26
+ LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
27
+ logging.basicConfig(level=LOG_LEVEL, format="%(asctime)s %(levelname)s %(message)s")
28
  logger = logging.getLogger("zamai-app")
29
 
30
 
 
41
 
42
 
43
  def _start_health_server(port: int):
44
+ """Start a tiny HTTP server that responds 200 to /health on a background thread."""
45
  class HealthHandler(http.server.SimpleHTTPRequestHandler):
46
  def do_GET(self):
47
  if self.path == "/health":
 
54
  self.end_headers()
55
 
56
  def _serve():
57
+ try:
58
+ with socketserver.TCPServer(("", int(port)), HealthHandler) as httpd:
59
+ logger.info("Health endpoint listening on port %s", port)
60
+ httpd.serve_forever()
61
+ except Exception as e:
62
+ logger.exception("Health server failed: %s", e)
63
 
64
+ t = threading.Thread(target=_serve, daemon=True)
65
+ t.start()
66
 
67
 
68
  def _detect_device() -> int:
69
+ # return device id for transformers pipeline: -1 for CPU or 0..N for CUDA
70
  try:
71
  if torch.cuda.is_available():
72
  logger.info("CUDA available; using GPU device 0")
 
82
  device = _detect_device()
83
  logger.info("Loading tokenizer and model: %s (device=%s)", model_id, device)
84
 
85
+ tokenizer = None
86
+ local_model_path = None
87
+ try:
88
+ hf = importlib.import_module("huggingface_hub")
89
+ snapshot_download = getattr(hf, "snapshot_download", None)
90
+ if snapshot_download:
91
+ try:
92
+ logger.info("Attempting to snapshot_download model %s to cache_dir=%s", model_id, cache_dir)
93
+ local_model_path = snapshot_download(repo_id=model_id, cache_dir=cache_dir, repo_type="model")
94
+ if local_model_path:
95
+ local_model_path = str(local_model_path)
96
+ logger.info("Model snapshot downloaded to %s", local_model_path)
97
+ except Exception as e:
98
+ logger.warning("snapshot_download failed for %s: %s", model_id, e)
99
+ local_model_path = None
100
+ except Exception:
101
+ logger.debug("huggingface_hub not available; falling back to AutoTokenizer.from_pretrained")
102
+
103
+ try:
104
+ if local_model_path:
105
+ tokenizer = AutoTokenizer.from_pretrained(local_model_path, use_fast=False, cache_dir=cache_dir)
106
+ else:
107
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, cache_dir=cache_dir)
108
+ logger.info("Loaded tokenizer for %s", model_id)
109
+ except Exception as e2:
110
+ logger.exception("Failed to load tokenizer for %s: %s", model_id, e2)
111
+ raise
112
+
113
+ gen = pipeline(
114
+ "text2text-generation",
115
+ model=model_id,
116
+ tokenizer=tokenizer,
117
+ device=device,
118
+ )
119
  return gen
120
 
121
 
 
127
  temperature: float,
128
  top_p: float,
129
  num_return_sequences: int):
130
+ """Generate text using the cached pipeline and return output or error message."""
131
  if not instruction or not instruction.strip():
132
+ return "⚠️ مهرباني وکړئ یوه لارښوونه ولیکئ." # please provide an instruction
133
 
134
+ # Build a simple prompt: instruction (+ input if provided)
135
  prompt = instruction.strip()
136
+ if input_text and input_text.strip():
137
  prompt += "\n" + input_text.strip()
138
 
139
+ def _filter_generation_kwargs(kwargs: dict) -> dict:
140
+ allowed = {
141
+ "max_new_tokens",
142
+ "num_beams",
143
+ "do_sample",
144
+ "temperature",
145
+ "top_p",
146
+ "num_return_sequences",
147
+ }
148
+ return {k: v for k, v in kwargs.items() if k in allowed}
149
+
150
  try:
151
  gen = get_generator()
152
+ gen_kwargs = {
153
+ "max_new_tokens": int(max_new_tokens),
154
+ "num_beams": int(num_beams) if not do_sample else 1,
155
+ "do_sample": bool(do_sample),
156
+ "temperature": float(temperature),
157
+ "top_p": float(top_p),
158
+ "num_return_sequences": max(1, int(num_return_sequences)),
159
+ }
160
+
161
+ gen_kwargs = _filter_generation_kwargs(gen_kwargs)
162
+ outputs = gen(prompt, **gen_kwargs)
163
+
164
+ texts = []
165
+ for out in outputs if isinstance(outputs, list) else [outputs]:
166
+ text = out.get("generated_text", "").strip()
167
+ texts.append(text)
168
+
169
+ if not texts:
170
+ return "⚠️ No response generated."
171
  return "\n\n---\n\n".join(texts)
172
 
173
  except Exception as e:
 
177
 
178
  def build_ui():
179
  with gr.Blocks() as demo:
180
+ gr.Markdown(
181
+ """
182
+ # ZamAI mT5 Pashto Demo
183
+ اپلیکیشن **ZamAI-mT5-Pashto** د پښتو لارښوونو لپاره.
184
+ لاندې تنظیمات بدل کړئ او لارښوونه ولیکئ ترڅو ځواب ترلاسه کړئ.
185
+ """
186
+ )
187
 
188
  with gr.Row():
189
  with gr.Column(scale=2):
 
193
  value=SAMPLE_INSTRUCTIONS[0],
194
  interactive=True,
195
  )
196
+ instruction_textbox = gr.Textbox(
197
+ lines=3,
198
+ placeholder="دلته لارښوونه ولیکئ...",
199
+ label="لارښوونه",
200
+ )
201
  input_text = gr.Textbox(lines=2, placeholder="اختیاري متن...", label="متن")
202
  output = gr.Textbox(label="ځواب", interactive=False, lines=8)
203
  generate_btn = gr.Button("جوړول", variant="primary")
204
 
205
  with gr.Column(scale=1):
206
  gr.Markdown("### د تولید تنظیمات")
207
+ max_new_tokens = gr.Slider(16, 512, value=DEFAULT_MAX_NEW_TOKENS, step=1, label="اعظمي نوي ټوکنونه (max_new_tokens)")
208
+ num_beams = gr.Slider(1, 8, value=2, step=1, label="شمیر شعاعونه (num_beams)")
209
+ do_sample = gr.Checkbox(label="نمونې فعال کړئ (do_sample)", value=True)
210
+ temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="تودوخه (temperature)")
211
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
212
  num_return_sequences = gr.Slider(1, 4, value=1, step=1, label="د راګرځېدونکو تسلسلو شمېر")
213
 
 
224
 
225
  if __name__ == "__main__":
226
  logger.info("Starting ZamAI mT5 Pashto Demo (model=%s)", MODEL_ID)
227
+ try:
228
+ _start_health_server(HEALTH_PORT)
229
+ except Exception:
230
+ logger.exception("Failed to start health server")
231
+
232
  demo = build_ui()
233
  demo.launch(server_name=GRADIO_HOST, server_port=GRADIO_PORT)