Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import os | |
| import base64 | |
| import json | |
| import time | |
| import logging | |
| import folder_paths | |
| import glob | |
| import comfy.utils | |
| from aiohttp import web | |
| from PIL import Image | |
| from io import BytesIO | |
| from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types | |
| class ModelFileManager: | |
| def __init__(self) -> None: | |
| self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} | |
| def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: | |
| return self.cache.get(key, default) | |
| def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]): | |
| self.cache[key] = value | |
| def clear_cache(self): | |
| self.cache.clear() | |
| def add_routes(self, routes): | |
| # NOTE: This is an experiment to replace `/models` | |
| async def get_model_folders(request): | |
| model_types = list(folder_paths.folder_names_and_paths.keys()) | |
| folder_black_list = ["configs", "custom_nodes"] | |
| output_folders: list[dict] = [] | |
| for folder in model_types: | |
| if folder in folder_black_list: | |
| continue | |
| output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)}) | |
| return web.json_response(output_folders) | |
| # NOTE: This is an experiment to replace `/models/{folder}` | |
| async def get_all_models(request): | |
| folder = request.match_info.get("folder", None) | |
| if not folder in folder_paths.folder_names_and_paths: | |
| return web.Response(status=404) | |
| files = self.get_model_file_list(folder) | |
| return web.json_response(files) | |
| async def get_model_preview(request): | |
| folder_name = request.match_info.get("folder", None) | |
| path_index = int(request.match_info.get("path_index", None)) | |
| filename = request.match_info.get("filename", None) | |
| if not folder_name in folder_paths.folder_names_and_paths: | |
| return web.Response(status=404) | |
| folders = folder_paths.folder_names_and_paths[folder_name] | |
| folder = folders[0][path_index] | |
| full_filename = os.path.join(folder, filename) | |
| previews = self.get_model_previews(full_filename) | |
| default_preview = previews[0] if len(previews) > 0 else None | |
| if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)): | |
| return web.Response(status=404) | |
| try: | |
| with Image.open(default_preview) as img: | |
| img_bytes = BytesIO() | |
| img.save(img_bytes, format="WEBP") | |
| img_bytes.seek(0) | |
| return web.Response(body=img_bytes.getvalue(), content_type="image/webp") | |
| except: | |
| return web.Response(status=404) | |
| def get_model_file_list(self, folder_name: str): | |
| folder_name = map_legacy(folder_name) | |
| folders = folder_paths.folder_names_and_paths[folder_name] | |
| output_list: list[dict] = [] | |
| for index, folder in enumerate(folders[0]): | |
| if not os.path.isdir(folder): | |
| continue | |
| out = self.cache_model_file_list_(folder) | |
| if out is None: | |
| out = self.recursive_search_models_(folder, index) | |
| self.set_cache(folder, out) | |
| output_list.extend(out[0]) | |
| return output_list | |
| def cache_model_file_list_(self, folder: str): | |
| model_file_list_cache = self.get_cache(folder) | |
| if model_file_list_cache is None: | |
| return None | |
| if not os.path.isdir(folder): | |
| return None | |
| if os.path.getmtime(folder) != model_file_list_cache[1]: | |
| return None | |
| for x in model_file_list_cache[1]: | |
| time_modified = model_file_list_cache[1][x] | |
| folder = x | |
| if os.path.getmtime(folder) != time_modified: | |
| return None | |
| return model_file_list_cache | |
| def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]: | |
| if not os.path.isdir(directory): | |
| return [], {}, time.perf_counter() | |
| excluded_dir_names = [".git"] | |
| # TODO use settings | |
| include_hidden_files = False | |
| result: list[str] = [] | |
| dirs: dict[str, float] = {} | |
| for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): | |
| subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] | |
| if not include_hidden_files: | |
| subdirs[:] = [d for d in subdirs if not d.startswith(".")] | |
| filenames = [f for f in filenames if not f.startswith(".")] | |
| filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions) | |
| for file_name in filenames: | |
| try: | |
| relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) | |
| result.append(relative_path) | |
| except: | |
| logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.") | |
| continue | |
| for d in subdirs: | |
| path: str = os.path.join(dirpath, d) | |
| try: | |
| dirs[path] = os.path.getmtime(path) | |
| except FileNotFoundError: | |
| logging.warning(f"Warning: Unable to access {path}. Skipping this path.") | |
| continue | |
| return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter() | |
| def get_model_previews(self, filepath: str) -> list[str | BytesIO]: | |
| dirname = os.path.dirname(filepath) | |
| if not os.path.exists(dirname): | |
| return [] | |
| basename = os.path.splitext(filepath)[0] | |
| match_files = glob.glob(f"{basename}.*", recursive=False) | |
| image_files = filter_files_content_types(match_files, "image") | |
| safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None) | |
| safetensors_metadata = {} | |
| result: list[str | BytesIO] = [] | |
| for filename in image_files: | |
| _basename = os.path.splitext(filename)[0] | |
| if _basename == basename: | |
| result.append(filename) | |
| if _basename == f"{basename}.preview": | |
| result.append(filename) | |
| if safetensors_file: | |
| safetensors_filepath = os.path.join(dirname, safetensors_file) | |
| header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024) | |
| if header: | |
| safetensors_metadata = json.loads(header) | |
| safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None) | |
| if safetensors_images: | |
| safetensors_images = json.loads(safetensors_images) | |
| for image in safetensors_images: | |
| result.append(BytesIO(base64.b64decode(image))) | |
| return result | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| self.clear_cache() | |