training / flax /run_speculative_decoding.py
GalaktischeGurke's picture
Saving train state of step 5000
4008bf9 verified
#!/usr/bin/env python3
# make sure to use branch: https://github.com/huggingface/transformers/pull/26701
import copy
import time
import torch
from datasets import load_dataset
from transformers import (
AutoProcessor,
WhisperForConditionalGeneration,
)
DEVICE = "cuda"
DTYPE = torch.float16
SAMPLING_RATE = 16_000
BATCH_SIZE = 1
USE_FLASH_ATTN_2 = True
# TO DEBUG
GAMMAS = [5, 7, 6, 5, 4, 3, 5]
COUNT = 0
# local loading is faster
teacher = WhisperForConditionalGeneration.from_pretrained(
"/home/patrick/distil_whisper/",
torch_dtype=DTYPE,
variant="fp16",
low_cpu_mem_usage=True,
use_flash_attention_2=USE_FLASH_ATTN_2,
)
student = WhisperForConditionalGeneration.from_pretrained(
"/home/patrick/distil_whisper_student/",
torch_dtype=DTYPE,
variant="fp16",
low_cpu_mem_usage=True,
use_flash_attention_2=USE_FLASH_ATTN_2,
)
# student = WhisperForCausalLM.from_pretrained("/home/patrick/distil_whisper_student", torch_dtype=DTYPE, variant="fp16", low_cpu_mem_usage=True, use_flash_attention_2=USE_FLASH_ATTN_2)
student.generation_config = copy.deepcopy(teacher.generation_config)
student.generation_config.num_assistant_tokens_schedule = "constant"
# teacher = WhisperForConditionalGeneration.from_pretrained(
# "openai/whisper-large-v2", torch_dtype=DTYPE, variant="fp16", low_cpu_mem_usage=True
# )
# student = WhisperForConditionalGeneration.from_pretrained(
# "sanchit-gandhi/large-32-2-gpu-flat-lr", torch_dtype=DTYPE, variant="fp16", low_cpu_mem_usage=True
# )
#
teacher.to(DEVICE)
student.to(DEVICE)
processor = AutoProcessor.from_pretrained("sanchit-gandhi/large-32-2-gpu-flat-lr")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
total_time_default = 0
total_time_spec = 0
total_time_spec_2 = 0
input_values = ds[0]["audio"]["array"]
inputs = processor(input_values, return_tensors="pt", sampling_rate=SAMPLING_RATE)
input_features = inputs.input_features.to(device=DEVICE, dtype=DTYPE)
_ = teacher.generate(input_features, max_length=100)
end_idx = ds.shape[0]
for audio_idx in range(0, end_idx, BATCH_SIZE):
input_values = ds[audio_idx : audio_idx + BATCH_SIZE]
input_values = [i["array"] for i in input_values["audio"]]
inputs = processor(input_values, return_tensors="pt", sampling_rate=SAMPLING_RATE)
input_features = inputs.input_features.to(device=DEVICE, dtype=DTYPE)
start_time = time.time()
out = teacher.generate(input_features, max_length=100)
run_time = time.time() - start_time
print(f"Normal Decoding: {run_time}")
total_time_default += run_time
default_out = processor.batch_decode(out, skip_special_tokens=True)
# print("Output", default_out)
# start_time = time.time()
# with torch.no_grad():
# encoder_outputs = teacher.get_encoder()(input_features).last_hidden_state
# out, ratio = speculative_decoding(teacher, student, encoder_outputs, max_length=100, gamma=5)
# run_time = time.time() - start_time
# print(20 * "=")
# print(f"Speculative Decoding: {run_time}")
# total_time_spec += run_time
# spec_out = processor.batch_decode(out)
start_time = time.time()
with torch.no_grad():
encoder_outputs = teacher.get_encoder()(input_features)
out = teacher.generate(
assistant_model=student,
assistant_encoder_outputs=encoder_outputs,
encoder_outputs=encoder_outputs,
max_length=100,
)
run_time = time.time() - start_time
spec_out_2 = processor.batch_decode(out, skip_special_tokens=True)
print(f"Speculative Decoding 2: {run_time}")
total_time_spec_2 += run_time
if spec_out_2 != default_out:
COUNT += 1
print(f"Audio {audio_idx} does not match. Spec: {spec_out_2}, True: {default_out}")
print(20 * "=")
print("Total time", total_time_default)
print(f"Overall speed-up spec 2 {total_time_default / total_time_spec_2}")
# print(f"Overall speed-up {total_time_default / total_time_spec}")