Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -504,32 +504,47 @@ def start_training(
|
|
| 504 |
@app.post("/train-from-hf")
|
| 505 |
def auto_run_lora_from_repo():
|
| 506 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
local_dir = Path(f"/tmp/{LORA_NAME}-{uuid.uuid4()}")
|
| 508 |
os.makedirs(local_dir, exist_ok=True)
|
| 509 |
|
|
|
|
| 510 |
snapshot_path = snapshot_download(
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
|
| 518 |
-
#
|
| 519 |
image_dir = Path(snapshot_path) / FOLDER_IN_REPO
|
| 520 |
|
| 521 |
-
|
| 522 |
-
image_paths = list(image_dir.rglob("*.jpg")) +
|
|
|
|
|
|
|
| 523 |
|
| 524 |
if not image_paths:
|
| 525 |
-
return JSONResponse(
|
|
|
|
|
|
|
|
|
|
| 526 |
|
|
|
|
| 527 |
captions = [
|
| 528 |
-
f"Autogenerated caption for {img.stem} in the {CONCEPT_SENTENCE} [trigger]"
|
|
|
|
| 529 |
]
|
| 530 |
|
|
|
|
| 531 |
dataset_path = create_dataset(image_paths, *captions)
|
| 532 |
|
|
|
|
| 533 |
result = start_training(
|
| 534 |
lora_name=LORA_NAME,
|
| 535 |
concept_sentence=CONCEPT_SENTENCE,
|
|
@@ -556,5 +571,7 @@ augmentation:
|
|
| 556 |
|
| 557 |
return {"message": result}
|
| 558 |
|
|
|
|
|
|
|
| 559 |
except Exception as e:
|
| 560 |
-
return JSONResponse(status_code=500, content={"error": str(e)})
|
|
|
|
| 504 |
@app.post("/train-from-hf")
|
| 505 |
def auto_run_lora_from_repo():
|
| 506 |
try:
|
| 507 |
+
# Set HF cache path if not already set
|
| 508 |
+
os.environ["HF_HOME"] = "/tmp/hf_cache"
|
| 509 |
+
os.makedirs("/tmp/hf_cache", exist_ok=True)
|
| 510 |
+
|
| 511 |
+
# Create temporary directory to hold downloaded files
|
| 512 |
local_dir = Path(f"/tmp/{LORA_NAME}-{uuid.uuid4()}")
|
| 513 |
os.makedirs(local_dir, exist_ok=True)
|
| 514 |
|
| 515 |
+
# Download snapshot from model repo using allow_patterns
|
| 516 |
snapshot_path = snapshot_download(
|
| 517 |
+
repo_id=REPO_ID,
|
| 518 |
+
repo_type="model",
|
| 519 |
+
local_dir=local_dir,
|
| 520 |
+
local_dir_use_symlinks=False,
|
| 521 |
+
allow_patterns=[f"{FOLDER_IN_REPO}/*"], # only that folder
|
| 522 |
+
)
|
| 523 |
|
| 524 |
+
# Target subfolder inside the snapshot
|
| 525 |
image_dir = Path(snapshot_path) / FOLDER_IN_REPO
|
| 526 |
|
| 527 |
+
# Collect all image files (recursively)
|
| 528 |
+
image_paths = list(image_dir.rglob("*.jpg")) + \
|
| 529 |
+
list(image_dir.rglob("*.jpeg")) + \
|
| 530 |
+
list(image_dir.rglob("*.png"))
|
| 531 |
|
| 532 |
if not image_paths:
|
| 533 |
+
return JSONResponse(
|
| 534 |
+
status_code=400,
|
| 535 |
+
content={"error": "No images found in the HF repo folder."}
|
| 536 |
+
)
|
| 537 |
|
| 538 |
+
# Create auto captions
|
| 539 |
captions = [
|
| 540 |
+
f"Autogenerated caption for {img.stem} in the {CONCEPT_SENTENCE} [trigger]"
|
| 541 |
+
for img in image_paths
|
| 542 |
]
|
| 543 |
|
| 544 |
+
# Prepare dataset
|
| 545 |
dataset_path = create_dataset(image_paths, *captions)
|
| 546 |
|
| 547 |
+
# Start training
|
| 548 |
result = start_training(
|
| 549 |
lora_name=LORA_NAME,
|
| 550 |
concept_sentence=CONCEPT_SENTENCE,
|
|
|
|
| 571 |
|
| 572 |
return {"message": result}
|
| 573 |
|
| 574 |
+
except PermissionError as pe:
|
| 575 |
+
return JSONResponse(status_code=500, content={"error": f"Permission denied: {pe}"})
|
| 576 |
except Exception as e:
|
| 577 |
+
return JSONResponse(status_code=500, content={"error": str(e)})
|