petergpt commited on
Commit
c333b0b
Β·
verified Β·
1 Parent(s): e041428

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -36
app.py CHANGED
@@ -5,13 +5,13 @@ from PIL import Image
5
  from torchvision import transforms
6
  import gradio as gr
7
 
8
- # Load the model from Hugging Face
9
  birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True)
10
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
  birefnet.to(device)
12
  birefnet.eval()
13
 
14
- # Define the transform to preprocess the input images
15
  image_size = (1024, 1024)
16
  transform_image = transforms.Compose([
17
  transforms.Resize(image_size),
@@ -19,63 +19,72 @@ transform_image = transforms.Compose([
19
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
20
  ])
21
 
22
- def extract_objects(filepaths):
23
- # Open all images from the uploaded file paths
24
- images = [Image.open(path).convert("RGB") for path in filepaths]
25
-
26
- start_time = time.time()
27
  inputs = []
28
  original_sizes = []
29
- for img in images:
30
  original_sizes.append(img.size)
31
  inputs.append(transform_image(img))
32
  input_tensor = torch.stack(inputs).to(device)
33
 
34
- # Inference
35
- inf_start = time.time()
36
- with torch.no_grad():
37
- preds = birefnet(input_tensor)[-1].sigmoid().cpu()
38
- inf_end = time.time()
 
39
 
40
- # Post-process results
41
  results = []
42
- image_times = []
43
- for i, img in enumerate(images):
44
- t_start = time.time()
45
  pred = preds[i].squeeze()
46
  pred_pil = transforms.ToPILImage()(pred)
47
  mask = pred_pil.resize(original_sizes[i])
48
-
49
- # Create a transparent background image
50
  result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
51
  result.paste(img, mask=mask)
52
  results.append(result)
53
- t_end = time.time()
54
- image_times.append(t_end - t_start)
55
 
56
- end_time = time.time()
57
- total_time = end_time - start_time
58
- inference_time = inf_end - inf_start
59
- prep_post_time = total_time - inference_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Create a summary of timings
62
  summary = (
63
- f"Total request time: {total_time:.2f} s\n"
64
- f"Inference time (batch): {inference_time:.2f} s\n"
65
- f"Pre/Post-processing time: {prep_post_time:.2f} s\n"
66
- "Per-image post-processing times:\n" +
67
- "\n".join([f" Image {i+1}: {t:.2f} s" for i, t in enumerate(image_times)])
68
  )
69
 
70
- return results, summary
71
 
72
  iface = gr.Interface(
73
  fn=extract_objects,
74
  inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
75
  outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
76
- title="BiRefNet Bulk Background Removal",
77
- description="Upload multiple images and process them in one request. Timing information for the full request and per-image processing is provided."
78
  )
79
 
80
- if __name__ == "__main__":
81
- iface.launch()
 
 
5
  from torchvision import transforms
6
  import gradio as gr
7
 
8
+ # Load model
9
  birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True)
10
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
  birefnet.to(device)
12
  birefnet.eval()
13
 
14
+ # Preprocessing
15
  image_size = (1024, 1024)
16
  transform_image = transforms.Compose([
17
  transforms.Resize(image_size),
 
19
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
20
  ])
21
 
22
+ def process_batch(img_batch):
 
 
 
 
23
  inputs = []
24
  original_sizes = []
25
+ for img in img_batch:
26
  original_sizes.append(img.size)
27
  inputs.append(transform_image(img))
28
  input_tensor = torch.stack(inputs).to(device)
29
 
30
+ try:
31
+ with torch.no_grad():
32
+ preds = birefnet(input_tensor)[-1].sigmoid().cpu()
33
+ except torch.OutOfMemoryError:
34
+ torch.cuda.empty_cache()
35
+ return None
36
 
 
37
  results = []
38
+ for i, img in enumerate(img_batch):
 
 
39
  pred = preds[i].squeeze()
40
  pred_pil = transforms.ToPILImage()(pred)
41
  mask = pred_pil.resize(original_sizes[i])
 
 
42
  result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
43
  result.paste(img, mask=mask)
44
  results.append(result)
 
 
45
 
46
+ return results
47
+
48
+ def extract_objects(filepaths):
49
+ # Open all images from the uploaded file paths
50
+ images = [Image.open(path).convert("RGB") for path in filepaths]
51
+
52
+ # You can define a batch size here (e.g., batch_size = 5)
53
+ # This prevents trying to process all images at once if too large
54
+ batch_size = 5
55
+ batches = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
56
+
57
+ total_start = time.time()
58
+ all_results = []
59
+ batch_times = []
60
+ for b_idx, batch in enumerate(batches):
61
+ b_start = time.time()
62
+ res = process_batch(batch)
63
+ if res is None:
64
+ # Handle OOM gracefully
65
+ all_results.extend([Image.new("RGBA", (100, 100), (255,0,0,255)) for _ in batch])
66
+ batch_times.append(f"Batch {b_idx+1}: OOM Error")
67
+ else:
68
+ all_results.extend(res)
69
+ b_end = time.time()
70
+ batch_times.append(f"Batch {b_idx+1}: {(b_end - b_start):.2f} s")
71
+ total_end = time.time()
72
 
 
73
  summary = (
74
+ f"Total request time: {total_end - total_start:.2f} s\n"
75
+ "Batch times:\n" + "\n".join(batch_times)
 
 
 
76
  )
77
 
78
+ return all_results, summary
79
 
80
  iface = gr.Interface(
81
  fn=extract_objects,
82
  inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
83
  outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
84
+ title="BiRefNet Bulk Background Removal with Queue & Batch",
85
+ description="Upload multiple images. The request is queued and processed in batches to avoid OOM errors."
86
  )
87
 
88
+ # Enable the queue with defined concurrency to prevent multiple large requests at once
89
+ # You can adjust concurrency_count and max_size as needed.
90
+ iface.queue(concurrency_count=1, max_size=10).launch()