Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2021 The HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Utilities to dynamically load objects from the Hub.""" | |
| import filecmp | |
| import importlib | |
| import os | |
| import re | |
| import shutil | |
| import signal | |
| import sys | |
| import typing | |
| import warnings | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Union | |
| from .utils import ( | |
| HF_MODULES_CACHE, | |
| TRANSFORMERS_DYNAMIC_MODULE_NAME, | |
| cached_file, | |
| extract_commit_hash, | |
| is_offline_mode, | |
| logging, | |
| try_to_load_from_cache, | |
| ) | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| def init_hf_modules(): | |
| """ | |
| Creates the cache directory for modules with an init, and adds it to the Python path. | |
| """ | |
| # This function has already been executed if HF_MODULES_CACHE already is in the Python path. | |
| if HF_MODULES_CACHE in sys.path: | |
| return | |
| sys.path.append(HF_MODULES_CACHE) | |
| os.makedirs(HF_MODULES_CACHE, exist_ok=True) | |
| init_path = Path(HF_MODULES_CACHE) / "__init__.py" | |
| if not init_path.exists(): | |
| init_path.touch() | |
| importlib.invalidate_caches() | |
| def create_dynamic_module(name: Union[str, os.PathLike]): | |
| """ | |
| Creates a dynamic module in the cache directory for modules. | |
| Args: | |
| name (`str` or `os.PathLike`): | |
| The name of the dynamic module to create. | |
| """ | |
| init_hf_modules() | |
| dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve() | |
| # If the parent module does not exist yet, recursively create it. | |
| if not dynamic_module_path.parent.exists(): | |
| create_dynamic_module(dynamic_module_path.parent) | |
| os.makedirs(dynamic_module_path, exist_ok=True) | |
| init_path = dynamic_module_path / "__init__.py" | |
| if not init_path.exists(): | |
| init_path.touch() | |
| # It is extremely important to invalidate the cache when we change stuff in those modules, or users end up | |
| # with errors about module that do not exist. Same for all other `invalidate_caches` in this file. | |
| importlib.invalidate_caches() | |
| def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]: | |
| """ | |
| Get the list of modules that are relatively imported in a module file. | |
| Args: | |
| module_file (`str` or `os.PathLike`): The module file to inspect. | |
| Returns: | |
| `List[str]`: The list of relative imports in the module. | |
| """ | |
| with open(module_file, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| # Imports of the form `import .xxx` | |
| relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) | |
| # Imports of the form `from .xxx import yyy` | |
| relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) | |
| # Unique-ify | |
| return list(set(relative_imports)) | |
| def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]: | |
| """ | |
| Get the list of all files that are needed for a given module. Note that this function recurses through the relative | |
| imports (if a imports b and b imports c, it will return module files for b and c). | |
| Args: | |
| module_file (`str` or `os.PathLike`): The module file to inspect. | |
| Returns: | |
| `List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list | |
| of module files a given module needs. | |
| """ | |
| no_change = False | |
| files_to_check = [module_file] | |
| all_relative_imports = [] | |
| # Let's recurse through all relative imports | |
| while not no_change: | |
| new_imports = [] | |
| for f in files_to_check: | |
| new_imports.extend(get_relative_imports(f)) | |
| module_path = Path(module_file).parent | |
| new_import_files = [str(module_path / m) for m in new_imports] | |
| new_import_files = [f for f in new_import_files if f not in all_relative_imports] | |
| files_to_check = [f"{f}.py" for f in new_import_files] | |
| no_change = len(new_import_files) == 0 | |
| all_relative_imports.extend(files_to_check) | |
| return all_relative_imports | |
| def get_imports(filename: Union[str, os.PathLike]) -> List[str]: | |
| """ | |
| Extracts all the libraries (not relative imports this time) that are imported in a file. | |
| Args: | |
| filename (`str` or `os.PathLike`): The module file to inspect. | |
| Returns: | |
| `List[str]`: The list of all packages required to use the input module. | |
| """ | |
| with open(filename, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| # filter out try/except block so in custom code we can have try/except imports | |
| content = re.sub(r"\s*try\s*:\s*.*?\s*except\s*.*?:", "", content, flags=re.MULTILINE | re.DOTALL) | |
| # Imports of the form `import xxx` | |
| imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) | |
| # Imports of the form `from xxx import yyy` | |
| imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) | |
| # Only keep the top-level module | |
| imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] | |
| return list(set(imports)) | |
| def check_imports(filename: Union[str, os.PathLike]) -> List[str]: | |
| """ | |
| Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a | |
| library is missing. | |
| Args: | |
| filename (`str` or `os.PathLike`): The module file to check. | |
| Returns: | |
| `List[str]`: The list of relative imports in the file. | |
| """ | |
| imports = get_imports(filename) | |
| missing_packages = [] | |
| for imp in imports: | |
| try: | |
| importlib.import_module(imp) | |
| except ImportError: | |
| missing_packages.append(imp) | |
| if len(missing_packages) > 0: | |
| raise ImportError( | |
| "This modeling file requires the following packages that were not found in your environment: " | |
| f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" | |
| ) | |
| return get_relative_imports(filename) | |
| def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -> typing.Type: | |
| """ | |
| Import a module on the cache directory for modules and extract a class from it. | |
| Args: | |
| class_name (`str`): The name of the class to import. | |
| module_path (`str` or `os.PathLike`): The path to the module to import. | |
| Returns: | |
| `typing.Type`: The class looked for. | |
| """ | |
| module_path = module_path.replace(os.path.sep, ".") | |
| module = importlib.import_module(module_path) | |
| return getattr(module, class_name) | |
| def get_cached_module_file( | |
| pretrained_model_name_or_path: Union[str, os.PathLike], | |
| module_file: str, | |
| cache_dir: Optional[Union[str, os.PathLike]] = None, | |
| force_download: bool = False, | |
| resume_download: bool = False, | |
| proxies: Optional[Dict[str, str]] = None, | |
| token: Optional[Union[bool, str]] = None, | |
| revision: Optional[str] = None, | |
| local_files_only: bool = False, | |
| repo_type: Optional[str] = None, | |
| _commit_hash: Optional[str] = None, | |
| **deprecated_kwargs, | |
| ) -> str: | |
| """ | |
| Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached | |
| Transformers module. | |
| Args: | |
| pretrained_model_name_or_path (`str` or `os.PathLike`): | |
| This can be either: | |
| - a string, the *model id* of a pretrained model configuration hosted inside a model repo on | |
| huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced | |
| under a user or organization name, like `dbmdz/bert-base-german-cased`. | |
| - a path to a *directory* containing a configuration file saved using the | |
| [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. | |
| module_file (`str`): | |
| The name of the module file containing the class to look for. | |
| cache_dir (`str` or `os.PathLike`, *optional*): | |
| Path to a directory in which a downloaded pretrained model configuration should be cached if the standard | |
| cache should not be used. | |
| force_download (`bool`, *optional*, defaults to `False`): | |
| Whether or not to force to (re-)download the configuration files and override the cached versions if they | |
| exist. | |
| resume_download (`bool`, *optional*, defaults to `False`): | |
| Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. | |
| proxies (`Dict[str, str]`, *optional*): | |
| A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', | |
| 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. | |
| token (`str` or *bool*, *optional*): | |
| The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated | |
| when running `huggingface-cli login` (stored in `~/.huggingface`). | |
| revision (`str`, *optional*, defaults to `"main"`): | |
| The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |
| git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any | |
| identifier allowed by git. | |
| local_files_only (`bool`, *optional*, defaults to `False`): | |
| If `True`, will only try to load the tokenizer configuration from local files. | |
| repo_type (`str`, *optional*): | |
| Specify the repo type (useful when downloading from a space for instance). | |
| <Tip> | |
| Passing `token=True` is required when you want to use a private model. | |
| </Tip> | |
| Returns: | |
| `str`: The path to the module inside the cache. | |
| """ | |
| use_auth_token = deprecated_kwargs.pop("use_auth_token", None) | |
| if use_auth_token is not None: | |
| warnings.warn( | |
| "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning | |
| ) | |
| if token is not None: | |
| raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") | |
| token = use_auth_token | |
| if is_offline_mode() and not local_files_only: | |
| logger.info("Offline mode: forcing local_files_only=True") | |
| local_files_only = True | |
| # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. | |
| pretrained_model_name_or_path = str(pretrained_model_name_or_path) | |
| is_local = os.path.isdir(pretrained_model_name_or_path) | |
| if is_local: | |
| submodule = os.path.basename(pretrained_model_name_or_path) | |
| else: | |
| submodule = pretrained_model_name_or_path.replace("/", os.path.sep) | |
| cached_module = try_to_load_from_cache( | |
| pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type | |
| ) | |
| new_files = [] | |
| try: | |
| # Load from URL or cache if already cached | |
| resolved_module_file = cached_file( | |
| pretrained_model_name_or_path, | |
| module_file, | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| proxies=proxies, | |
| resume_download=resume_download, | |
| local_files_only=local_files_only, | |
| token=token, | |
| revision=revision, | |
| repo_type=repo_type, | |
| _commit_hash=_commit_hash, | |
| ) | |
| if not is_local and cached_module != resolved_module_file: | |
| new_files.append(module_file) | |
| except EnvironmentError: | |
| logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") | |
| raise | |
| # Check we have all the requirements in our environment | |
| modules_needed = check_imports(resolved_module_file) | |
| # Now we move the module inside our cached dynamic modules. | |
| full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule | |
| create_dynamic_module(full_submodule) | |
| submodule_path = Path(HF_MODULES_CACHE) / full_submodule | |
| if submodule == os.path.basename(pretrained_model_name_or_path): | |
| # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or | |
| # has changed since last copy. | |
| if not (submodule_path / module_file).exists() or not filecmp.cmp( | |
| resolved_module_file, str(submodule_path / module_file) | |
| ): | |
| shutil.copy(resolved_module_file, submodule_path / module_file) | |
| importlib.invalidate_caches() | |
| for module_needed in modules_needed: | |
| module_needed = f"{module_needed}.py" | |
| module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed) | |
| if not (submodule_path / module_needed).exists() or not filecmp.cmp( | |
| module_needed_file, str(submodule_path / module_needed) | |
| ): | |
| shutil.copy(module_needed_file, submodule_path / module_needed) | |
| importlib.invalidate_caches() | |
| else: | |
| # Get the commit hash | |
| commit_hash = extract_commit_hash(resolved_module_file, _commit_hash) | |
| # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the | |
| # benefit of versioning. | |
| submodule_path = submodule_path / commit_hash | |
| full_submodule = full_submodule + os.path.sep + commit_hash | |
| create_dynamic_module(full_submodule) | |
| if not (submodule_path / module_file).exists(): | |
| shutil.copy(resolved_module_file, submodule_path / module_file) | |
| importlib.invalidate_caches() | |
| # Make sure we also have every file with relative | |
| for module_needed in modules_needed: | |
| if not (submodule_path / f"{module_needed}.py").exists(): | |
| get_cached_module_file( | |
| pretrained_model_name_or_path, | |
| f"{module_needed}.py", | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| resume_download=resume_download, | |
| proxies=proxies, | |
| token=token, | |
| revision=revision, | |
| local_files_only=local_files_only, | |
| _commit_hash=commit_hash, | |
| ) | |
| new_files.append(f"{module_needed}.py") | |
| if len(new_files) > 0 and revision is None: | |
| new_files = "\n".join([f"- {f}" for f in new_files]) | |
| repo_type_str = "" if repo_type is None else f"{repo_type}s/" | |
| url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}" | |
| logger.warning( | |
| f"A new version of the following files was downloaded from {url}:\n{new_files}" | |
| "\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new " | |
| "versions of the code file, you can pin a revision." | |
| ) | |
| return os.path.join(full_submodule, module_file) | |
| def get_class_from_dynamic_module( | |
| class_reference: str, | |
| pretrained_model_name_or_path: Union[str, os.PathLike], | |
| cache_dir: Optional[Union[str, os.PathLike]] = None, | |
| force_download: bool = False, | |
| resume_download: bool = False, | |
| proxies: Optional[Dict[str, str]] = None, | |
| token: Optional[Union[bool, str]] = None, | |
| revision: Optional[str] = None, | |
| local_files_only: bool = False, | |
| repo_type: Optional[str] = None, | |
| code_revision: Optional[str] = None, | |
| **kwargs, | |
| ) -> typing.Type: | |
| """ | |
| Extracts a class from a module file, present in the local folder or repository of a model. | |
| <Tip warning={true}> | |
| Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should | |
| therefore only be called on trusted repos. | |
| </Tip> | |
| Args: | |
| class_reference (`str`): | |
| The full name of the class to load, including its module and optionally its repo. | |
| pretrained_model_name_or_path (`str` or `os.PathLike`): | |
| This can be either: | |
| - a string, the *model id* of a pretrained model configuration hosted inside a model repo on | |
| huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced | |
| under a user or organization name, like `dbmdz/bert-base-german-cased`. | |
| - a path to a *directory* containing a configuration file saved using the | |
| [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. | |
| This is used when `class_reference` does not specify another repo. | |
| module_file (`str`): | |
| The name of the module file containing the class to look for. | |
| class_name (`str`): | |
| The name of the class to import in the module. | |
| cache_dir (`str` or `os.PathLike`, *optional*): | |
| Path to a directory in which a downloaded pretrained model configuration should be cached if the standard | |
| cache should not be used. | |
| force_download (`bool`, *optional*, defaults to `False`): | |
| Whether or not to force to (re-)download the configuration files and override the cached versions if they | |
| exist. | |
| resume_download (`bool`, *optional*, defaults to `False`): | |
| Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. | |
| proxies (`Dict[str, str]`, *optional*): | |
| A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', | |
| 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. | |
| token (`str` or `bool`, *optional*): | |
| The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated | |
| when running `huggingface-cli login` (stored in `~/.huggingface`). | |
| revision (`str`, *optional*, defaults to `"main"`): | |
| The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |
| git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any | |
| identifier allowed by git. | |
| local_files_only (`bool`, *optional*, defaults to `False`): | |
| If `True`, will only try to load the tokenizer configuration from local files. | |
| repo_type (`str`, *optional*): | |
| Specify the repo type (useful when downloading from a space for instance). | |
| code_revision (`str`, *optional*, defaults to `"main"`): | |
| The specific revision to use for the code on the Hub, if the code leaves in a different repository than the | |
| rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for | |
| storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. | |
| <Tip> | |
| Passing `token=True` is required when you want to use a private model. | |
| </Tip> | |
| Returns: | |
| `typing.Type`: The class, dynamically imported from the module. | |
| Examples: | |
| ```python | |
| # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this | |
| # module. | |
| cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model") | |
| # Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this | |
| # module. | |
| cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model") | |
| ```""" | |
| use_auth_token = kwargs.pop("use_auth_token", None) | |
| if use_auth_token is not None: | |
| warnings.warn( | |
| "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning | |
| ) | |
| if token is not None: | |
| raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") | |
| token = use_auth_token | |
| # Catch the name of the repo if it's specified in `class_reference` | |
| if "--" in class_reference: | |
| repo_id, class_reference = class_reference.split("--") | |
| else: | |
| repo_id = pretrained_model_name_or_path | |
| module_file, class_name = class_reference.split(".") | |
| if code_revision is None and pretrained_model_name_or_path == repo_id: | |
| code_revision = revision | |
| # And lastly we get the class inside our newly created module | |
| final_module = get_cached_module_file( | |
| repo_id, | |
| module_file + ".py", | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| resume_download=resume_download, | |
| proxies=proxies, | |
| token=token, | |
| revision=code_revision, | |
| local_files_only=local_files_only, | |
| repo_type=repo_type, | |
| ) | |
| return get_class_in_module(class_name, final_module.replace(".py", "")) | |
| def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]: | |
| """ | |
| Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally | |
| adds the proper fields in a config. | |
| Args: | |
| obj (`Any`): The object for which to save the module files. | |
| folder (`str` or `os.PathLike`): The folder where to save. | |
| config (`PretrainedConfig` or dictionary, `optional`): | |
| A config in which to register the auto_map corresponding to this custom object. | |
| Returns: | |
| `List[str]`: The list of files saved. | |
| """ | |
| if obj.__module__ == "__main__": | |
| logger.warning( | |
| f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put " | |
| "this code in a separate module so we can include it in the saved folder and make it easier to share via " | |
| "the Hub." | |
| ) | |
| return | |
| def _set_auto_map_in_config(_config): | |
| module_name = obj.__class__.__module__ | |
| last_module = module_name.split(".")[-1] | |
| full_name = f"{last_module}.{obj.__class__.__name__}" | |
| # Special handling for tokenizers | |
| if "Tokenizer" in full_name: | |
| slow_tokenizer_class = None | |
| fast_tokenizer_class = None | |
| if obj.__class__.__name__.endswith("Fast"): | |
| # Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute. | |
| fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}" | |
| if getattr(obj, "slow_tokenizer_class", None) is not None: | |
| slow_tokenizer = getattr(obj, "slow_tokenizer_class") | |
| slow_tok_module_name = slow_tokenizer.__module__ | |
| last_slow_tok_module = slow_tok_module_name.split(".")[-1] | |
| slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}" | |
| else: | |
| # Slow tokenizer: no way to have the fast class | |
| slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}" | |
| full_name = (slow_tokenizer_class, fast_tokenizer_class) | |
| if isinstance(_config, dict): | |
| auto_map = _config.get("auto_map", {}) | |
| auto_map[obj._auto_class] = full_name | |
| _config["auto_map"] = auto_map | |
| elif getattr(_config, "auto_map", None) is not None: | |
| _config.auto_map[obj._auto_class] = full_name | |
| else: | |
| _config.auto_map = {obj._auto_class: full_name} | |
| # Add object class to the config auto_map | |
| if isinstance(config, (list, tuple)): | |
| for cfg in config: | |
| _set_auto_map_in_config(cfg) | |
| elif config is not None: | |
| _set_auto_map_in_config(config) | |
| result = [] | |
| # Copy module file to the output folder. | |
| object_file = sys.modules[obj.__module__].__file__ | |
| dest_file = Path(folder) / (Path(object_file).name) | |
| shutil.copy(object_file, dest_file) | |
| result.append(dest_file) | |
| # Gather all relative imports recursively and make sure they are copied as well. | |
| for needed_file in get_relative_import_files(object_file): | |
| dest_file = Path(folder) / (Path(needed_file).name) | |
| shutil.copy(needed_file, dest_file) | |
| result.append(dest_file) | |
| return result | |
| def _raise_timeout_error(signum, frame): | |
| raise ValueError( | |
| "Loading this model requires you to execute custom code contained in the model repository on your local" | |
| "machine. Please set the option `trust_remote_code=True` to permit loading of this model." | |
| ) | |
| TIME_OUT_REMOTE_CODE = 15 | |
| def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code): | |
| if trust_remote_code is None: | |
| if has_local_code: | |
| trust_remote_code = False | |
| elif has_remote_code and TIME_OUT_REMOTE_CODE > 0: | |
| try: | |
| signal.signal(signal.SIGALRM, _raise_timeout_error) | |
| signal.alarm(TIME_OUT_REMOTE_CODE) | |
| while trust_remote_code is None: | |
| answer = input( | |
| f"The repository for {model_name} contains custom code which must be executed to correctly" | |
| f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" | |
| f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n" | |
| f"Do you wish to run the custom code? [y/N] " | |
| ) | |
| if answer.lower() in ["yes", "y", "1"]: | |
| trust_remote_code = True | |
| elif answer.lower() in ["no", "n", "0", ""]: | |
| trust_remote_code = False | |
| signal.alarm(0) | |
| except Exception: | |
| # OS which does not support signal.SIGALRM | |
| raise ValueError( | |
| f"The repository for {model_name} contains custom code which must be executed to correctly" | |
| f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" | |
| f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." | |
| ) | |
| elif has_remote_code: | |
| # For the CI which puts the timeout at 0 | |
| _raise_timeout_error(None, None) | |
| if has_remote_code and not has_local_code and not trust_remote_code: | |
| raise ValueError( | |
| f"Loading {model_name} requires you to execute the configuration file in that" | |
| " repo on your local machine. Make sure you have read the code there to avoid malicious use, then" | |
| " set the option `trust_remote_code=True` to remove this error." | |
| ) | |
| return trust_remote_code | |