Commit
·
0577e3b
1
Parent(s):
30fdbbc
model/select endpoint warmup fix
Browse files
app.py
CHANGED
|
@@ -498,33 +498,37 @@ def model_checkpoints(repo_id: str, revision: str = "main"):
|
|
| 498 |
|
| 499 |
@app.post("/model/select")
|
| 500 |
def model_select(req: ModelSelect):
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
success, validation_result = model_selector.validate_selection(req)
|
| 505 |
if not success:
|
| 506 |
if "error" in validation_result:
|
| 507 |
raise HTTPException(status_code=400, detail=validation_result["error"])
|
| 508 |
return {"ok": False, **validation_result}
|
| 509 |
-
|
| 510 |
-
#
|
| 511 |
validation_result["active_jam"] = _any_jam_running()
|
| 512 |
-
|
| 513 |
-
#
|
| 514 |
if req.dry_run:
|
| 515 |
return {"ok": True, "dry_run": True, **validation_result}
|
| 516 |
|
| 517 |
-
# Handle jam policy
|
| 518 |
if _any_jam_running():
|
| 519 |
if req.stop_active:
|
| 520 |
_stop_all_jams()
|
| 521 |
else:
|
| 522 |
raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
|
| 523 |
|
| 524 |
-
#
|
| 525 |
env_changes = model_selector.prepare_env_changes(req, validation_result)
|
| 526 |
-
|
| 527 |
-
#
|
| 528 |
old_env = {
|
| 529 |
"MRT_SIZE": os.getenv("MRT_SIZE"),
|
| 530 |
"MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"),
|
|
@@ -532,52 +536,64 @@ def model_select(req: ModelSelect):
|
|
| 532 |
"MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"),
|
| 533 |
"MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"),
|
| 534 |
}
|
| 535 |
-
|
| 536 |
try:
|
| 537 |
-
# Apply
|
| 538 |
for key, value in env_changes.items():
|
| 539 |
if value is None:
|
| 540 |
os.environ.pop(key, None)
|
| 541 |
else:
|
| 542 |
os.environ[key] = str(value)
|
| 543 |
|
| 544 |
-
# Force model
|
| 545 |
with _MRT_LOCK:
|
| 546 |
_MRT = None
|
|
|
|
|
|
|
| 547 |
|
| 548 |
-
# Load finetune assets if requested
|
| 549 |
if req.sync_assets and validation_result.get("assets_repo"):
|
| 550 |
ok, msg = asset_manager.load_finetune_assets_from_hf(
|
| 551 |
-
validation_result["assets_repo"],
|
| 552 |
-
|
| 553 |
)
|
| 554 |
if ok:
|
| 555 |
-
# Sync globals after successful asset loading
|
| 556 |
_MEAN_EMBED = asset_manager.mean_embed
|
| 557 |
_CENTROIDS = asset_manager.centroids
|
| 558 |
_ASSETS_REPO_ID = asset_manager.assets_repo_id
|
|
|
|
|
|
|
| 559 |
|
| 560 |
-
#
|
|
|
|
|
|
|
| 561 |
if req.prewarm:
|
| 562 |
-
get_mrt()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
|
| 564 |
-
return {"ok": True, **validation_result}
|
| 565 |
-
|
| 566 |
except Exception as e:
|
| 567 |
-
#
|
| 568 |
for k, v in old_env.items():
|
| 569 |
if v is None:
|
| 570 |
os.environ.pop(k, None)
|
| 571 |
else:
|
| 572 |
os.environ[k] = v
|
|
|
|
| 573 |
with _MRT_LOCK:
|
| 574 |
_MRT = None
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
pass
|
| 580 |
-
raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
|
| 581 |
|
| 582 |
|
| 583 |
|
|
|
|
| 498 |
|
| 499 |
@app.post("/model/select")
|
| 500 |
def model_select(req: ModelSelect):
|
| 501 |
+
"""
|
| 502 |
+
Swap model/checkpoint/assets. If req.prewarm is True, run the full bar-aligned warmup
|
| 503 |
+
(_mrt_warmup) synchronously so we only report warmed once the new model is actually ready.
|
| 504 |
+
"""
|
| 505 |
+
global _MRT, _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID, _WARMED
|
| 506 |
+
|
| 507 |
+
# 1) Validate the request (no side-effects)
|
| 508 |
success, validation_result = model_selector.validate_selection(req)
|
| 509 |
if not success:
|
| 510 |
if "error" in validation_result:
|
| 511 |
raise HTTPException(status_code=400, detail=validation_result["error"])
|
| 512 |
return {"ok": False, **validation_result}
|
| 513 |
+
|
| 514 |
+
# Augment response surface
|
| 515 |
validation_result["active_jam"] = _any_jam_running()
|
| 516 |
+
|
| 517 |
+
# Dry-run path
|
| 518 |
if req.dry_run:
|
| 519 |
return {"ok": True, "dry_run": True, **validation_result}
|
| 520 |
|
| 521 |
+
# 2) Handle jam policy
|
| 522 |
if _any_jam_running():
|
| 523 |
if req.stop_active:
|
| 524 |
_stop_all_jams()
|
| 525 |
else:
|
| 526 |
raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
|
| 527 |
|
| 528 |
+
# 3) Compute environment changes (no mutation yet)
|
| 529 |
env_changes = model_selector.prepare_env_changes(req, validation_result)
|
| 530 |
+
|
| 531 |
+
# Keep current env for rollback
|
| 532 |
old_env = {
|
| 533 |
"MRT_SIZE": os.getenv("MRT_SIZE"),
|
| 534 |
"MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"),
|
|
|
|
| 536 |
"MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"),
|
| 537 |
"MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"),
|
| 538 |
}
|
| 539 |
+
|
| 540 |
try:
|
| 541 |
+
# 4) Apply env atomically
|
| 542 |
for key, value in env_changes.items():
|
| 543 |
if value is None:
|
| 544 |
os.environ.pop(key, None)
|
| 545 |
else:
|
| 546 |
os.environ[key] = str(value)
|
| 547 |
|
| 548 |
+
# 5) Force rebuild of the model and reset warmup state
|
| 549 |
with _MRT_LOCK:
|
| 550 |
_MRT = None
|
| 551 |
+
with _WARMUP_LOCK:
|
| 552 |
+
_WARMED = False # ← critical: don't leak previous model's warmed state
|
| 553 |
|
| 554 |
+
# 6) Load finetune assets if requested (mean/centroids)
|
| 555 |
if req.sync_assets and validation_result.get("assets_repo"):
|
| 556 |
ok, msg = asset_manager.load_finetune_assets_from_hf(
|
| 557 |
+
validation_result["assets_repo"],
|
| 558 |
+
None # don't implicitly instantiate model here; we'll do it below
|
| 559 |
)
|
| 560 |
if ok:
|
|
|
|
| 561 |
_MEAN_EMBED = asset_manager.mean_embed
|
| 562 |
_CENTROIDS = asset_manager.centroids
|
| 563 |
_ASSETS_REPO_ID = asset_manager.assets_repo_id
|
| 564 |
+
else:
|
| 565 |
+
logging.warning("Asset sync skipped/failed: %s", msg)
|
| 566 |
|
| 567 |
+
# 7) Prewarm behavior:
|
| 568 |
+
# - If prewarm=True, run the *real* bar-aligned warmup synchronously.
|
| 569 |
+
# - This will instantiate the new MRT and set _WARMED=True on success.
|
| 570 |
if req.prewarm:
|
| 571 |
+
_mrt_warmup() # builds MRT internally via get_mrt(), runs generate_chunk, sets _WARMED
|
| 572 |
+
|
| 573 |
+
# Optional: if you want to always ensure MRT exists (even without prewarm), uncomment:
|
| 574 |
+
# else:
|
| 575 |
+
# _ = get_mrt()
|
| 576 |
+
|
| 577 |
+
return {
|
| 578 |
+
"ok": True,
|
| 579 |
+
**validation_result,
|
| 580 |
+
"warmup_done": bool(_WARMED),
|
| 581 |
+
}
|
| 582 |
|
|
|
|
|
|
|
| 583 |
except Exception as e:
|
| 584 |
+
# 8) Roll back env on failure
|
| 585 |
for k, v in old_env.items():
|
| 586 |
if v is None:
|
| 587 |
os.environ.pop(k, None)
|
| 588 |
else:
|
| 589 |
os.environ[k] = v
|
| 590 |
+
# Also reset model pointer & warmed flag to a safe state
|
| 591 |
with _MRT_LOCK:
|
| 592 |
_MRT = None
|
| 593 |
+
with _WARMUP_LOCK:
|
| 594 |
+
_WARMED = False
|
| 595 |
+
logging.exception("Model select failed: %s", e)
|
| 596 |
+
raise HTTPException(status_code=500, detail=f"Model select failed: {e}")
|
|
|
|
|
|
|
| 597 |
|
| 598 |
|
| 599 |
|