Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
4a07334
1
Parent(s):
b3212f3
Fix bug on model preloadling
Browse files- model/model_manager.py +2 -2
- model/pre_download.py +8 -7
model/model_manager.py
CHANGED
|
@@ -7,7 +7,7 @@ import spaces
|
|
| 7 |
from PIL import Image
|
| 8 |
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS, MUSEUM_UNSUPPORTED_MODELS, DESIRED_APPEAR_MODEL, load_pipeline
|
| 9 |
from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum, draw_from_videogen_museum, draw2_from_videogen_museum
|
| 10 |
-
from .pre_download import pre_download_all_models,
|
| 11 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 12 |
import torch
|
| 13 |
import re
|
|
@@ -82,7 +82,7 @@ class ModelManager:
|
|
| 82 |
self.load_guard(enable_nsfw)
|
| 83 |
self.loaded_models = {}
|
| 84 |
if do_pre_download:
|
| 85 |
-
pre_download_all_models()
|
| 86 |
if do_debug_packages:
|
| 87 |
debug_packages()
|
| 88 |
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS, MUSEUM_UNSUPPORTED_MODELS, DESIRED_APPEAR_MODEL, load_pipeline
|
| 9 |
from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum, draw_from_videogen_museum, draw2_from_videogen_museum
|
| 10 |
+
from .pre_download import pre_download_all_models, pre_download_image_models_gen, pre_download_image_models_edit, pre_download_video_models_gen
|
| 11 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 12 |
import torch
|
| 13 |
import re
|
|
|
|
| 82 |
self.load_guard(enable_nsfw)
|
| 83 |
self.loaded_models = {}
|
| 84 |
if do_pre_download:
|
| 85 |
+
pre_download_all_models(include_video=False)
|
| 86 |
if do_debug_packages:
|
| 87 |
debug_packages()
|
| 88 |
|
model/pre_download.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
| 1 |
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS
|
| 2 |
|
| 3 |
-
def pre_download_all_models():
|
| 4 |
"""
|
| 5 |
Pre-download all models to avoid download delay during the first user request
|
| 6 |
"""
|
| 7 |
-
imagen_dl_error =
|
| 8 |
-
imagedit_dl_error =
|
| 9 |
-
|
|
|
|
| 10 |
print("All models downloaded.")
|
| 11 |
print("Models that encountered download error:", "Image Generation:", imagen_dl_error, "Image Edition:", imagedit_dl_error, "Video Generation:", videogen_dl_error)
|
| 12 |
|
| 13 |
-
def
|
| 14 |
"""
|
| 15 |
Pre-download image models to avoid download delay during the first user request
|
| 16 |
"""
|
|
@@ -33,7 +34,7 @@ def pre_download_image_models():
|
|
| 33 |
pass
|
| 34 |
return errored_models
|
| 35 |
|
| 36 |
-
def
|
| 37 |
"""
|
| 38 |
Pre-download image models to avoid download delay during the first user request
|
| 39 |
"""
|
|
@@ -56,7 +57,7 @@ def pre_download_image_models():
|
|
| 56 |
pass
|
| 57 |
return errored_models
|
| 58 |
|
| 59 |
-
def
|
| 60 |
"""
|
| 61 |
Pre-download video models to avoid download delay during the first user request
|
| 62 |
"""
|
|
|
|
| 1 |
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS
|
| 2 |
|
| 3 |
+
def pre_download_all_models(include_video=True):
|
| 4 |
"""
|
| 5 |
Pre-download all models to avoid download delay during the first user request
|
| 6 |
"""
|
| 7 |
+
imagen_dl_error = pre_download_image_models_gen()
|
| 8 |
+
imagedit_dl_error = pre_download_image_models_edit()
|
| 9 |
+
if include_video:
|
| 10 |
+
videogen_dl_error = pre_download_video_models_gen()
|
| 11 |
print("All models downloaded.")
|
| 12 |
print("Models that encountered download error:", "Image Generation:", imagen_dl_error, "Image Edition:", imagedit_dl_error, "Video Generation:", videogen_dl_error)
|
| 13 |
|
| 14 |
+
def pre_download_image_models_gen():
|
| 15 |
"""
|
| 16 |
Pre-download image models to avoid download delay during the first user request
|
| 17 |
"""
|
|
|
|
| 34 |
pass
|
| 35 |
return errored_models
|
| 36 |
|
| 37 |
+
def pre_download_image_models_edit():
|
| 38 |
"""
|
| 39 |
Pre-download image models to avoid download delay during the first user request
|
| 40 |
"""
|
|
|
|
| 57 |
pass
|
| 58 |
return errored_models
|
| 59 |
|
| 60 |
+
def pre_download_video_models_gen():
|
| 61 |
"""
|
| 62 |
Pre-download video models to avoid download delay during the first user request
|
| 63 |
"""
|