rahul7star commited on
Commit
69af2b6
·
verified ·
1 Parent(s): 63ac92a

Update flux_train_ui.py

Browse files
Files changed (1) hide show
  1. flux_train_ui.py +55 -1
flux_train_ui.py CHANGED
@@ -3,6 +3,11 @@ from huggingface_hub import whoami
3
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
4
  import sys
5
  import spaces
 
 
 
 
 
6
  # Add the current working directory to the Python path
7
  sys.path.insert(0, os.getcwd())
8
 
@@ -21,6 +26,26 @@ sys.path.insert(0, "ai-toolkit")
21
  from toolkit.job import get_job
22
 
23
  MAX_IMAGES = 150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def load_captioning(uploaded_files, concept_sentence):
26
  uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')]
@@ -427,5 +452,34 @@ with gr.Blocks(theme=theme, css=css) as demo:
427
 
428
  do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list)
429
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  if __name__ == "__main__":
431
- demo.launch(share=True, show_error=True)
 
 
 
 
 
 
 
3
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
4
  import sys
5
  import spaces
6
+ import gradio as gr
7
+ from fastapi import FastAPI, Request
8
+ import uvicorn
9
+ from fastapi import Request
10
+
11
  # Add the current working directory to the Python path
12
  sys.path.insert(0, os.getcwd())
13
 
 
26
  from toolkit.job import get_job
27
 
28
  MAX_IMAGES = 150
29
+
30
+
31
+
32
+
33
+
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
 
50
  def load_captioning(uploaded_files, concept_sentence):
51
  uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')]
 
452
 
453
  do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list)
454
 
455
+
456
+
457
+
458
+
459
+
460
+
461
+ def train_model_ui(data: str):
462
+ print(f"🚀 Triggered with input: {data}")
463
+ return f"✅ Received: {data}"
464
+
465
+ def add_trigger_endpoint(app):
466
+ @app.post("/trigger")
467
+ async def trigger(request: Request):
468
+ try:
469
+ body = await request.json()
470
+ input_data = body.get("input", "")
471
+ print(f"🔁 API Trigger: {input_data}")
472
+ result = train_model_ui(input_data)
473
+ return {"result": result}
474
+ except Exception as e:
475
+ return {"error": str(e)}
476
+
477
+ # ⬇️ DO NOT remove this since you're keeping the manual launch
478
  if __name__ == "__main__":
479
+ # Launch Gradio, then hook in the FastAPI route
480
+ demo.launch(
481
+ share=True,
482
+ show_error=True,
483
+ after_start=add_trigger_endpoint # ✅ Hook the FastAPI logic here
484
+ )
485
+