Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import threading | |
import http.server | |
import socketserver | |
from functools import lru_cache | |
from typing import Optional | |
import gradio as gr | |
from transformers.pipelines import pipeline | |
from transformers import AutoTokenizer | |
import torch | |
import importlib | |
import time | |
# ---------------- Configuration ---------------- | |
MODEL_ID = os.getenv("MODEL_ID", "tasal9/ZamAI-mT5-Pashto") | |
CACHE_DIR = os.getenv("HF_HOME", None) # optional cache dir for transformers | |
HEALTH_PORT = int(os.getenv("HEALTH_PORT", "8080")) | |
GRADIO_HOST = os.getenv("GRADIO_HOST", "0.0.0.0") | |
GRADIO_PORT = int(os.getenv("GRADIO_PORT", "7860")) | |
DEFAULT_MAX_NEW_TOKENS = int(os.getenv("DEFAULT_MAX_NEW_TOKENS", "128")) | |
ECHO_MODE = os.getenv("ECHO_MODE", "off").lower() # default env; UI can override at runtime | |
OFFLINE_FLAG = os.getenv("OFFLINE", "0").lower() in {"1", "true", "yes"} | |
if OFFLINE_FLAG: | |
os.environ["HF_HUB_OFFLINE"] = "1" | |
def _log_cache_env(): | |
try: | |
import huggingface_hub as _hub | |
hub_cache = getattr(_hub.constants, 'HF_HUB_CACHE', None) | |
except Exception: | |
hub_cache = None | |
logging.info( | |
"Cache config: HF_HOME=%s TRANSFORMERS_CACHE=%s HF_HUB_OFFLINE=%s hub_cache=%s", | |
os.getenv("HF_HOME"), os.getenv("TRANSFORMERS_CACHE"), os.getenv("HF_HUB_OFFLINE"), hub_cache | |
) | |
_log_cache_env() | |
# ---------------- Logging ---------------- | |
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() | |
logging.basicConfig(level=LOG_LEVEL, format="%(asctime)s %(levelname)s %(message)s") | |
logger = logging.getLogger("zamai-app") | |
# Metrics storage for last real generation | |
LAST_METRICS: dict[str, float | int | str | None] = { | |
"latency_sec": None, | |
"input_tokens": None, | |
"output_tokens": None, | |
"num_sequences": None, | |
"mode": None, | |
} | |
# ---------------- Utilities ---------------- | |
SAMPLE_INSTRUCTIONS = [ | |
"په پښتو کې د خپل نوم او د عمر معلومات ولیکئ.", | |
"د هوا د حالت په اړه لنډ راپور ورکړئ.", | |
"په پښتو کې یوه لنډه کیسه ولیکئ چې د ښوونځي د ژوند په اړه وي.", | |
"د خپلو ملګرو لپاره د یوې کوچنۍ پیغام ولیکئ.", | |
"په پښتو کې د خپل خوښې خواړه تشریح کړئ او ووایاست ولې یې خوښوی.", | |
"د خپلې سیمې د تاریخي ځایونو په اړه لنډ معلومات ورکړئ.", | |
"یو ورځني کارنامه ولیکئ چې په کور کې څه کارونه ترسره کوئ." | |
] | |
def _start_health_server(port: int): | |
"""Start a tiny HTTP server that responds 200 to /health on a background thread.""" | |
class HealthHandler(http.server.SimpleHTTPRequestHandler): | |
def do_GET(self): | |
if self.path == "/health": | |
self.send_response(200) | |
self.send_header("Content-type", "text/plain") | |
self.end_headers() | |
self.wfile.write(b"ok") | |
else: | |
self.send_response(404) | |
self.end_headers() | |
def _serve(): | |
try: | |
with socketserver.TCPServer(("", int(port)), HealthHandler) as httpd: | |
logger.info("Health endpoint listening on port %s", port) | |
httpd.serve_forever() | |
except Exception as e: | |
logger.exception("Health server failed: %s", e) | |
t = threading.Thread(target=_serve, daemon=True) | |
t.start() | |
def _detect_device() -> int: | |
# return device id for transformers pipeline: -1 for CPU or 0..N for CUDA | |
try: | |
if torch.cuda.is_available(): | |
logger.info("CUDA available; using GPU device 0") | |
return 0 | |
except Exception: | |
logger.debug("torch.cuda check failed; falling back to CPU") | |
return -1 | |
# ---------------- Generator (cached) ---------------- | |
def get_generator(model_id: str = MODEL_ID, cache_dir: Optional[str] = CACHE_DIR): | |
device = _detect_device() | |
logger.info("Loading tokenizer and model: %s (device=%s)", model_id, device) | |
tokenizer = None | |
local_model_path = None | |
try: | |
hf = importlib.import_module("huggingface_hub") | |
snapshot_download = getattr(hf, "snapshot_download", None) | |
if snapshot_download: | |
try: | |
logger.info("Attempting to snapshot_download model %s to cache_dir=%s", model_id, cache_dir) | |
local_model_path = snapshot_download(repo_id=model_id, cache_dir=cache_dir, repo_type="model") | |
if local_model_path: | |
local_model_path = str(local_model_path) | |
logger.info("Model snapshot downloaded to %s", local_model_path) | |
except Exception as e: | |
logger.warning("snapshot_download failed for %s: %s", model_id, e) | |
local_model_path = None | |
except Exception: | |
logger.debug("huggingface_hub not available; falling back to AutoTokenizer.from_pretrained") | |
try: | |
if local_model_path: | |
tokenizer = AutoTokenizer.from_pretrained(local_model_path, use_fast=False, cache_dir=cache_dir) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, cache_dir=cache_dir) | |
logger.info("Loaded tokenizer for %s", model_id) | |
except Exception as e2: | |
logger.exception("Failed to load tokenizer for %s: %s", model_id, e2) | |
raise | |
gen = pipeline( | |
"text2text-generation", | |
model=model_id, | |
tokenizer=tokenizer, | |
device=device, | |
) | |
return gen | |
def predict(instruction: str, | |
input_text: str, | |
max_new_tokens: int, | |
num_beams: int, | |
do_sample: bool, | |
temperature: float, | |
top_p: float, | |
num_return_sequences: int, | |
mode: str): | |
"""Generate text using the cached pipeline and return output or error message.""" | |
if not instruction or not instruction.strip(): | |
return "⚠️ مهرباني وکړئ یوه لارښوونه ولیکئ." | |
def build_prompt() -> str: | |
base = instruction.strip() | |
if input_text and input_text.strip(): | |
return base + "\n" + input_text.strip() | |
return base | |
prompt = build_prompt() | |
active_mode = (mode or "").strip().lower() or ECHO_MODE | |
if active_mode in ("echo", "useless"): | |
if active_mode == "echo": | |
return f"### Prompt\n\n````\n{prompt}\n````\n\n### Output\n\n````\n{prompt}\n````" | |
return f"### Prompt\n\n````\n{prompt}\n````\n\n### Output\n\nThis is a useless placeholder response." | |
allowed_keys = {"max_new_tokens", "num_beams", "do_sample", "temperature", "top_p", "num_return_sequences"} | |
start = time.time() | |
try: | |
gen = get_generator() | |
raw_kwargs = { | |
"max_new_tokens": int(max_new_tokens), | |
"num_beams": int(num_beams) if not do_sample else 1, | |
"do_sample": bool(do_sample), | |
"temperature": float(temperature), | |
"top_p": float(top_p), | |
"num_return_sequences": max(1, int(num_return_sequences)), | |
} | |
gen_kwargs = {k: v for k, v in raw_kwargs.items() if k in allowed_keys} | |
outputs = gen(prompt, **gen_kwargs) | |
if not isinstance(outputs, list): | |
outputs = [outputs] | |
texts = [] | |
for out in outputs: | |
if isinstance(out, dict): | |
text = out.get("generated_text", "").strip() | |
else: | |
text = str(out).strip() | |
if text: | |
texts.append(text) | |
if not texts: | |
LAST_METRICS.update({ | |
"latency_sec": round(time.time() - start, 3), | |
"input_tokens": None, | |
"output_tokens": 0, | |
"num_sequences": 0, | |
"mode": active_mode, | |
}) | |
return f"### Prompt\n\n````\n{prompt}\n````\n\n### Output\n\n⚠️ No response generated." | |
joined = "\n\n---\n\n".join(texts) | |
# Basic token counting via whitespace split (approximate) | |
input_tokens = len(prompt.split()) | |
output_tokens = sum(len(t.split()) for t in texts) | |
LAST_METRICS.update({ | |
"latency_sec": round(time.time() - start, 3), | |
"input_tokens": input_tokens, | |
"output_tokens": output_tokens, | |
"num_sequences": len(texts), | |
"mode": active_mode, | |
}) | |
metrics_md = f"\n\n### Metrics\n- Latency: {LAST_METRICS['latency_sec']}s\n- Input tokens (approx): {input_tokens}\n- Output tokens (approx): {output_tokens}\n- Sequences: {len(texts)}" | |
return f"### Prompt\n\n````\n{prompt}\n````\n\n### Output\n\n{joined}{metrics_md}" | |
except Exception as e: | |
logger.exception("Generation failed: %s", e) | |
return f"⚠️ Generation failed: {e}" | |
def build_ui(): | |
with gr.Blocks() as demo: | |
device_label = "GPU" if _detect_device() != -1 else "CPU" | |
gr.Markdown( | |
f""" | |
# ZamAI mT5 Pashto Demo | |
اپلیکیشن **ZamAI-mT5-Pashto** د پښتو لارښوونو لپاره. | |
**Device:** {device_label} | **Env Mode:** {ECHO_MODE} | **Offline:** {os.getenv('HF_HUB_OFFLINE','0')} | |
که د موډ بدلول غواړئ لاندې د Mode selector څخه استفاده وکړئ. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
instruction_dropdown = gr.Dropdown( | |
choices=SAMPLE_INSTRUCTIONS, | |
label="نمونې لارښوونې", | |
value=SAMPLE_INSTRUCTIONS[0], | |
interactive=True, | |
) | |
instruction_textbox = gr.Textbox( | |
lines=3, | |
placeholder="دلته لارښوونه ولیکئ...", | |
label="لارښوونه", | |
) | |
input_text = gr.Textbox(lines=2, placeholder="اختیاري متن...", label="متن") | |
output = gr.Markdown(label="ځواب") | |
generate_btn = gr.Button("جوړول", variant="primary") | |
mode_selector = gr.Dropdown( | |
choices=["off", "echo", "useless"], | |
value=ECHO_MODE, | |
label="Mode (off=real, echo=return prompt, useless=fixed)", | |
interactive=True, | |
) | |
status_box = gr.Markdown(value="Loading status pending...", label="Status") | |
refresh_status = gr.Button("Refresh Status") | |
with gr.Column(scale=1): | |
gr.Markdown("### د تولید تنظیمات") | |
max_new_tokens = gr.Slider(16, 512, value=DEFAULT_MAX_NEW_TOKENS, step=1, label="اعظمي نوي ټوکنونه (max_new_tokens)") | |
num_beams = gr.Slider(1, 8, value=2, step=1, label="شمیر شعاعونه (num_beams)") | |
do_sample = gr.Checkbox(label="نمونې فعال کړئ (do_sample)", value=True) | |
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="تودوخه (temperature)") | |
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p") | |
num_return_sequences = gr.Slider(1, 4, value=1, step=1, label="د راګرځېدونکو تسلسلو شمېر") | |
instruction_dropdown.change(lambda x: x, inputs=instruction_dropdown, outputs=instruction_textbox) | |
def refresh(): | |
base = f"**Device:** {'GPU' if _detect_device() != -1 else 'CPU'} | **Offline:** {os.getenv('HF_HUB_OFFLINE','0')} | **Env Mode:** {ECHO_MODE}" | |
if LAST_METRICS.get('latency_sec') is not None: | |
base += (f"<br>**Last Gen:** latency={LAST_METRICS['latency_sec']}s, " | |
f"in≈{LAST_METRICS['input_tokens']}, out≈{LAST_METRICS['output_tokens']}, seqs={LAST_METRICS['num_sequences']}") | |
return base | |
refresh_status.click(fn=refresh, inputs=None, outputs=status_box) | |
generate_btn.click( | |
fn=predict, | |
inputs=[instruction_textbox, input_text, max_new_tokens, num_beams, do_sample, temperature, top_p, num_return_sequences, mode_selector], | |
outputs=output, | |
) | |
# Model load banner shown after interface loads (async) | |
def _post_load(): | |
return "✅ Model interface ready. If this is the first run and model wasn't cached, initial generation may still warm up." | |
demo.load(_post_load, inputs=None, outputs=status_box) | |
return demo | |
if __name__ == "__main__": | |
logger.info("Starting ZamAI mT5 Pashto Demo (model=%s)", MODEL_ID) | |
try: | |
_start_health_server(HEALTH_PORT) | |
except Exception: | |
logger.exception("Failed to start health server") | |
demo = build_ui() | |
demo.launch(server_name=GRADIO_HOST, server_port=GRADIO_PORT) | |
logging.info("HF_HOME=%s TRANSFORMERS_CACHE=%s", os.getenv("HF_HOME"), os.getenv("TRANSFORMERS_CACHE")) |