|
from safetensors.torch import load_file, save_file |
|
import torch |
|
from typing import List, Dict, Optional |
|
import logging |
|
from tqdm import tqdm |
|
import os |
|
import hashlib |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
logger = logging.getLogger(__name__) |
|
|
|
def calculate_checksum(file_path: str) -> str: |
|
""" |
|
Calculate the SHA-256 checksum of a file. |
|
|
|
Args: |
|
file_path (str): Path to the file. |
|
|
|
Returns: |
|
str: SHA-256 checksum of the file. |
|
""" |
|
sha256 = hashlib.sha256() |
|
with open(file_path, "rb") as f: |
|
for chunk in iter(lambda: f.read(4096), b""): |
|
sha256.update(chunk) |
|
return sha256.hexdigest() |
|
|
|
def verify_checksums(model_parts: List[str], expected_checksums: List[str]) -> None: |
|
""" |
|
Verify the checksums of model part files. |
|
|
|
Args: |
|
model_parts (list): List of model part file paths. |
|
expected_checksums (list): List of expected checksums for each part. |
|
|
|
Raises: |
|
RuntimeError: If any checksum does not match. |
|
""" |
|
for part, expected_checksum in zip(model_parts, expected_checksums): |
|
actual_checksum = calculate_checksum(part) |
|
if actual_checksum != expected_checksum: |
|
raise RuntimeError(f"Checksum mismatch for {part}: expected {expected_checksum}, got {actual_checksum}") |
|
|
|
def load_part(part: str) -> Dict[str, torch.Tensor]: |
|
""" |
|
Load a single model part. |
|
|
|
Args: |
|
part (str): Path to the model part file. |
|
|
|
Returns: |
|
dict: State dictionary of the model part. |
|
""" |
|
return load_file(part) |
|
|
|
def load_charm_model(model_parts: List[str], expected_checksums: Optional[List[str]] = None) -> Dict[str, torch.Tensor]: |
|
""" |
|
Load and merge multiple .safetensors model files. |
|
|
|
Args: |
|
model_parts (list): List of model part file paths (e.g., ["model-1-of-10.safetensors", ...]). |
|
expected_checksums (list, optional): List of expected checksums for each part. |
|
|
|
Returns: |
|
dict: Merged model state dictionary. |
|
|
|
Raises: |
|
FileNotFoundError: If any model part file is missing. |
|
RuntimeError: If there is an issue loading or merging the model parts. |
|
""" |
|
merged_state_dict = {} |
|
|
|
|
|
for part in model_parts: |
|
if not os.path.exists(part): |
|
raise FileNotFoundError(f"Model part not found: {part}") |
|
|
|
|
|
if expected_checksums: |
|
logger.info("Verifying checksums...") |
|
verify_checksums(model_parts, expected_checksums) |
|
logger.info("Checksums verified successfully.") |
|
|
|
|
|
try: |
|
logger.info("Loading and merging model parts...") |
|
with ThreadPoolExecutor() as executor: |
|
futures = {executor.submit(load_part, part): part for part in model_parts} |
|
for future in tqdm(as_completed(futures), total=len(futures), desc="Loading model parts"): |
|
part = futures[future] |
|
try: |
|
state_dict = future.result() |
|
merged_state_dict.update(state_dict) |
|
logger.debug(f"Loaded part: {part}") |
|
except Exception as e: |
|
logger.error(f"Error loading part {part}: {e}") |
|
raise RuntimeError(f"Failed to load part: {part}") |
|
|
|
logger.info("Model parts loaded and merged successfully.") |
|
return merged_state_dict |
|
except Exception as e: |
|
logger.error(f"Error loading or merging model parts: {e}") |
|
raise RuntimeError("Failed to load or merge model parts.") |
|
|
|
|
|
if __name__ == "__main__": |
|
try: |
|
|
|
model_files = [f"model-{i}-of-10.safetensors" for i in range(1, 11)] |
|
|
|
|
|
expected_checksums = [ |
|
"checksum_for_model-1-of-10.safetensors", |
|
"checksum_for_model-2-of-10.safetensors", |
|
|
|
] |
|
|
|
|
|
charm_model = load_charm_model(model_files, expected_checksums) |
|
|
|
|
|
output_file = "merged_model.safetensors" |
|
save_file(charm_model, output_file) |
|
logger.info(f"Merged model saved as '{output_file}'.") |
|
except Exception as e: |
|
logger.error(f"An error occurred: {e}") |