Spaces:
Running
on
Zero
Running
on
Zero
| from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS | |
| def pre_download_all_models(): | |
| """ | |
| Pre-download all models to avoid download delay during the first user request | |
| """ | |
| imagen_dl_error = pre_download_image_models() | |
| imagedit_dl_error = pre_download_image_models() | |
| videogen_dl_error = pre_download_video_models() | |
| print("All models downloaded.") | |
| print("Models that encountered download error:", "Image Generation:", imagen_dl_error, "Image Edition:", imagedit_dl_error, "Video Generation:", videogen_dl_error) | |
| def pre_download_image_models(): | |
| """ | |
| Pre-download image models to avoid download delay during the first user request | |
| """ | |
| import imagen_hub | |
| errored_models = [] | |
| for model_string in IMAGE_GENERATION_MODELS: | |
| model_lib, model_name, model_type = model_string.split("_") | |
| if model_lib == "imagenhub": | |
| try: | |
| print("Loading image generation model:", model_name) | |
| temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files | |
| del temp_model | |
| except Exception as e: | |
| print(f"Failed to load model {model_name} \n {e}") | |
| errored_models.append(model_string) | |
| continue | |
| else: | |
| pass | |
| return errored_models | |
| def pre_download_image_models(): | |
| """ | |
| Pre-download image models to avoid download delay during the first user request | |
| """ | |
| import imagen_hub | |
| errored_models = [] | |
| for model_string in IMAGE_EDITION_MODELS: | |
| model_lib, model_name, model_type = model_string.split("_") | |
| if model_lib == "imagenhub": | |
| try: | |
| print("Loading image edition model:", model_name) | |
| temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files | |
| del temp_model | |
| except Exception as e: | |
| print(f"Failed to load model {model_name} \n {e}") | |
| errored_models.append(model_string) | |
| continue | |
| else: | |
| pass | |
| return errored_models | |
| def pre_download_video_models(): | |
| """ | |
| Pre-download video models to avoid download delay during the first user request | |
| """ | |
| import videogen_hub | |
| errored_models = [] | |
| for model_string in VIDEO_GENERATION_MODELS: | |
| model_lib, model_name, model_type = model_string.split("_") | |
| if model_lib == "videogenhub": | |
| try: | |
| print("Loading video generation model:", model_name) | |
| temp_model = videogen_hub.get_model(model_name) # Forcing model to download weight files | |
| del temp_model | |
| except Exception as e: | |
| print(f"Failed to load model {model_name} \n {e}") | |
| errored_models.append(model_string) | |
| continue | |
| else: | |
| pass | |
| return errored_models |