import argparse import glob import importlib import itertools import os import torch from common.bench_framework import (make_bwd_benchmark_for_case, make_bwd_benchmark_plot_for_case, make_fwd_benchmark_for_case, make_fwd_benchmark_plot_for_case) from common.diff_engine import DiffCase, calculate_diff def make_title_tag(): if torch.cuda.is_available(): dev_name = torch.cuda.get_device_name(0) else: dev_name = "CPU" torch_ver = torch.__version__ return f"[{dev_name} | torch {torch_ver}]" def plot_result(r_path): import matplotlib.pyplot as plt import pandas as pd df = pd.read_csv(r_path + ".csv") plt.figure(figsize=(12, 6)) ax = df.plot(x="config", y=["Naive", "Cuda"], kind="bar", ax=plt.gca()) ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(), fontsize=14, fontweight="bold") ax.set_ylabel("Relative Speedup", fontsize=14) ax.set_xlabel("") plt.xticks(rotation=45, fontsize=12, ha="right", rotation_mode="anchor") for container in ax.containers: labels = [f"x{v.get_height():.2f}" for v in container] ax.bar_label(container, labels=labels, label_type="edge", fontsize=10) plt.tight_layout() plt.savefig(r_path + ".png", bbox_inches="tight") def main(): ap = argparse.ArgumentParser() ap.add_argument("--case", choices=["rms", "add_rms", "poly", "mul_poly"], required=True) ap.add_argument("--plot", action="store_true") ap.add_argument( "--save-path", type=str, default="./configs/", help="Path to save benchmark results", ) args = ap.parse_args() torch.set_default_device("cuda") mod = importlib.import_module(f"cases.{args.case}") case: DiffCase = mod.CASE calculate_diff( case, batch_size=2, seq_len=128, hidden_size=4096, ) save_dir = os.path.join(args.save_path, args.case) if args.plot: batch_size_range = [1] seq_length_range = [4096, 8192, 16384] dim = [8192, 16384] if "poly" in args.case else [2048, 4096] configs = list( itertools.product(batch_size_range, seq_length_range, dim)) plot_name = f"plot_{args.case}-fwd-perf" bench = make_fwd_benchmark_plot_for_case( case=case, configs=configs, plot_name=plot_name, line_names={ "naive": "Naive", "cuda": "Cuda", }, ) bench.run(print_data=True, save_path=save_dir) plot_result(os.path.join(save_dir, plot_name)) plot_name = f"plot_{args.case}-bwd-perf" bench = make_bwd_benchmark_plot_for_case( case=case, configs=configs, plot_name=plot_name, line_names={ "naive": "Naive", "cuda": "Cuda", }, ) bench.run(print_data=True, save_path=save_dir) plot_result(os.path.join(save_dir, plot_name)) for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob( os.path.join(save_dir, "*.csv")): os.remove(f) else: batch_size_range = [2**i for i in range(0, 4, 1)] seq_length_range = [2**i for i in range(10, 14, 1)] dim = [8192, 16384] if "poly" in args.case else [2048, 4096] configs = list( itertools.product(dim, batch_size_range, seq_length_range)) bench = make_fwd_benchmark_for_case( case=case, configs=configs, plot_name=f"{args.case}-fwd-perf", line_names={ "naive": "Naive", "cuda": "Cuda", "speedup": "SpeedUp" }, ) bench.run(print_data=True, save_path=save_dir) bench = make_bwd_benchmark_for_case( case=case, configs=configs, plot_name=f"{args.case}-bwd-perf", line_names={ "naive": "Naive", "cuda": "Cuda", "speedup": "SpeedUp" }, ) bench.run(print_data=True, save_path=save_dir) for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob( os.path.join(save_dir, "*.png")): os.remove(f) if __name__ == "__main__": main()