Charm_15 / model-1-of-278.safetensors
GeminiFan207's picture
Rename model-1-of-10.safetensors to model-1-of-278.safetensors
84d8acf verified
from safetensors.torch import load_file, save_file
import torch
from typing import List, Dict, Optional
import logging
from tqdm import tqdm
import os
import hashlib
from concurrent.futures import ThreadPoolExecutor, as_completed
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def calculate_checksum(file_path: str) -> str:
"""
Calculate the SHA-256 checksum of a file.
Args:
file_path (str): Path to the file.
Returns:
str: SHA-256 checksum of the file.
"""
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
sha256.update(chunk)
return sha256.hexdigest()
def verify_checksums(model_parts: List[str], expected_checksums: List[str]) -> None:
"""
Verify the checksums of model part files.
Args:
model_parts (list): List of model part file paths.
expected_checksums (list): List of expected checksums for each part.
Raises:
RuntimeError: If any checksum does not match.
"""
for part, expected_checksum in zip(model_parts, expected_checksums):
actual_checksum = calculate_checksum(part)
if actual_checksum != expected_checksum:
raise RuntimeError(f"Checksum mismatch for {part}: expected {expected_checksum}, got {actual_checksum}")
def load_part(part: str) -> Dict[str, torch.Tensor]:
"""
Load a single model part.
Args:
part (str): Path to the model part file.
Returns:
dict: State dictionary of the model part.
"""
return load_file(part)
def load_charm_model(model_parts: List[str], expected_checksums: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
"""
Load and merge multiple .safetensors model files.
Args:
model_parts (list): List of model part file paths (e.g., ["model-1-of-10.safetensors", ...]).
expected_checksums (list, optional): List of expected checksums for each part.
Returns:
dict: Merged model state dictionary.
Raises:
FileNotFoundError: If any model part file is missing.
RuntimeError: If there is an issue loading or merging the model parts.
"""
merged_state_dict = {}
# Check if all model parts exist
for part in model_parts:
if not os.path.exists(part):
raise FileNotFoundError(f"Model part not found: {part}")
# Verify checksums if provided
if expected_checksums:
logger.info("Verifying checksums...")
verify_checksums(model_parts, expected_checksums)
logger.info("Checksums verified successfully.")
# Load and merge model parts in parallel
try:
logger.info("Loading and merging model parts...")
with ThreadPoolExecutor() as executor:
futures = {executor.submit(load_part, part): part for part in model_parts}
for future in tqdm(as_completed(futures), total=len(futures), desc="Loading model parts"):
part = futures[future]
try:
state_dict = future.result()
merged_state_dict.update(state_dict) # Merge parameters
logger.debug(f"Loaded part: {part}")
except Exception as e:
logger.error(f"Error loading part {part}: {e}")
raise RuntimeError(f"Failed to load part: {part}")
logger.info("Model parts loaded and merged successfully.")
return merged_state_dict
except Exception as e:
logger.error(f"Error loading or merging model parts: {e}")
raise RuntimeError("Failed to load or merge model parts.")
# Example usage
if __name__ == "__main__":
try:
# List of model part files
model_files = [f"model-{i}-of-10.safetensors" for i in range(1, 11)]
# Optional: List of expected checksums for each part
expected_checksums = [
"checksum_for_model-1-of-10.safetensors",
"checksum_for_model-2-of-10.safetensors",
# Add checksums for all parts...
]
# Load and merge the model
charm_model = load_charm_model(model_files, expected_checksums)
# Save the merged model as a .safetensors file
output_file = "merged_model.safetensors"
save_file(charm_model, output_file)
logger.info(f"Merged model saved as '{output_file}'.")
except Exception as e:
logger.error(f"An error occurred: {e}")