Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -292,3 +292,202 @@ def call_other_space():
|
|
292 |
|
293 |
except Exception as e:
|
294 |
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
|
293 |
except Exception as e:
|
294 |
return {"error": str(e)}
|
295 |
+
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
# ========== TRAIN CONFIGURATION ==========
|
303 |
+
REPO_ID = "rahul7star/ohamlab"
|
304 |
+
FOLDER_IN_REPO = "filter-demo/upload_20250708_041329_9c5c81"
|
305 |
+
CONCEPT_SENTENCE = "ohamlab style"
|
306 |
+
LORA_NAME = "ohami_filter_autorun"
|
307 |
+
|
308 |
+
# ========== FASTAPI APP ==========
|
309 |
+
|
310 |
+
|
311 |
+
# ========== HELPERS ==========
|
312 |
+
def create_dataset(images, *captions):
|
313 |
+
destination_folder = f"datasets_{uuid.uuid4()}"
|
314 |
+
os.makedirs(destination_folder, exist_ok=True)
|
315 |
+
|
316 |
+
jsonl_file_path = os.path.join(destination_folder, "metadata.jsonl")
|
317 |
+
with open(jsonl_file_path, "a") as jsonl_file:
|
318 |
+
for index, image in enumerate(images):
|
319 |
+
new_image_path = shutil.copy(str(image), destination_folder)
|
320 |
+
caption = captions[index]
|
321 |
+
file_name = os.path.basename(new_image_path)
|
322 |
+
data = {"file_name": file_name, "prompt": caption}
|
323 |
+
jsonl_file.write(json.dumps(data) + "\n")
|
324 |
+
|
325 |
+
return destination_folder
|
326 |
+
|
327 |
+
def recursive_update(d, u):
|
328 |
+
for k, v in u.items():
|
329 |
+
if isinstance(v, dict) and v:
|
330 |
+
d[k] = recursive_update(d.get(k, {}), v)
|
331 |
+
else:
|
332 |
+
d[k] = v
|
333 |
+
return d
|
334 |
+
|
335 |
+
def start_training(
|
336 |
+
lora_name,
|
337 |
+
concept_sentence,
|
338 |
+
steps,
|
339 |
+
lr,
|
340 |
+
rank,
|
341 |
+
model_to_train,
|
342 |
+
low_vram,
|
343 |
+
dataset_folder,
|
344 |
+
sample_1,
|
345 |
+
sample_2,
|
346 |
+
sample_3,
|
347 |
+
use_more_advanced_options,
|
348 |
+
more_advanced_options,
|
349 |
+
):
|
350 |
+
try:
|
351 |
+
user = whoami()
|
352 |
+
username = user.get("name", "anonymous")
|
353 |
+
push_to_hub = True
|
354 |
+
except:
|
355 |
+
username = "anonymous"
|
356 |
+
push_to_hub = False
|
357 |
+
|
358 |
+
slugged_lora_name = lora_name.replace(" ", "_").lower()
|
359 |
+
|
360 |
+
# Load base config
|
361 |
+
config = {
|
362 |
+
"config": {
|
363 |
+
"name": slugged_lora_name,
|
364 |
+
"process": [
|
365 |
+
{
|
366 |
+
"model": {
|
367 |
+
"low_vram": low_vram,
|
368 |
+
"is_flux": True,
|
369 |
+
"quantize": True,
|
370 |
+
"name_or_path": "black-forest-labs/FLUX.1-dev"
|
371 |
+
},
|
372 |
+
"network": {
|
373 |
+
"linear": rank,
|
374 |
+
"linear_alpha": rank,
|
375 |
+
"type": "lora"
|
376 |
+
},
|
377 |
+
"train": {
|
378 |
+
"steps": steps,
|
379 |
+
"lr": lr,
|
380 |
+
"skip_first_sample": True,
|
381 |
+
"batch_size": 1,
|
382 |
+
"dtype": "bf16",
|
383 |
+
"gradient_accumulation_steps": 1,
|
384 |
+
"gradient_checkpointing": True,
|
385 |
+
"noise_scheduler": "flowmatch",
|
386 |
+
"optimizer": "adamw8bit",
|
387 |
+
"ema_config": {
|
388 |
+
"use_ema": True,
|
389 |
+
"ema_decay": 0.99
|
390 |
+
}
|
391 |
+
},
|
392 |
+
"datasets": [
|
393 |
+
{"folder_path": dataset_folder}
|
394 |
+
],
|
395 |
+
"save": {
|
396 |
+
"dtype": "float16",
|
397 |
+
"save_every": 10000,
|
398 |
+
"push_to_hub": push_to_hub,
|
399 |
+
"hf_repo_id": f"{username}/{slugged_lora_name}",
|
400 |
+
"hf_private": True,
|
401 |
+
"max_step_saves_to_keep": 4
|
402 |
+
},
|
403 |
+
"sample": {
|
404 |
+
"guidance_scale": 3.5,
|
405 |
+
"sample_every": steps,
|
406 |
+
"sample_steps": 28,
|
407 |
+
"width": 1024,
|
408 |
+
"height": 1024,
|
409 |
+
"walk_seed": True,
|
410 |
+
"seed": 42,
|
411 |
+
"sampler": "flowmatch",
|
412 |
+
"prompts": [p for p in [sample_1, sample_2, sample_3] if p]
|
413 |
+
},
|
414 |
+
"trigger_word": concept_sentence
|
415 |
+
}
|
416 |
+
]
|
417 |
+
}
|
418 |
+
}
|
419 |
+
|
420 |
+
# Apply advanced YAML overrides if any
|
421 |
+
if use_more_advanced_options and more_advanced_options:
|
422 |
+
advanced_config = yaml.safe_load(more_advanced_options)
|
423 |
+
config["config"]["process"][0] = recursive_update(config["config"]["process"][0], advanced_config)
|
424 |
+
|
425 |
+
# Save YAML config
|
426 |
+
os.makedirs("tmp_configs", exist_ok=True)
|
427 |
+
config_path = f"tmp_configs/{uuid.uuid4()}_{slugged_lora_name}.yaml"
|
428 |
+
with open(config_path, "w") as f:
|
429 |
+
yaml.dump(config, f)
|
430 |
+
|
431 |
+
# Simulate training
|
432 |
+
print(f"[INFO] Starting training with config: {config_path}")
|
433 |
+
print(json.dumps(config, indent=2))
|
434 |
+
return f"Training started successfully with config: {config_path}"
|
435 |
+
|
436 |
+
# ========== MAIN ENDPOINT ==========
|
437 |
+
@app.post("/train-from-hf")
|
438 |
+
def auto_run_lora_from_repo():
|
439 |
+
try:
|
440 |
+
local_dir = Path(f"/tmp/{LORA_NAME}-{uuid.uuid4()}")
|
441 |
+
os.makedirs(local_dir, exist_ok=True)
|
442 |
+
|
443 |
+
hf_hub_download(
|
444 |
+
repo_id=REPO_ID,
|
445 |
+
repo_type="dataset",
|
446 |
+
subfolder=FOLDER_IN_REPO,
|
447 |
+
local_dir=local_dir,
|
448 |
+
local_dir_use_symlinks=False,
|
449 |
+
force_download=False,
|
450 |
+
etag_timeout=10,
|
451 |
+
allow_patterns=["*.jpg", "*.png", "*.jpeg"],
|
452 |
+
)
|
453 |
+
|
454 |
+
image_dir = local_dir / FOLDER_IN_REPO
|
455 |
+
image_paths = list(image_dir.rglob("*.jpg")) + list(image_dir.rglob("*.jpeg")) + list(image_dir.rglob("*.png"))
|
456 |
+
|
457 |
+
if not image_paths:
|
458 |
+
return JSONResponse(status_code=400, content={"error": "No images found in the HF repo folder."})
|
459 |
+
|
460 |
+
captions = [
|
461 |
+
f"Autogenerated caption for {img.stem} in the {CONCEPT_SENTENCE} [trigger]" for img in image_paths
|
462 |
+
]
|
463 |
+
|
464 |
+
dataset_path = create_dataset(image_paths, *captions)
|
465 |
+
|
466 |
+
result = start_training(
|
467 |
+
lora_name=LORA_NAME,
|
468 |
+
concept_sentence=CONCEPT_SENTENCE,
|
469 |
+
steps=1000,
|
470 |
+
lr=4e-4,
|
471 |
+
rank=16,
|
472 |
+
model_to_train="dev",
|
473 |
+
low_vram=True,
|
474 |
+
dataset_folder=dataset_path,
|
475 |
+
sample_1=f"A stylized portrait using {CONCEPT_SENTENCE}",
|
476 |
+
sample_2=f"A cat in the {CONCEPT_SENTENCE}",
|
477 |
+
sample_3=f"A selfie processed in {CONCEPT_SENTENCE}",
|
478 |
+
use_more_advanced_options=True,
|
479 |
+
more_advanced_options="""
|
480 |
+
training:
|
481 |
+
seed: 42
|
482 |
+
precision: bf16
|
483 |
+
batch_size: 2
|
484 |
+
augmentation:
|
485 |
+
flip: true
|
486 |
+
color_jitter: true
|
487 |
+
"""
|
488 |
+
)
|
489 |
+
|
490 |
+
return {"message": result}
|
491 |
+
|
492 |
+
except Exception as e:
|
493 |
+
return JSONResponse(status_code=500, content={"error": str(e)})
|