Spaces:
Runtime error
Runtime error
Commit
·
2718a79
1
Parent(s):
b7546a7
Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ from model import DocGeoNet
|
|
12 |
from seg import U2NETP
|
13 |
import glob
|
14 |
|
|
|
15 |
warnings.filterwarnings('ignore')
|
16 |
|
17 |
class Net(nn.Module):
|
@@ -52,18 +53,15 @@ def reload_rec_model(model, path=""):
|
|
52 |
model.load_state_dict(model_dict)
|
53 |
return model
|
54 |
|
55 |
-
|
56 |
-
seg_model_path = './model_pretrained/preprocess.pth'
|
57 |
-
rec_model_path = './model_pretrained/DocGeoNet.pth'
|
58 |
-
reload_rec_model(net.DocTr, rec_model_path)
|
59 |
-
reload_seg_model(net.msk, seg_model_path)
|
60 |
-
|
61 |
-
# Compile models (assuming PyTorch 2.0)
|
62 |
-
net = torch.compile(net)
|
63 |
|
64 |
-
net
|
|
|
|
|
|
|
65 |
|
66 |
-
def rec(input_image):
|
67 |
im_ori = np.array(input_image)[:, :, :3] / 255. # read image 0-255 to 0-1
|
68 |
h, w, _ = im_ori.shape
|
69 |
im = cv2.resize(im_ori, (256, 256))
|
@@ -80,7 +78,7 @@ def rec(input_image):
|
|
80 |
bm1 = cv2.blur(bm1, (3, 3))
|
81 |
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
|
82 |
out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
|
83 |
-
img_rec = ((out[0] * 255).permute(1, 2, 0).numpy())[
|
84 |
|
85 |
# Convert from BGR to RGB
|
86 |
img_rec = cv2.cvtColor(img_rec, cv2.COLOR_BGR2RGB)
|
@@ -92,3 +90,10 @@ demo_img_files = glob.glob('./distorted/*.[jJ][pP][gG]') + glob.glob('./distorte
|
|
92 |
# Gradio Interface
|
93 |
input_image = gr.inputs.Image()
|
94 |
output_image = gr.outputs.Image(type='pil')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
from seg import U2NETP
|
13 |
import glob
|
14 |
|
15 |
+
|
16 |
warnings.filterwarnings('ignore')
|
17 |
|
18 |
class Net(nn.Module):
|
|
|
53 |
model.load_state_dict(model_dict)
|
54 |
return model
|
55 |
|
56 |
+
def rec(input_image):
|
57 |
+
seg_model_path = './model_pretrained/preprocess.pth'
|
58 |
+
rec_model_path = './model_pretrained/DocGeoNet.pth'
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
net = Net()
|
61 |
+
reload_rec_model(net.DocTr, rec_model_path)
|
62 |
+
reload_seg_model(net.msk, seg_model_path)
|
63 |
+
net.eval()
|
64 |
|
|
|
65 |
im_ori = np.array(input_image)[:, :, :3] / 255. # read image 0-255 to 0-1
|
66 |
h, w, _ = im_ori.shape
|
67 |
im = cv2.resize(im_ori, (256, 256))
|
|
|
78 |
bm1 = cv2.blur(bm1, (3, 3))
|
79 |
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
|
80 |
out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
|
81 |
+
img_rec = ((out[0] * 255).permute(1, 2, 0).numpy())[:,:,::-1].astype(np.uint8)
|
82 |
|
83 |
# Convert from BGR to RGB
|
84 |
img_rec = cv2.cvtColor(img_rec, cv2.COLOR_BGR2RGB)
|
|
|
90 |
# Gradio Interface
|
91 |
input_image = gr.inputs.Image()
|
92 |
output_image = gr.outputs.Image(type='pil')
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
iface = gr.Interface(fn=rec, inputs=input_image, outputs=output_image, title="DocGeoNet",examples=demo_img_files)
|
97 |
+
|
98 |
+
#iface.launch(server_port=8821, server_name="0.0.0.0")
|
99 |
+
iface.launch()
|