Spaces:
Runtime error
Runtime error
Commit
·
b7546a7
1
Parent(s):
b79aac9
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,7 +12,6 @@ from model import DocGeoNet
|
|
| 12 |
from seg import U2NETP
|
| 13 |
import glob
|
| 14 |
|
| 15 |
-
|
| 16 |
warnings.filterwarnings('ignore')
|
| 17 |
|
| 18 |
class Net(nn.Module):
|
|
@@ -53,15 +52,18 @@ def reload_rec_model(model, path=""):
|
|
| 53 |
model.load_state_dict(model_dict)
|
| 54 |
return model
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
reload_seg_model(net.msk, seg_model_path)
|
| 63 |
-
net.eval()
|
| 64 |
|
|
|
|
|
|
|
|
|
|
| 65 |
im_ori = np.array(input_image)[:, :, :3] / 255. # read image 0-255 to 0-1
|
| 66 |
h, w, _ = im_ori.shape
|
| 67 |
im = cv2.resize(im_ori, (256, 256))
|
|
@@ -78,7 +80,7 @@ def rec(input_image):
|
|
| 78 |
bm1 = cv2.blur(bm1, (3, 3))
|
| 79 |
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
|
| 80 |
out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
|
| 81 |
-
img_rec = ((out[0] * 255).permute(1, 2, 0).numpy())[
|
| 82 |
|
| 83 |
# Convert from BGR to RGB
|
| 84 |
img_rec = cv2.cvtColor(img_rec, cv2.COLOR_BGR2RGB)
|
|
@@ -90,10 +92,3 @@ demo_img_files = glob.glob('./distorted/*.[jJ][pP][gG]') + glob.glob('./distorte
|
|
| 90 |
# Gradio Interface
|
| 91 |
input_image = gr.inputs.Image()
|
| 92 |
output_image = gr.outputs.Image(type='pil')
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
iface = gr.Interface(fn=rec, inputs=input_image, outputs=output_image, title="DocGeoNet",examples=demo_img_files)
|
| 97 |
-
|
| 98 |
-
#iface.launch(server_port=8821, server_name="0.0.0.0")
|
| 99 |
-
iface.launch()
|
|
|
|
| 12 |
from seg import U2NETP
|
| 13 |
import glob
|
| 14 |
|
|
|
|
| 15 |
warnings.filterwarnings('ignore')
|
| 16 |
|
| 17 |
class Net(nn.Module):
|
|
|
|
| 52 |
model.load_state_dict(model_dict)
|
| 53 |
return model
|
| 54 |
|
| 55 |
+
net = Net()
|
| 56 |
+
seg_model_path = './model_pretrained/preprocess.pth'
|
| 57 |
+
rec_model_path = './model_pretrained/DocGeoNet.pth'
|
| 58 |
+
reload_rec_model(net.DocTr, rec_model_path)
|
| 59 |
+
reload_seg_model(net.msk, seg_model_path)
|
| 60 |
|
| 61 |
+
# Compile models (assuming PyTorch 2.0)
|
| 62 |
+
net = torch.compile(net)
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
net.eval()
|
| 65 |
+
|
| 66 |
+
def rec(input_image):
|
| 67 |
im_ori = np.array(input_image)[:, :, :3] / 255. # read image 0-255 to 0-1
|
| 68 |
h, w, _ = im_ori.shape
|
| 69 |
im = cv2.resize(im_ori, (256, 256))
|
|
|
|
| 80 |
bm1 = cv2.blur(bm1, (3, 3))
|
| 81 |
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
|
| 82 |
out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
|
| 83 |
+
img_rec = ((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1].astype(np.uint8)
|
| 84 |
|
| 85 |
# Convert from BGR to RGB
|
| 86 |
img_rec = cv2.cvtColor(img_rec, cv2.COLOR_BGR2RGB)
|
|
|
|
| 92 |
# Gradio Interface
|
| 93 |
input_image = gr.inputs.Image()
|
| 94 |
output_image = gr.outputs.Image(type='pil')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|