|
import transformers
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5EncoderModel, T5Config
|
|
import torch
|
|
import os
|
|
import traceback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LOCAL_FULL_MODEL_DIR = "D:/ai/Models/t5_flan_encoder"
|
|
|
|
|
|
OUTPUT_ENCODER_MODEL_DIR = "D:/ai/Models/flan-t5-xxl-encoder-only"
|
|
|
|
|
|
|
|
|
|
SAVE_AS_BFLOAT16 = True
|
|
|
|
|
|
LOW_CPU_MEM_USAGE_ON_LOAD = True
|
|
|
|
|
|
|
|
|
|
|
|
def convert_and_save_encoder():
|
|
print("Starting encoder extraction and saving process...")
|
|
|
|
|
|
if not os.path.exists(OUTPUT_ENCODER_MODEL_DIR):
|
|
try:
|
|
os.makedirs(OUTPUT_ENCODER_MODEL_DIR)
|
|
print(f"Created output directory: {OUTPUT_ENCODER_MODEL_DIR}")
|
|
except OSError as e:
|
|
print(f"Error creating directory {OUTPUT_ENCODER_MODEL_DIR}: {e}")
|
|
return
|
|
|
|
|
|
print(f"\nStep 1: Loading tokenizer from {LOCAL_FULL_MODEL_DIR}...")
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(LOCAL_FULL_MODEL_DIR)
|
|
print("Tokenizer loaded successfully.")
|
|
except Exception as e:
|
|
print(f"ERROR: Could not load tokenizer from {LOCAL_FULL_MODEL_DIR}: {e}")
|
|
traceback.print_exc()
|
|
return
|
|
|
|
|
|
print(f"\nStep 2: Loading full Seq2Seq model from {LOCAL_FULL_MODEL_DIR}...")
|
|
try:
|
|
full_model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
LOCAL_FULL_MODEL_DIR,
|
|
low_cpu_mem_usage=LOW_CPU_MEM_USAGE_ON_LOAD
|
|
)
|
|
print(f"Full model loaded. Original dtype: {full_model.dtype}")
|
|
except Exception as e:
|
|
print(f"ERROR: Could not load full model from {LOCAL_FULL_MODEL_DIR}: {e}")
|
|
traceback.print_exc()
|
|
return
|
|
|
|
|
|
print("\nStep 3: Extracting encoder stack from the full model...")
|
|
try:
|
|
extracted_encoder_stack = full_model.get_encoder()
|
|
print("Encoder stack extracted successfully.")
|
|
except Exception as e:
|
|
print(f"ERROR: Could not extract encoder stack: {e}")
|
|
traceback.print_exc()
|
|
return
|
|
|
|
|
|
|
|
print(f"\nStep 4: Preparing and saving encoder as T5EncoderModel to {OUTPUT_ENCODER_MODEL_DIR}...")
|
|
try:
|
|
|
|
encoder_config = extracted_encoder_stack.config
|
|
if not isinstance(encoder_config, T5Config):
|
|
print(f"Warning: Extracted encoder config is type {type(encoder_config)}, expected T5Config.")
|
|
|
|
|
|
|
|
final_model_to_save = T5EncoderModel(config=encoder_config)
|
|
print(f"Created an empty T5EncoderModel. Initial dtype: {final_model_to_save.dtype}")
|
|
|
|
|
|
extracted_stack_state_dict = extracted_encoder_stack.state_dict()
|
|
print(f"Extracted T5Stack state_dict contains {len(extracted_stack_state_dict)} keys.")
|
|
|
|
|
|
if 'shared.weight' in extracted_stack_state_dict:
|
|
|
|
shared_embedding_state_dict = {'weight': extracted_stack_state_dict['shared.weight']}
|
|
final_model_to_save.shared.load_state_dict(shared_embedding_state_dict)
|
|
print("Successfully transferred shared embeddings to T5EncoderModel.shared.")
|
|
else:
|
|
print("WARNING: 'shared.weight' not found in the extracted encoder stack's state_dict. "
|
|
"The T5EncoderModel's shared embeddings will remain as randomly initialized.")
|
|
|
|
|
|
|
|
|
|
state_dict_for_internal_stack = {
|
|
k: v for k, v in extracted_stack_state_dict.items() if not k.startswith('shared.')
|
|
}
|
|
|
|
if not state_dict_for_internal_stack:
|
|
print("WARNING: No weights found for the internal encoder stack (e.g., block weights, final_layer_norm). "
|
|
"This part of the T5EncoderModel will remain randomly initialized.")
|
|
else:
|
|
print(f"Transferring {len(state_dict_for_internal_stack)} keys to T5EncoderModel.encoder (the internal T5Stack)...")
|
|
missing_keys, unexpected_keys = final_model_to_save.encoder.load_state_dict(
|
|
state_dict_for_internal_stack, strict=False
|
|
)
|
|
if missing_keys:
|
|
print(f"WARNING: Missing keys when loading internal T5Stack: {missing_keys}")
|
|
if unexpected_keys:
|
|
print(f"WARNING: Unexpected keys when loading internal T5Stack: {unexpected_keys}")
|
|
if not missing_keys and not unexpected_keys:
|
|
print("Successfully transferred weights to T5EncoderModel.encoder.")
|
|
else:
|
|
print("NOTICE: There were missing or unexpected keys during internal T5Stack weight transfer. Review warnings.")
|
|
|
|
|
|
|
|
if SAVE_AS_BFLOAT16:
|
|
print("Converting final T5EncoderModel to bfloat16...")
|
|
if hasattr(torch, 'bfloat16'):
|
|
final_model_to_save = final_model_to_save.to(dtype=torch.bfloat16)
|
|
print(f"Final T5EncoderModel converted. New dtype: {final_model_to_save.dtype}")
|
|
else:
|
|
print("WARNING: torch.bfloat16 not available. Skipping bfloat16 conversion. Model will be saved in current dtype.")
|
|
else:
|
|
print("Skipping bfloat16 conversion as per SAVE_AS_BFLOAT16 flag.")
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Saving T5EncoderModel to {OUTPUT_ENCODER_MODEL_DIR}...")
|
|
final_model_to_save.save_pretrained(OUTPUT_ENCODER_MODEL_DIR)
|
|
tokenizer.save_pretrained(OUTPUT_ENCODER_MODEL_DIR)
|
|
print("T5EncoderModel and tokenizer saved successfully.")
|
|
|
|
except Exception as e:
|
|
print(f"ERROR during T5EncoderModel preparation or saving: {e}")
|
|
traceback.print_exc()
|
|
return
|
|
|
|
|
|
print(f"\nStep 5: Verifying the saved encoder model from {OUTPUT_ENCODER_MODEL_DIR}...")
|
|
try:
|
|
load_kwargs = {}
|
|
expected_dtype_after_load = torch.float32
|
|
|
|
if SAVE_AS_BFLOAT16 and hasattr(torch, 'bfloat16'):
|
|
print("Verification: Attempting to load model explicitly as bfloat16.")
|
|
load_kwargs['torch_dtype'] = torch.bfloat16
|
|
expected_dtype_after_load = torch.bfloat16
|
|
elif SAVE_AS_BFLOAT16 and not hasattr(torch, 'bfloat16'):
|
|
print("Verification: bfloat16 was intended but not available; expecting original precision (likely float32).")
|
|
|
|
|
|
expected_dtype_after_load = final_model_to_save.dtype
|
|
else:
|
|
print("Verification: Attempting to load model in its saved precision (expected float32 or original).")
|
|
expected_dtype_after_load = final_model_to_save.dtype
|
|
|
|
|
|
|
|
print(f"Loading T5EncoderModel with kwargs: {load_kwargs}")
|
|
loaded_encoder_model = T5EncoderModel.from_pretrained(
|
|
OUTPUT_ENCODER_MODEL_DIR,
|
|
**load_kwargs
|
|
)
|
|
print("Successfully loaded encoder model for verification.")
|
|
print(f"Loaded encoder model dtype: {loaded_encoder_model.dtype}")
|
|
|
|
if loaded_encoder_model.dtype == expected_dtype_after_load:
|
|
print(f"VERIFICATION SUCCESSFUL: Loaded model dtype ({loaded_encoder_model.dtype}) matches expected dtype ({expected_dtype_after_load}).")
|
|
else:
|
|
print(f"VERIFICATION WARNING: Loaded model dtype is {loaded_encoder_model.dtype}, but expected {expected_dtype_after_load}. "
|
|
"This might be okay if precision changed due to availability or specific load behavior.")
|
|
|
|
|
|
print("Attempting a sample inference with the loaded encoder model...")
|
|
test_input_text = "This is a test sentence for the encoder."
|
|
inputs = tokenizer(test_input_text, return_tensors="pt", padding=True, truncation=True)
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
print("CUDA available. Moving model and inputs to GPU.")
|
|
device = torch.device("cuda")
|
|
loaded_encoder_model.to(device)
|
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
else:
|
|
print("CUDA not available. Using CPU.")
|
|
device = torch.device("cpu")
|
|
loaded_encoder_model.to(device)
|
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
|
|
with torch.no_grad():
|
|
outputs = loaded_encoder_model(**inputs)
|
|
last_hidden_states = outputs.last_hidden_state
|
|
print(f"Shape of last hidden state from loaded model: {last_hidden_states.shape}")
|
|
print("Sample inference completed successfully.")
|
|
|
|
except Exception as e:
|
|
print(f"ERROR during verification: {e}")
|
|
traceback.print_exc()
|
|
|
|
print("\nEncoder extraction and saving process finished.")
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if not os.path.isdir(LOCAL_FULL_MODEL_DIR):
|
|
print(f"ERROR: The specified local full model directory does not exist: {LOCAL_FULL_MODEL_DIR}")
|
|
print("Please ensure the path is correct and the full model is downloaded there.")
|
|
elif LOCAL_FULL_MODEL_DIR == "./path/to/your/local/flan-t5-xxl":
|
|
print("ERROR: Please change the placeholder path for LOCAL_FULL_MODEL_DIR in the script.")
|
|
else:
|
|
print(f"PyTorch version: {torch.__version__}")
|
|
print(f"Transformers version: {transformers.__version__}")
|
|
if SAVE_AS_BFLOAT16:
|
|
if hasattr(torch, 'bfloat16'):
|
|
print("torch.bfloat16 is available.")
|
|
else:
|
|
print("WARNING: SAVE_AS_BFLOAT16 is True, but torch.bfloat16 is NOT available in this PyTorch version. "
|
|
"Conversion will be skipped.")
|
|
convert_and_save_encoder() |