tasal9's picture
fix: update LAST_METRICS type annotation and remove unused demo.load call
594495c
import os
import logging
import threading
import http.server
import socketserver
from functools import lru_cache
from typing import Optional
import gradio as gr
from transformers.pipelines import pipeline
from transformers import AutoTokenizer
import torch
import importlib
import time
# ---------------- Configuration ----------------
MODEL_ID = os.getenv("MODEL_ID", "tasal9/ZamAI-mT5-Pashto")
CACHE_DIR = os.getenv("HF_HOME", None) # optional cache dir for transformers
HEALTH_PORT = int(os.getenv("HEALTH_PORT", "8080"))
GRADIO_HOST = os.getenv("GRADIO_HOST", "0.0.0.0")
GRADIO_PORT = int(os.getenv("GRADIO_PORT", "7860"))
DEFAULT_MAX_NEW_TOKENS = int(os.getenv("DEFAULT_MAX_NEW_TOKENS", "128"))
ECHO_MODE = os.getenv("ECHO_MODE", "off").lower() # default env; UI can override at runtime
OFFLINE_FLAG = os.getenv("OFFLINE", "0").lower() in {"1", "true", "yes"}
if OFFLINE_FLAG:
os.environ["HF_HUB_OFFLINE"] = "1"
def _log_cache_env():
try:
import huggingface_hub as _hub
hub_cache = getattr(_hub.constants, 'HF_HUB_CACHE', None)
except Exception:
hub_cache = None
logging.info(
"Cache config: HF_HOME=%s TRANSFORMERS_CACHE=%s HF_HUB_OFFLINE=%s hub_cache=%s",
os.getenv("HF_HOME"), os.getenv("TRANSFORMERS_CACHE"), os.getenv("HF_HUB_OFFLINE"), hub_cache
)
_log_cache_env()
# ---------------- Logging ----------------
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(level=LOG_LEVEL, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("zamai-app")
# Metrics storage for last real generation
LAST_METRICS: dict[str, float | int | str | None] = {
"latency_sec": None,
"input_tokens": None,
"output_tokens": None,
"num_sequences": None,
"mode": None,
}
# ---------------- Utilities ----------------
SAMPLE_INSTRUCTIONS = [
"په پښتو کې د خپل نوم او د عمر معلومات ولیکئ.",
"د هوا د حالت په اړه لنډ راپور ورکړئ.",
"په پښتو کې یوه لنډه کیسه ولیکئ چې د ښوونځي د ژوند په اړه وي.",
"د خپلو ملګرو لپاره د یوې کوچنۍ پیغام ولیکئ.",
"په پښتو کې د خپل خوښې خواړه تشریح کړئ او ووایاست ولې یې خوښوی.",
"د خپلې سیمې د تاریخي ځایونو په اړه لنډ معلومات ورکړئ.",
"یو ورځني کارنامه ولیکئ چې په کور کې څه کارونه ترسره کوئ."
]
def _start_health_server(port: int):
"""Start a tiny HTTP server that responds 200 to /health on a background thread."""
class HealthHandler(http.server.SimpleHTTPRequestHandler):
def do_GET(self):
if self.path == "/health":
self.send_response(200)
self.send_header("Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"ok")
else:
self.send_response(404)
self.end_headers()
def _serve():
try:
with socketserver.TCPServer(("", int(port)), HealthHandler) as httpd:
logger.info("Health endpoint listening on port %s", port)
httpd.serve_forever()
except Exception as e:
logger.exception("Health server failed: %s", e)
t = threading.Thread(target=_serve, daemon=True)
t.start()
def _detect_device() -> int:
# return device id for transformers pipeline: -1 for CPU or 0..N for CUDA
try:
if torch.cuda.is_available():
logger.info("CUDA available; using GPU device 0")
return 0
except Exception:
logger.debug("torch.cuda check failed; falling back to CPU")
return -1
# ---------------- Generator (cached) ----------------
@lru_cache(maxsize=1)
def get_generator(model_id: str = MODEL_ID, cache_dir: Optional[str] = CACHE_DIR):
device = _detect_device()
logger.info("Loading tokenizer and model: %s (device=%s)", model_id, device)
tokenizer = None
local_model_path = None
try:
hf = importlib.import_module("huggingface_hub")
snapshot_download = getattr(hf, "snapshot_download", None)
if snapshot_download:
try:
logger.info("Attempting to snapshot_download model %s to cache_dir=%s", model_id, cache_dir)
local_model_path = snapshot_download(repo_id=model_id, cache_dir=cache_dir, repo_type="model")
if local_model_path:
local_model_path = str(local_model_path)
logger.info("Model snapshot downloaded to %s", local_model_path)
except Exception as e:
logger.warning("snapshot_download failed for %s: %s", model_id, e)
local_model_path = None
except Exception:
logger.debug("huggingface_hub not available; falling back to AutoTokenizer.from_pretrained")
try:
if local_model_path:
tokenizer = AutoTokenizer.from_pretrained(local_model_path, use_fast=False, cache_dir=cache_dir)
else:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, cache_dir=cache_dir)
logger.info("Loaded tokenizer for %s", model_id)
except Exception as e2:
logger.exception("Failed to load tokenizer for %s: %s", model_id, e2)
raise
gen = pipeline(
"text2text-generation",
model=model_id,
tokenizer=tokenizer,
device=device,
)
return gen
def predict(instruction: str,
input_text: str,
max_new_tokens: int,
num_beams: int,
do_sample: bool,
temperature: float,
top_p: float,
num_return_sequences: int,
mode: str):
"""Generate text using the cached pipeline and return output or error message."""
if not instruction or not instruction.strip():
return "⚠️ مهرباني وکړئ یوه لارښوونه ولیکئ."
def build_prompt() -> str:
base = instruction.strip()
if input_text and input_text.strip():
return base + "\n" + input_text.strip()
return base
prompt = build_prompt()
active_mode = (mode or "").strip().lower() or ECHO_MODE
if active_mode in ("echo", "useless"):
if active_mode == "echo":
return f"### Prompt\n\n````\n{prompt}\n````\n\n### Output\n\n````\n{prompt}\n````"
return f"### Prompt\n\n````\n{prompt}\n````\n\n### Output\n\nThis is a useless placeholder response."
allowed_keys = {"max_new_tokens", "num_beams", "do_sample", "temperature", "top_p", "num_return_sequences"}
start = time.time()
try:
gen = get_generator()
raw_kwargs = {
"max_new_tokens": int(max_new_tokens),
"num_beams": int(num_beams) if not do_sample else 1,
"do_sample": bool(do_sample),
"temperature": float(temperature),
"top_p": float(top_p),
"num_return_sequences": max(1, int(num_return_sequences)),
}
gen_kwargs = {k: v for k, v in raw_kwargs.items() if k in allowed_keys}
outputs = gen(prompt, **gen_kwargs)
if not isinstance(outputs, list):
outputs = [outputs]
texts = []
for out in outputs:
if isinstance(out, dict):
text = out.get("generated_text", "").strip()
else:
text = str(out).strip()
if text:
texts.append(text)
if not texts:
LAST_METRICS.update({
"latency_sec": round(time.time() - start, 3),
"input_tokens": None,
"output_tokens": 0,
"num_sequences": 0,
"mode": active_mode,
})
return f"### Prompt\n\n````\n{prompt}\n````\n\n### Output\n\n⚠️ No response generated."
joined = "\n\n---\n\n".join(texts)
# Basic token counting via whitespace split (approximate)
input_tokens = len(prompt.split())
output_tokens = sum(len(t.split()) for t in texts)
LAST_METRICS.update({
"latency_sec": round(time.time() - start, 3),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"num_sequences": len(texts),
"mode": active_mode,
})
metrics_md = f"\n\n### Metrics\n- Latency: {LAST_METRICS['latency_sec']}s\n- Input tokens (approx): {input_tokens}\n- Output tokens (approx): {output_tokens}\n- Sequences: {len(texts)}"
return f"### Prompt\n\n````\n{prompt}\n````\n\n### Output\n\n{joined}{metrics_md}"
except Exception as e:
logger.exception("Generation failed: %s", e)
return f"⚠️ Generation failed: {e}"
def build_ui():
with gr.Blocks() as demo:
device_label = "GPU" if _detect_device() != -1 else "CPU"
gr.Markdown(
f"""
# ZamAI mT5 Pashto Demo
اپلیکیشن **ZamAI-mT5-Pashto** د پښتو لارښوونو لپاره.
**Device:** {device_label} | **Env Mode:** {ECHO_MODE} | **Offline:** {os.getenv('HF_HUB_OFFLINE','0')}
که د موډ بدلول غواړئ لاندې د Mode selector څخه استفاده وکړئ.
"""
)
with gr.Row():
with gr.Column(scale=2):
instruction_dropdown = gr.Dropdown(
choices=SAMPLE_INSTRUCTIONS,
label="نمونې لارښوونې",
value=SAMPLE_INSTRUCTIONS[0],
interactive=True,
)
instruction_textbox = gr.Textbox(
lines=3,
placeholder="دلته لارښوونه ولیکئ...",
label="لارښوونه",
)
input_text = gr.Textbox(lines=2, placeholder="اختیاري متن...", label="متن")
output = gr.Markdown(label="ځواب")
generate_btn = gr.Button("جوړول", variant="primary")
mode_selector = gr.Dropdown(
choices=["off", "echo", "useless"],
value=ECHO_MODE,
label="Mode (off=real, echo=return prompt, useless=fixed)",
interactive=True,
)
status_box = gr.Markdown(value="Loading status pending...", label="Status")
refresh_status = gr.Button("Refresh Status")
with gr.Column(scale=1):
gr.Markdown("### د تولید تنظیمات")
max_new_tokens = gr.Slider(16, 512, value=DEFAULT_MAX_NEW_TOKENS, step=1, label="اعظمي نوي ټوکنونه (max_new_tokens)")
num_beams = gr.Slider(1, 8, value=2, step=1, label="شمیر شعاعونه (num_beams)")
do_sample = gr.Checkbox(label="نمونې فعال کړئ (do_sample)", value=True)
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="تودوخه (temperature)")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
num_return_sequences = gr.Slider(1, 4, value=1, step=1, label="د راګرځېدونکو تسلسلو شمېر")
instruction_dropdown.change(lambda x: x, inputs=instruction_dropdown, outputs=instruction_textbox)
def refresh():
base = f"**Device:** {'GPU' if _detect_device() != -1 else 'CPU'} | **Offline:** {os.getenv('HF_HUB_OFFLINE','0')} | **Env Mode:** {ECHO_MODE}"
if LAST_METRICS.get('latency_sec') is not None:
base += (f"<br>**Last Gen:** latency={LAST_METRICS['latency_sec']}s, "
f"in≈{LAST_METRICS['input_tokens']}, out≈{LAST_METRICS['output_tokens']}, seqs={LAST_METRICS['num_sequences']}")
return base
refresh_status.click(fn=refresh, inputs=None, outputs=status_box)
generate_btn.click(
fn=predict,
inputs=[instruction_textbox, input_text, max_new_tokens, num_beams, do_sample, temperature, top_p, num_return_sequences, mode_selector],
outputs=output,
)
# Model load banner shown after interface loads (async)
def _post_load():
return "✅ Model interface ready. If this is the first run and model wasn't cached, initial generation may still warm up."
demo.load(_post_load, inputs=None, outputs=status_box)
return demo
if __name__ == "__main__":
logger.info("Starting ZamAI mT5 Pashto Demo (model=%s)", MODEL_ID)
try:
_start_health_server(HEALTH_PORT)
except Exception:
logger.exception("Failed to start health server")
demo = build_ui()
demo.launch(server_name=GRADIO_HOST, server_port=GRADIO_PORT)
logging.info("HF_HOME=%s TRANSFORMERS_CACHE=%s", os.getenv("HF_HOME"), os.getenv("TRANSFORMERS_CACHE"))