Tsumugii24 commited on
Commit
9a27fc0
·
1 Parent(s): 8af039a

add model auto downloads

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py CHANGED
@@ -17,6 +17,7 @@ ROOT_PATH = sys.path[0] # 项目根目录
17
 
18
  fonts_list = ["SimSun.ttf", "TimesNewRoman.ttf", "malgun.ttf"] # 字体列表
19
  fonts_directory_path = Path(ROOT_PATH, "fonts")
 
20
 
21
  data_url_dict = {
22
  "SimSun.ttf": "https://raw.githubusercontent.com/Tsumugii24/Typora-images/main/files/SimSun.ttf",
@@ -24,6 +25,13 @@ data_url_dict = {
24
  "malgun.ttf": "https://raw.githubusercontent.com/Tsumugii24/Typora-images/main/files/malgun.ttf",
25
  }
26
 
 
 
 
 
 
 
 
27
 
28
  # 判断字体文件是否存在
29
  def is_fonts(fonts_dir):
@@ -42,6 +50,24 @@ def is_fonts(fonts_dir):
42
  # 本地字体库不存在,创建字体库
43
  print("[bold red]Local fonts library does not exist, creating now...[/bold red]")
44
  download_fonts(fonts_list) # 创建字体库
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # 下载字体
47
  def download_fonts(font_diff):
@@ -56,7 +82,21 @@ def download_fonts(font_diff):
56
  # 下载字体文件
57
  wget.download(v, file_path)
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  is_fonts(fonts_directory_path)
 
60
 
61
  # --------------------- 字体库 ---------------------
62
  SimSun_path = f"{ROOT_PATH}/fonts/SimSun.ttf" # 宋体文件路径
 
17
 
18
  fonts_list = ["SimSun.ttf", "TimesNewRoman.ttf", "malgun.ttf"] # 字体列表
19
  fonts_directory_path = Path(ROOT_PATH, "fonts")
20
+ models_directory_path = Path(ROOT_PATH) # 模型存放在项目的根目录
21
 
22
  data_url_dict = {
23
  "SimSun.ttf": "https://raw.githubusercontent.com/Tsumugii24/Typora-images/main/files/SimSun.ttf",
 
25
  "malgun.ttf": "https://raw.githubusercontent.com/Tsumugii24/Typora-images/main/files/malgun.ttf",
26
  }
27
 
28
+ model_url_dict = {
29
+ "cnn_se.pt": "https://huggingface.co/Tsumugii/lesion-cells-det/raw/main/cnn_se.pt",
30
+ "detr_based.pt": "https://huggingface.co/Tsumugii/lesion-cells-det/raw/main/detr_based.pt",
31
+ "vit_based.pt": "https://huggingface.co/Tsumugii/lesion-cells-det/raw/main/vit_based.pt",
32
+ "yolov5_based.pt": "https://huggingface.co/Tsumugii/lesion-cells-det/raw/main/yolov5_based.pt",
33
+ "yolov8_based.pt": "https://huggingface.co/Tsumugii/lesion-cells-det/raw/main/yolov8_based.pt",
34
+ }
35
 
36
  # 判断字体文件是否存在
37
  def is_fonts(fonts_dir):
 
50
  # 本地字体库不存在,创建字体库
51
  print("[bold red]Local fonts library does not exist, creating now...[/bold red]")
52
  download_fonts(fonts_list) # 创建字体库
53
+
54
+ # 判断模型文件是否存在
55
+ def is_models(models_dir):
56
+ if models_dir.is_dir():
57
+ # 如果本地模型库存在
58
+ local_list = os.listdir(models_dir) # 本地模型库
59
+
60
+ model_diff = list(set(model_url_dict.keys()).difference(set(local_list)))
61
+
62
+ if model_diff != []:
63
+ # 缺失模型
64
+ download_models(model_diff) # 下载缺失的模型
65
+ else:
66
+ print(f"{model_url_dict.keys()}[bold green]Required models already downloaded![/bold green]")
67
+ else:
68
+ # 本地模型库不存在,创建模型库
69
+ print("[bold red]Local models library does not exist, creating now...[/bold red]")
70
+ download_models(model_url_dict.keys()) # 创建模型库
71
 
72
  # 下载字体
73
  def download_fonts(font_diff):
 
82
  # 下载字体文件
83
  wget.download(v, file_path)
84
 
85
+ # 下载模型
86
+ def download_models(model_diff):
87
+ global model_name
88
+
89
+ for k in model_diff:
90
+ v = model_url_dict[k]
91
+ model_name = v.split("/")[-1] # 模型名称
92
+
93
+ file_path = f"{ROOT_PATH}/{model_name}" # 模型路径
94
+ # 下载模型文件
95
+ wget.download(v, file_path)
96
+
97
+
98
  is_fonts(fonts_directory_path)
99
+ is_models(models_directory_path)
100
 
101
  # --------------------- 字体库 ---------------------
102
  SimSun_path = f"{ROOT_PATH}/fonts/SimSun.ttf" # 宋体文件路径