FLAN-T5xxl / convert-bf16-enc.py
skunkworx's picture
Upload convert-bf16-enc.py
d4bc4ed verified
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5EncoderModel, T5Config
import torch
import os
import traceback # For detailed error reporting
# -----------------------------------------------------------------------------
# 1. CONFIGURATION - USER MUST MODIFY THESE
# -----------------------------------------------------------------------------
# Path to your LOCAL full FLAN-T5 model directory (e.g., float32 version)
# This directory should contain config.json, pytorch_model.bin or .safetensors files, tokenizer files etc.
LOCAL_FULL_MODEL_DIR = "D:/ai/Models/t5_flan_encoder" # <--- CHANGE THIS
# Directory where you want to save the encoder-only model
OUTPUT_ENCODER_MODEL_DIR = "D:/ai/Models/flan-t5-xxl-encoder-only" # <--- You can change this
# --- Optional Settings ---
# Set to True to convert and save the encoder in bfloat16 precision.
# Set to False to save in the precision it was loaded in (or float32 by default if not specified).
SAVE_AS_BFLOAT16 = True
# For very large models, low_cpu_mem_usage is recommended during initial loading.
LOW_CPU_MEM_USAGE_ON_LOAD = True
# -----------------------------------------------------------------------------
# SCRIPT STARTS HERE
# -----------------------------------------------------------------------------
def convert_and_save_encoder():
print("Starting encoder extraction and saving process...")
# --- Create output directory if it doesn't exist ---
if not os.path.exists(OUTPUT_ENCODER_MODEL_DIR):
try:
os.makedirs(OUTPUT_ENCODER_MODEL_DIR)
print(f"Created output directory: {OUTPUT_ENCODER_MODEL_DIR}")
except OSError as e:
print(f"Error creating directory {OUTPUT_ENCODER_MODEL_DIR}: {e}")
return
# --- 1. Load the tokenizer from your local full model directory ---
print(f"\nStep 1: Loading tokenizer from {LOCAL_FULL_MODEL_DIR}...")
try:
tokenizer = AutoTokenizer.from_pretrained(LOCAL_FULL_MODEL_DIR)
print("Tokenizer loaded successfully.")
except Exception as e:
print(f"ERROR: Could not load tokenizer from {LOCAL_FULL_MODEL_DIR}: {e}")
traceback.print_exc()
return
# --- 2. Load the full Seq2Seq model from your local directory ---
print(f"\nStep 2: Loading full Seq2Seq model from {LOCAL_FULL_MODEL_DIR}...")
try:
full_model = AutoModelForSeq2SeqLM.from_pretrained(
LOCAL_FULL_MODEL_DIR,
low_cpu_mem_usage=LOW_CPU_MEM_USAGE_ON_LOAD
)
print(f"Full model loaded. Original dtype: {full_model.dtype}")
except Exception as e:
print(f"ERROR: Could not load full model from {LOCAL_FULL_MODEL_DIR}: {e}")
traceback.print_exc()
return
# --- 3. Get the encoder (T5Stack) from the full model ---
print("\nStep 3: Extracting encoder stack from the full model...")
try:
extracted_encoder_stack = full_model.get_encoder() # This is a T5Stack instance
print("Encoder stack extracted successfully.")
except Exception as e:
print(f"ERROR: Could not extract encoder stack: {e}")
traceback.print_exc()
return
# --- 4. Prepare and Save the encoder as a T5EncoderModel ---
# This is the corrected approach to ensure proper weight names and config for T5EncoderModel
print(f"\nStep 4: Preparing and saving encoder as T5EncoderModel to {OUTPUT_ENCODER_MODEL_DIR}...")
try:
# Get the config from the extracted encoder stack (which is a T5Config)
encoder_config = extracted_encoder_stack.config
if not isinstance(encoder_config, T5Config): # Should be T5Config
print(f"Warning: Extracted encoder config is type {type(encoder_config)}, expected T5Config.")
# Create a new T5EncoderModel instance. This will initialize its own T5Stack (at .encoder)
# and its own shared embeddings (at .shared) based on the config.
final_model_to_save = T5EncoderModel(config=encoder_config)
print(f"Created an empty T5EncoderModel. Initial dtype: {final_model_to_save.dtype}")
# Get the state_dict from the extracted T5Stack
extracted_stack_state_dict = extracted_encoder_stack.state_dict()
print(f"Extracted T5Stack state_dict contains {len(extracted_stack_state_dict)} keys.")
# A. Load the shared embeddings into the T5EncoderModel's top-level shared embedding layer
if 'shared.weight' in extracted_stack_state_dict:
# Create a state_dict specifically for the .shared module of T5EncoderModel
shared_embedding_state_dict = {'weight': extracted_stack_state_dict['shared.weight']}
final_model_to_save.shared.load_state_dict(shared_embedding_state_dict)
print("Successfully transferred shared embeddings to T5EncoderModel.shared.")
else:
print("WARNING: 'shared.weight' not found in the extracted encoder stack's state_dict. "
"The T5EncoderModel's shared embeddings will remain as randomly initialized.")
# B. Prepare the state_dict for the T5Stack *within* T5EncoderModel (final_model_to_save.encoder)
# These are weights like 'block.i...', 'final_layer_norm.weight'
# We filter out 'shared.weight' as it's handled by final_model_to_save.shared
state_dict_for_internal_stack = {
k: v for k, v in extracted_stack_state_dict.items() if not k.startswith('shared.')
}
if not state_dict_for_internal_stack:
print("WARNING: No weights found for the internal encoder stack (e.g., block weights, final_layer_norm). "
"This part of the T5EncoderModel will remain randomly initialized.")
else:
print(f"Transferring {len(state_dict_for_internal_stack)} keys to T5EncoderModel.encoder (the internal T5Stack)...")
missing_keys, unexpected_keys = final_model_to_save.encoder.load_state_dict(
state_dict_for_internal_stack, strict=False # Use strict=False for now to see issues
)
if missing_keys:
print(f"WARNING: Missing keys when loading internal T5Stack: {missing_keys}")
if unexpected_keys:
print(f"WARNING: Unexpected keys when loading internal T5Stack: {unexpected_keys}")
if not missing_keys and not unexpected_keys:
print("Successfully transferred weights to T5EncoderModel.encoder.")
else:
print("NOTICE: There were missing or unexpected keys during internal T5Stack weight transfer. Review warnings.")
# C. Optionally, convert the fully assembled T5EncoderModel to bfloat16
if SAVE_AS_BFLOAT16:
print("Converting final T5EncoderModel to bfloat16...")
if hasattr(torch, 'bfloat16'):
final_model_to_save = final_model_to_save.to(dtype=torch.bfloat16)
print(f"Final T5EncoderModel converted. New dtype: {final_model_to_save.dtype}")
else:
print("WARNING: torch.bfloat16 not available. Skipping bfloat16 conversion. Model will be saved in current dtype.")
else:
print("Skipping bfloat16 conversion as per SAVE_AS_BFLOAT16 flag.")
# D. Save the T5EncoderModel instance
# This will save weights with correct prefixes (e.g., 'encoder.block...', 'shared.weight')
# and a config.json appropriate for T5EncoderModel.
print(f"Saving T5EncoderModel to {OUTPUT_ENCODER_MODEL_DIR}...")
final_model_to_save.save_pretrained(OUTPUT_ENCODER_MODEL_DIR)
tokenizer.save_pretrained(OUTPUT_ENCODER_MODEL_DIR) # Save the tokenizer too
print("T5EncoderModel and tokenizer saved successfully.")
except Exception as e:
print(f"ERROR during T5EncoderModel preparation or saving: {e}")
traceback.print_exc()
return
# --- 5. Verification (Optional but Recommended) ---
print(f"\nStep 5: Verifying the saved encoder model from {OUTPUT_ENCODER_MODEL_DIR}...")
try:
load_kwargs = {}
expected_dtype_after_load = torch.float32 # Default assumption
if SAVE_AS_BFLOAT16 and hasattr(torch, 'bfloat16'):
print("Verification: Attempting to load model explicitly as bfloat16.")
load_kwargs['torch_dtype'] = torch.bfloat16
expected_dtype_after_load = torch.bfloat16
elif SAVE_AS_BFLOAT16 and not hasattr(torch, 'bfloat16'):
print("Verification: bfloat16 was intended but not available; expecting original precision (likely float32).")
# Model would have been saved in its original precision (e.g. final_model_to_save.dtype before save)
# For simplicity, we'll stick to expecting float32 if bfloat16 unavailable
expected_dtype_after_load = final_model_to_save.dtype # Dtype before save if bfloat16 failed
else: # SAVE_AS_BFLOAT16 is False
print("Verification: Attempting to load model in its saved precision (expected float32 or original).")
expected_dtype_after_load = final_model_to_save.dtype # Dtype before save
# The crucial test: Loading it should not produce "weights were not initialized" warnings.
# Those warnings appear on stderr if issues occur during from_pretrained.
print(f"Loading T5EncoderModel with kwargs: {load_kwargs}")
loaded_encoder_model = T5EncoderModel.from_pretrained(
OUTPUT_ENCODER_MODEL_DIR,
**load_kwargs
)
print("Successfully loaded encoder model for verification.") # If this prints, "weights not initialized" was likely avoided
print(f"Loaded encoder model dtype: {loaded_encoder_model.dtype}")
if loaded_encoder_model.dtype == expected_dtype_after_load:
print(f"VERIFICATION SUCCESSFUL: Loaded model dtype ({loaded_encoder_model.dtype}) matches expected dtype ({expected_dtype_after_load}).")
else:
print(f"VERIFICATION WARNING: Loaded model dtype is {loaded_encoder_model.dtype}, but expected {expected_dtype_after_load}. "
"This might be okay if precision changed due to availability or specific load behavior.")
# Further test: attempt a simple inference
print("Attempting a sample inference with the loaded encoder model...")
test_input_text = "This is a test sentence for the encoder."
inputs = tokenizer(test_input_text, return_tensors="pt", padding=True, truncation=True)
# Move model and inputs to GPU if available
if torch.cuda.is_available():
print("CUDA available. Moving model and inputs to GPU.")
device = torch.device("cuda")
loaded_encoder_model.to(device)
inputs = {k: v.to(device) for k, v in inputs.items()}
else:
print("CUDA not available. Using CPU.")
device = torch.device("cpu")
loaded_encoder_model.to(device) # Ensure model is on CPU if inputs are
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = loaded_encoder_model(**inputs)
last_hidden_states = outputs.last_hidden_state
print(f"Shape of last hidden state from loaded model: {last_hidden_states.shape}")
print("Sample inference completed successfully.")
except Exception as e:
print(f"ERROR during verification: {e}")
traceback.print_exc()
print("\nEncoder extraction and saving process finished.")
if __name__ == "__main__":
# --- Preliminary Checks ---
if not os.path.isdir(LOCAL_FULL_MODEL_DIR):
print(f"ERROR: The specified local full model directory does not exist: {LOCAL_FULL_MODEL_DIR}")
print("Please ensure the path is correct and the full model is downloaded there.")
elif LOCAL_FULL_MODEL_DIR == "./path/to/your/local/flan-t5-xxl":
print("ERROR: Please change the placeholder path for LOCAL_FULL_MODEL_DIR in the script.")
else:
print(f"PyTorch version: {torch.__version__}")
print(f"Transformers version: {transformers.__version__}")
if SAVE_AS_BFLOAT16:
if hasattr(torch, 'bfloat16'):
print("torch.bfloat16 is available.")
else:
print("WARNING: SAVE_AS_BFLOAT16 is True, but torch.bfloat16 is NOT available in this PyTorch version. "
"Conversion will be skipped.")
convert_and_save_encoder()