not-lain commited on
Commit
fe5d262
·
1 Parent(s): c1f05bf

reset pytorch matrix multiplication precision for rmbg

Browse files
Files changed (1) hide show
  1. app.py +13 -15
app.py CHANGED
@@ -10,21 +10,6 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
10
  import numpy as np
11
  from simple_lama_inpainting import SimpleLama
12
 
13
- torch.set_float32_matmul_precision(["high", "highest"][0])
14
-
15
- birefnet = AutoModelForImageSegmentation.from_pretrained(
16
- "ZhengPeng7/BiRefNet", trust_remote_code=True
17
- )
18
- birefnet.to("cuda")
19
-
20
-
21
- transform_image = transforms.Compose(
22
- [
23
- transforms.Resize((1024, 1024)),
24
- transforms.ToTensor(),
25
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
26
- ]
27
- )
28
 
29
  pipe = FluxFillPipeline.from_pretrained(
30
  "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
@@ -113,6 +98,18 @@ def rmbg(image=None, url=None):
113
  image = url
114
  image = load_img(image).convert("RGB")
115
  image_size = image.size
 
 
 
 
 
 
 
 
 
 
 
 
116
  input_images = transform_image(image).unsqueeze(0).to("cuda")
117
  # Prediction
118
  with torch.no_grad():
@@ -121,6 +118,7 @@ def rmbg(image=None, url=None):
121
  pred_pil = transforms.ToPILImage()(pred)
122
  mask = pred_pil.resize(image_size)
123
  image.putalpha(mask)
 
124
  return image
125
 
126
 
 
10
  import numpy as np
11
  from simple_lama_inpainting import SimpleLama
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  pipe = FluxFillPipeline.from_pretrained(
15
  "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
 
98
  image = url
99
  image = load_img(image).convert("RGB")
100
  image_size = image.size
101
+ torch.set_float32_matmul_precision(["high", "highest"][0])
102
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
103
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
104
+ )
105
+ birefnet.to("cuda")
106
+ transform_image = transforms.Compose(
107
+ [
108
+ transforms.Resize((1024, 1024)),
109
+ transforms.ToTensor(),
110
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
111
+ ]
112
+ )
113
  input_images = transform_image(image).unsqueeze(0).to("cuda")
114
  # Prediction
115
  with torch.no_grad():
 
118
  pred_pil = transforms.ToPILImage()(pred)
119
  mask = pred_pil.resize(image_size)
120
  image.putalpha(mask)
121
+ torch.set_float32_matmul_precision(["high", "highest"][1])
122
  return image
123
 
124