import transformers from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5EncoderModel, T5Config import torch import os import traceback # For detailed error reporting # ----------------------------------------------------------------------------- # 1. CONFIGURATION - USER MUST MODIFY THESE # ----------------------------------------------------------------------------- # Path to your LOCAL full FLAN-T5 model directory (e.g., float32 version) # This directory should contain config.json, pytorch_model.bin or .safetensors files, tokenizer files etc. LOCAL_FULL_MODEL_DIR = "D:/ai/Models/t5_flan_encoder" # <--- CHANGE THIS # Directory where you want to save the encoder-only model OUTPUT_ENCODER_MODEL_DIR = "D:/ai/Models/flan-t5-xxl-encoder-only" # <--- You can change this # --- Optional Settings --- # Set to True to convert and save the encoder in bfloat16 precision. # Set to False to save in the precision it was loaded in (or float32 by default if not specified). SAVE_AS_BFLOAT16 = True # For very large models, low_cpu_mem_usage is recommended during initial loading. LOW_CPU_MEM_USAGE_ON_LOAD = True # ----------------------------------------------------------------------------- # SCRIPT STARTS HERE # ----------------------------------------------------------------------------- def convert_and_save_encoder(): print("Starting encoder extraction and saving process...") # --- Create output directory if it doesn't exist --- if not os.path.exists(OUTPUT_ENCODER_MODEL_DIR): try: os.makedirs(OUTPUT_ENCODER_MODEL_DIR) print(f"Created output directory: {OUTPUT_ENCODER_MODEL_DIR}") except OSError as e: print(f"Error creating directory {OUTPUT_ENCODER_MODEL_DIR}: {e}") return # --- 1. Load the tokenizer from your local full model directory --- print(f"\nStep 1: Loading tokenizer from {LOCAL_FULL_MODEL_DIR}...") try: tokenizer = AutoTokenizer.from_pretrained(LOCAL_FULL_MODEL_DIR) print("Tokenizer loaded successfully.") except Exception as e: print(f"ERROR: Could not load tokenizer from {LOCAL_FULL_MODEL_DIR}: {e}") traceback.print_exc() return # --- 2. Load the full Seq2Seq model from your local directory --- print(f"\nStep 2: Loading full Seq2Seq model from {LOCAL_FULL_MODEL_DIR}...") try: full_model = AutoModelForSeq2SeqLM.from_pretrained( LOCAL_FULL_MODEL_DIR, low_cpu_mem_usage=LOW_CPU_MEM_USAGE_ON_LOAD ) print(f"Full model loaded. Original dtype: {full_model.dtype}") except Exception as e: print(f"ERROR: Could not load full model from {LOCAL_FULL_MODEL_DIR}: {e}") traceback.print_exc() return # --- 3. Get the encoder (T5Stack) from the full model --- print("\nStep 3: Extracting encoder stack from the full model...") try: extracted_encoder_stack = full_model.get_encoder() # This is a T5Stack instance print("Encoder stack extracted successfully.") except Exception as e: print(f"ERROR: Could not extract encoder stack: {e}") traceback.print_exc() return # --- 4. Prepare and Save the encoder as a T5EncoderModel --- # This is the corrected approach to ensure proper weight names and config for T5EncoderModel print(f"\nStep 4: Preparing and saving encoder as T5EncoderModel to {OUTPUT_ENCODER_MODEL_DIR}...") try: # Get the config from the extracted encoder stack (which is a T5Config) encoder_config = extracted_encoder_stack.config if not isinstance(encoder_config, T5Config): # Should be T5Config print(f"Warning: Extracted encoder config is type {type(encoder_config)}, expected T5Config.") # Create a new T5EncoderModel instance. This will initialize its own T5Stack (at .encoder) # and its own shared embeddings (at .shared) based on the config. final_model_to_save = T5EncoderModel(config=encoder_config) print(f"Created an empty T5EncoderModel. Initial dtype: {final_model_to_save.dtype}") # Get the state_dict from the extracted T5Stack extracted_stack_state_dict = extracted_encoder_stack.state_dict() print(f"Extracted T5Stack state_dict contains {len(extracted_stack_state_dict)} keys.") # A. Load the shared embeddings into the T5EncoderModel's top-level shared embedding layer if 'shared.weight' in extracted_stack_state_dict: # Create a state_dict specifically for the .shared module of T5EncoderModel shared_embedding_state_dict = {'weight': extracted_stack_state_dict['shared.weight']} final_model_to_save.shared.load_state_dict(shared_embedding_state_dict) print("Successfully transferred shared embeddings to T5EncoderModel.shared.") else: print("WARNING: 'shared.weight' not found in the extracted encoder stack's state_dict. " "The T5EncoderModel's shared embeddings will remain as randomly initialized.") # B. Prepare the state_dict for the T5Stack *within* T5EncoderModel (final_model_to_save.encoder) # These are weights like 'block.i...', 'final_layer_norm.weight' # We filter out 'shared.weight' as it's handled by final_model_to_save.shared state_dict_for_internal_stack = { k: v for k, v in extracted_stack_state_dict.items() if not k.startswith('shared.') } if not state_dict_for_internal_stack: print("WARNING: No weights found for the internal encoder stack (e.g., block weights, final_layer_norm). " "This part of the T5EncoderModel will remain randomly initialized.") else: print(f"Transferring {len(state_dict_for_internal_stack)} keys to T5EncoderModel.encoder (the internal T5Stack)...") missing_keys, unexpected_keys = final_model_to_save.encoder.load_state_dict( state_dict_for_internal_stack, strict=False # Use strict=False for now to see issues ) if missing_keys: print(f"WARNING: Missing keys when loading internal T5Stack: {missing_keys}") if unexpected_keys: print(f"WARNING: Unexpected keys when loading internal T5Stack: {unexpected_keys}") if not missing_keys and not unexpected_keys: print("Successfully transferred weights to T5EncoderModel.encoder.") else: print("NOTICE: There were missing or unexpected keys during internal T5Stack weight transfer. Review warnings.") # C. Optionally, convert the fully assembled T5EncoderModel to bfloat16 if SAVE_AS_BFLOAT16: print("Converting final T5EncoderModel to bfloat16...") if hasattr(torch, 'bfloat16'): final_model_to_save = final_model_to_save.to(dtype=torch.bfloat16) print(f"Final T5EncoderModel converted. New dtype: {final_model_to_save.dtype}") else: print("WARNING: torch.bfloat16 not available. Skipping bfloat16 conversion. Model will be saved in current dtype.") else: print("Skipping bfloat16 conversion as per SAVE_AS_BFLOAT16 flag.") # D. Save the T5EncoderModel instance # This will save weights with correct prefixes (e.g., 'encoder.block...', 'shared.weight') # and a config.json appropriate for T5EncoderModel. print(f"Saving T5EncoderModel to {OUTPUT_ENCODER_MODEL_DIR}...") final_model_to_save.save_pretrained(OUTPUT_ENCODER_MODEL_DIR) tokenizer.save_pretrained(OUTPUT_ENCODER_MODEL_DIR) # Save the tokenizer too print("T5EncoderModel and tokenizer saved successfully.") except Exception as e: print(f"ERROR during T5EncoderModel preparation or saving: {e}") traceback.print_exc() return # --- 5. Verification (Optional but Recommended) --- print(f"\nStep 5: Verifying the saved encoder model from {OUTPUT_ENCODER_MODEL_DIR}...") try: load_kwargs = {} expected_dtype_after_load = torch.float32 # Default assumption if SAVE_AS_BFLOAT16 and hasattr(torch, 'bfloat16'): print("Verification: Attempting to load model explicitly as bfloat16.") load_kwargs['torch_dtype'] = torch.bfloat16 expected_dtype_after_load = torch.bfloat16 elif SAVE_AS_BFLOAT16 and not hasattr(torch, 'bfloat16'): print("Verification: bfloat16 was intended but not available; expecting original precision (likely float32).") # Model would have been saved in its original precision (e.g. final_model_to_save.dtype before save) # For simplicity, we'll stick to expecting float32 if bfloat16 unavailable expected_dtype_after_load = final_model_to_save.dtype # Dtype before save if bfloat16 failed else: # SAVE_AS_BFLOAT16 is False print("Verification: Attempting to load model in its saved precision (expected float32 or original).") expected_dtype_after_load = final_model_to_save.dtype # Dtype before save # The crucial test: Loading it should not produce "weights were not initialized" warnings. # Those warnings appear on stderr if issues occur during from_pretrained. print(f"Loading T5EncoderModel with kwargs: {load_kwargs}") loaded_encoder_model = T5EncoderModel.from_pretrained( OUTPUT_ENCODER_MODEL_DIR, **load_kwargs ) print("Successfully loaded encoder model for verification.") # If this prints, "weights not initialized" was likely avoided print(f"Loaded encoder model dtype: {loaded_encoder_model.dtype}") if loaded_encoder_model.dtype == expected_dtype_after_load: print(f"VERIFICATION SUCCESSFUL: Loaded model dtype ({loaded_encoder_model.dtype}) matches expected dtype ({expected_dtype_after_load}).") else: print(f"VERIFICATION WARNING: Loaded model dtype is {loaded_encoder_model.dtype}, but expected {expected_dtype_after_load}. " "This might be okay if precision changed due to availability or specific load behavior.") # Further test: attempt a simple inference print("Attempting a sample inference with the loaded encoder model...") test_input_text = "This is a test sentence for the encoder." inputs = tokenizer(test_input_text, return_tensors="pt", padding=True, truncation=True) # Move model and inputs to GPU if available if torch.cuda.is_available(): print("CUDA available. Moving model and inputs to GPU.") device = torch.device("cuda") loaded_encoder_model.to(device) inputs = {k: v.to(device) for k, v in inputs.items()} else: print("CUDA not available. Using CPU.") device = torch.device("cpu") loaded_encoder_model.to(device) # Ensure model is on CPU if inputs are inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = loaded_encoder_model(**inputs) last_hidden_states = outputs.last_hidden_state print(f"Shape of last hidden state from loaded model: {last_hidden_states.shape}") print("Sample inference completed successfully.") except Exception as e: print(f"ERROR during verification: {e}") traceback.print_exc() print("\nEncoder extraction and saving process finished.") if __name__ == "__main__": # --- Preliminary Checks --- if not os.path.isdir(LOCAL_FULL_MODEL_DIR): print(f"ERROR: The specified local full model directory does not exist: {LOCAL_FULL_MODEL_DIR}") print("Please ensure the path is correct and the full model is downloaded there.") elif LOCAL_FULL_MODEL_DIR == "./path/to/your/local/flan-t5-xxl": print("ERROR: Please change the placeholder path for LOCAL_FULL_MODEL_DIR in the script.") else: print(f"PyTorch version: {torch.__version__}") print(f"Transformers version: {transformers.__version__}") if SAVE_AS_BFLOAT16: if hasattr(torch, 'bfloat16'): print("torch.bfloat16 is available.") else: print("WARNING: SAVE_AS_BFLOAT16 is True, but torch.bfloat16 is NOT available in this PyTorch version. " "Conversion will be skipped.") convert_and_save_encoder()