juliozhao commited on
Commit
1be1c2a
·
verified ·
1 Parent(s): 3f02d20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -12,7 +12,7 @@ from huggingface_hub import snapshot_download
12
  from visualization import visualize_bbox
13
 
14
  # == download weights ==
15
- model_dir = snapshot_download('juliozhao/DocLayout-YOLO-DocStructBench', local_dir='./models/DocLayout-YOLO-DocStructBench')
16
  # == select device ==
17
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
 
@@ -33,7 +33,7 @@ id_to_names = {
33
  def recognize_image(input_img, conf_threshold, iou_threshold):
34
  det_res = model.predict(
35
  input_img,
36
- imgsz=1024,
37
  conf=conf_threshold,
38
  device=device,
39
  )[0]
@@ -60,7 +60,14 @@ if __name__ == "__main__":
60
  # == load model ==
61
  from doclayout_yolo import YOLOv10
62
  print(f"Using device: {device}")
63
- model = YOLOv10(os.path.join(os.path.dirname(__file__), "models", "DocLayout-YOLO-DocStructBench", "doclayout_yolo_docstructbench_imgsz1024.pt")) # load an official model
 
 
 
 
 
 
 
64
 
65
  with open("header.html", "r") as file:
66
  header = file.read()
 
12
  from visualization import visualize_bbox
13
 
14
  # == download weights ==
15
+ model_dir = snapshot_download('juliozhao/DocLayout-YOLO-DocStructBench-imgsz1280-2501', local_dir='./models/DocLayout-YOLO-DocStructBench-imgsz1280-2501')
16
  # == select device ==
17
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
 
 
33
  def recognize_image(input_img, conf_threshold, iou_threshold):
34
  det_res = model.predict(
35
  input_img,
36
+ imgsz=1280,
37
  conf=conf_threshold,
38
  device=device,
39
  )[0]
 
60
  # == load model ==
61
  from doclayout_yolo import YOLOv10
62
  print(f"Using device: {device}")
63
+ model = YOLOv10(
64
+ os.path.join(
65
+ os.path.dirname(__file__),
66
+ "models",
67
+ "DocLayout-YOLO-DocStructBench-imgsz1280-2501",
68
+ "doclayout_yolo_docstructbench_imgsz1280_2501.pt"
69
+ )
70
+ ) # load an official model
71
 
72
  with open("header.html", "r") as file:
73
  header = file.read()