How to load the model?
#2
by
jvhoffbauer
- opened
Loading the model with
reward_model = AutoModelForSequenceClassification.from_pretrained(
"trl-lib/llama-7b-se-rm-peft",
num_labels=1,
torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("trl-lib/llama-7b-se-rm-peft")
yields the following error
---------------------------------------------------------------------------
HTTPError Traceback (most recent call last)
File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py:259, in hf_raise_for_status(response, endpoint_name)
258 try:
--> 259 response.raise_for_status()
260 except HTTPError as e:
File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/requests/models.py:1021, in Response.raise_for_status(self)
1020 if http_error_msg:
-> 1021 raise HTTPError(http_error_msg, response=self)
HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/trl-lib/llama-7b-se-rm-peft/resolve/main/config.json
The above exception was the direct cause of the following exception:
EntryNotFoundError Traceback (most recent call last)
File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/transformers/utils/hub.py:427, in cached_file(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)
425 try:
426 # Load from URL or cache if already cached
--> 427 resolved_file = hf_hub_download(
428 path_or_repo_id,
429 filename,
430 subfolder=None if len(subfolder) == 0 else subfolder,
431 repo_type=repo_type,
432 revision=revision,
433 cache_dir=cache_dir,
434 user_agent=user_agent,
435 force_download=force_download,
436 proxies=proxies,
437 resume_download=resume_download,
438 token=token,
439 local_files_only=local_files_only,
440 )
441 except GatedRepoError as e:
File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py:120, in validate_hf_hub_args.._inner_fn(*args, **kwargs)
118 kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 120 return fn(*args, **kwargs)
File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1195, in hf_hub_download(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, local_dir_use_symlinks, user_agent, force_download, force_filename, proxies, etag_timeout, resume_download, token, local_files_only, legacy_cache_layout)
1194 try:
-> 1195 metadata = get_hf_file_metadata(
1196 url=url,
1197 token=token,
1198 proxies=proxies,
1199 timeout=etag_timeout,
1200 )
1201 except EntryNotFoundError as http_error:
1202 # Cache the non-existence of the file and raise
File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py:120, in validate_hf_hub_args.._inner_fn(*args, **kwargs)
118 kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 120 return fn(*args, **kwargs)
File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1541, in get_hf_file_metadata(url, token, proxies, timeout)
1532 r = _request_wrapper(
1533 method="HEAD",
1534 url=url,
(...)
1539 timeout=timeout,
1540 )
-> 1541 hf_raise_for_status(r)
1543 # Return
File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py:269, in hf_raise_for_status(response, endpoint_name)
268 message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}."
--> 269 raise EntryNotFoundError(message, response) from e
271 elif error_code == "GatedRepo":
EntryNotFoundError: 404 Client Error. (Request ID: Root=1-64e60e7f-642a805a39a142330e405e81)
Entry Not Found for url: https://huggingface.co/trl-lib/llama-7b-se-rm-peft/resolve/main/config.json.
The above exception was the direct cause of the following exception:
OSError Traceback (most recent call last)
Cell In[10], line 1
----> 1 reward_model = AutoModelForSequenceClassification.from_pretrained(
2 "trl-lib/llama-7b-se-rm-peft",
3 num_labels=1,
4 torch_dtype=torch.bfloat16
5 )
7 tokenizer = AutoTokenizer.from_pretrained("trl-lib/llama-7b-se-rm-peft")
File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:479, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
476 if kwargs.get("torch_dtype", None) == "auto":
477 _ = kwargs.pop("torch_dtype")
--> 479 config, kwargs = AutoConfig.from_pretrained(
480 pretrained_model_name_or_path,
481 return_unused_kwargs=True,
482 trust_remote_code=trust_remote_code,
483 **hub_kwargs,
484 **kwargs,
485 )
487 # if torch_dtype=auto was passed here, ensure to pass it on
488 if kwargs_orig.get("torch_dtype", None) == "auto":
File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/transformers/models/auto/configuration_auto.py:1004, in AutoConfig.from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
1002 kwargs["name_or_path"] = pretrained_model_name_or_path
1003 trust_remote_code = kwargs.pop("trust_remote_code", None)
-> 1004 config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
1005 has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]
1006 has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING
File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/transformers/configuration_utils.py:620, in PretrainedConfig.get_config_dict(cls, pretrained_model_name_or_path, **kwargs)
618 original_kwargs = copy.deepcopy(kwargs)
619 # Get config dict associated with the base config file
--> 620 config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
621 if "_commit_hash" in config_dict:
622 original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/transformers/configuration_utils.py:675, in PretrainedConfig._get_config_dict(cls, pretrained_model_name_or_path, **kwargs)
671 configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
673 try:
674 # Load from local folder or from cache or download from model Hub and cache
--> 675 resolved_config_file = cached_file(
676 pretrained_model_name_or_path,
677 configuration_file,
678 cache_dir=cache_dir,
679 force_download=force_download,
680 proxies=proxies,
681 resume_download=resume_download,
682 local_files_only=local_files_only,
683 token=token,
684 user_agent=user_agent,
685 revision=revision,
686 subfolder=subfolder,
687 _commit_hash=commit_hash,
688 )
689 commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
690 except EnvironmentError:
691 # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
692 # the original exception.
File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/transformers/utils/hub.py:478, in cached_file(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)
476 if revision is None:
477 revision = "main"
--> 478 raise EnvironmentError(
479 f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
480 f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
481 ) from e
482 except HTTPError as err:
483 # First we try to see if we have a cached version (not up to date):
484 resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
OSError: trl-lib/llama-7b-se-rm-peft does not appear to have a file named config.json. Checkout 'https://huggingface.co/trl-lib/llama-7b-se-rm-peft/main' for available files.
This is an adapter version which only contain the LoRA layers parameters. You should merge this peft adapter layers with the base model first, using:
python examples/stack_llama/scripts/merge_peft_adapter.py --adapter_model_name=XXX --base_model_name=YYY --output_name=ZZZ
You can find the merge_peft_adapter.py
under their repository. Good luck!
Thanks a lot!
jvhoffbauer
changed discussion status to
closed