rahul7star commited on
Commit
8bafa56
·
verified ·
1 Parent(s): 41bbeeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -0
app.py CHANGED
@@ -292,3 +292,202 @@ def call_other_space():
292
 
293
  except Exception as e:
294
  return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  except Exception as e:
294
  return {"error": str(e)}
295
+
296
+
297
+
298
+
299
+
300
+
301
+
302
+ # ========== TRAIN CONFIGURATION ==========
303
+ REPO_ID = "rahul7star/ohamlab"
304
+ FOLDER_IN_REPO = "filter-demo/upload_20250708_041329_9c5c81"
305
+ CONCEPT_SENTENCE = "ohamlab style"
306
+ LORA_NAME = "ohami_filter_autorun"
307
+
308
+ # ========== FASTAPI APP ==========
309
+
310
+
311
+ # ========== HELPERS ==========
312
+ def create_dataset(images, *captions):
313
+ destination_folder = f"datasets_{uuid.uuid4()}"
314
+ os.makedirs(destination_folder, exist_ok=True)
315
+
316
+ jsonl_file_path = os.path.join(destination_folder, "metadata.jsonl")
317
+ with open(jsonl_file_path, "a") as jsonl_file:
318
+ for index, image in enumerate(images):
319
+ new_image_path = shutil.copy(str(image), destination_folder)
320
+ caption = captions[index]
321
+ file_name = os.path.basename(new_image_path)
322
+ data = {"file_name": file_name, "prompt": caption}
323
+ jsonl_file.write(json.dumps(data) + "\n")
324
+
325
+ return destination_folder
326
+
327
+ def recursive_update(d, u):
328
+ for k, v in u.items():
329
+ if isinstance(v, dict) and v:
330
+ d[k] = recursive_update(d.get(k, {}), v)
331
+ else:
332
+ d[k] = v
333
+ return d
334
+
335
+ def start_training(
336
+ lora_name,
337
+ concept_sentence,
338
+ steps,
339
+ lr,
340
+ rank,
341
+ model_to_train,
342
+ low_vram,
343
+ dataset_folder,
344
+ sample_1,
345
+ sample_2,
346
+ sample_3,
347
+ use_more_advanced_options,
348
+ more_advanced_options,
349
+ ):
350
+ try:
351
+ user = whoami()
352
+ username = user.get("name", "anonymous")
353
+ push_to_hub = True
354
+ except:
355
+ username = "anonymous"
356
+ push_to_hub = False
357
+
358
+ slugged_lora_name = lora_name.replace(" ", "_").lower()
359
+
360
+ # Load base config
361
+ config = {
362
+ "config": {
363
+ "name": slugged_lora_name,
364
+ "process": [
365
+ {
366
+ "model": {
367
+ "low_vram": low_vram,
368
+ "is_flux": True,
369
+ "quantize": True,
370
+ "name_or_path": "black-forest-labs/FLUX.1-dev"
371
+ },
372
+ "network": {
373
+ "linear": rank,
374
+ "linear_alpha": rank,
375
+ "type": "lora"
376
+ },
377
+ "train": {
378
+ "steps": steps,
379
+ "lr": lr,
380
+ "skip_first_sample": True,
381
+ "batch_size": 1,
382
+ "dtype": "bf16",
383
+ "gradient_accumulation_steps": 1,
384
+ "gradient_checkpointing": True,
385
+ "noise_scheduler": "flowmatch",
386
+ "optimizer": "adamw8bit",
387
+ "ema_config": {
388
+ "use_ema": True,
389
+ "ema_decay": 0.99
390
+ }
391
+ },
392
+ "datasets": [
393
+ {"folder_path": dataset_folder}
394
+ ],
395
+ "save": {
396
+ "dtype": "float16",
397
+ "save_every": 10000,
398
+ "push_to_hub": push_to_hub,
399
+ "hf_repo_id": f"{username}/{slugged_lora_name}",
400
+ "hf_private": True,
401
+ "max_step_saves_to_keep": 4
402
+ },
403
+ "sample": {
404
+ "guidance_scale": 3.5,
405
+ "sample_every": steps,
406
+ "sample_steps": 28,
407
+ "width": 1024,
408
+ "height": 1024,
409
+ "walk_seed": True,
410
+ "seed": 42,
411
+ "sampler": "flowmatch",
412
+ "prompts": [p for p in [sample_1, sample_2, sample_3] if p]
413
+ },
414
+ "trigger_word": concept_sentence
415
+ }
416
+ ]
417
+ }
418
+ }
419
+
420
+ # Apply advanced YAML overrides if any
421
+ if use_more_advanced_options and more_advanced_options:
422
+ advanced_config = yaml.safe_load(more_advanced_options)
423
+ config["config"]["process"][0] = recursive_update(config["config"]["process"][0], advanced_config)
424
+
425
+ # Save YAML config
426
+ os.makedirs("tmp_configs", exist_ok=True)
427
+ config_path = f"tmp_configs/{uuid.uuid4()}_{slugged_lora_name}.yaml"
428
+ with open(config_path, "w") as f:
429
+ yaml.dump(config, f)
430
+
431
+ # Simulate training
432
+ print(f"[INFO] Starting training with config: {config_path}")
433
+ print(json.dumps(config, indent=2))
434
+ return f"Training started successfully with config: {config_path}"
435
+
436
+ # ========== MAIN ENDPOINT ==========
437
+ @app.post("/train-from-hf")
438
+ def auto_run_lora_from_repo():
439
+ try:
440
+ local_dir = Path(f"/tmp/{LORA_NAME}-{uuid.uuid4()}")
441
+ os.makedirs(local_dir, exist_ok=True)
442
+
443
+ hf_hub_download(
444
+ repo_id=REPO_ID,
445
+ repo_type="dataset",
446
+ subfolder=FOLDER_IN_REPO,
447
+ local_dir=local_dir,
448
+ local_dir_use_symlinks=False,
449
+ force_download=False,
450
+ etag_timeout=10,
451
+ allow_patterns=["*.jpg", "*.png", "*.jpeg"],
452
+ )
453
+
454
+ image_dir = local_dir / FOLDER_IN_REPO
455
+ image_paths = list(image_dir.rglob("*.jpg")) + list(image_dir.rglob("*.jpeg")) + list(image_dir.rglob("*.png"))
456
+
457
+ if not image_paths:
458
+ return JSONResponse(status_code=400, content={"error": "No images found in the HF repo folder."})
459
+
460
+ captions = [
461
+ f"Autogenerated caption for {img.stem} in the {CONCEPT_SENTENCE} [trigger]" for img in image_paths
462
+ ]
463
+
464
+ dataset_path = create_dataset(image_paths, *captions)
465
+
466
+ result = start_training(
467
+ lora_name=LORA_NAME,
468
+ concept_sentence=CONCEPT_SENTENCE,
469
+ steps=1000,
470
+ lr=4e-4,
471
+ rank=16,
472
+ model_to_train="dev",
473
+ low_vram=True,
474
+ dataset_folder=dataset_path,
475
+ sample_1=f"A stylized portrait using {CONCEPT_SENTENCE}",
476
+ sample_2=f"A cat in the {CONCEPT_SENTENCE}",
477
+ sample_3=f"A selfie processed in {CONCEPT_SENTENCE}",
478
+ use_more_advanced_options=True,
479
+ more_advanced_options="""
480
+ training:
481
+ seed: 42
482
+ precision: bf16
483
+ batch_size: 2
484
+ augmentation:
485
+ flip: true
486
+ color_jitter: true
487
+ """
488
+ )
489
+
490
+ return {"message": result}
491
+
492
+ except Exception as e:
493
+ return JSONResponse(status_code=500, content={"error": str(e)})