| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from matplotlib.ticker import ScalarFormatter | |
| from enum import Enum | |
| import io | |
| class AttentionType(Enum): | |
| LOCAL = 0 | |
| GLOBAL = 1 | |
| def gqa_kv_per_layer_per_token(n_kv_heads, d_head, kv_parameter_size): | |
| return 2 * kv_parameter_size * n_kv_heads * d_head | |
| def mla_kv_per_layer_per_token(d_compressed, kv_parameter_size): | |
| return kv_parameter_size * d_compressed | |
| def tokens_per_second(batch_size, bandwidth, total_kv_size, param_size): | |
| return (batch_size * bandwidth) / (batch_size * total_kv_size + param_size) | |
| def compute_tps(kv_per_layer_per_token, seq_len, batch_size, total_param_size, | |
| num_layers, swa_pattern, swa_size, bandwidth): | |
| tps_values = [] | |
| for ctx_len in seq_len: | |
| total_kv_size = 0 | |
| for l in range(num_layers): | |
| if swa_pattern[l % len(swa_pattern)] == AttentionType.LOCAL: | |
| total_kv_size += kv_per_layer_per_token * min(ctx_len, swa_size) | |
| else: | |
| total_kv_size += kv_per_layer_per_token * ctx_len | |
| tps = tokens_per_second(batch_size, bandwidth, total_kv_size, total_param_size) | |
| tps_values.append(tps) | |
| return tps_values | |
| def create_throughput_plot( | |
| model_name, | |
| memory_bandwidth, | |
| num_parameters, | |
| parameter_size, | |
| kv_parameter_size, | |
| num_layers, | |
| num_heads, | |
| d_model, | |
| ctx_length, | |
| local_layers, | |
| global_layers, | |
| swa_size, | |
| gqa_heads, | |
| mla_d_compressed, | |
| ): | |
| memory_bandwidth = float(memory_bandwidth) * 1_000_000_000 | |
| num_parameters = float(num_parameters) * 1_000_000_000 | |
| d_head = d_model // num_heads | |
| total_param_size = num_parameters * (parameter_size / 8.0) | |
| swa_pattern = ([AttentionType.LOCAL] * local_layers + | |
| [AttentionType.GLOBAL] * global_layers) | |
| if len(swa_pattern) == 0: | |
| swa_pattern = [AttentionType.GLOBAL] | |
| sns.set_theme(style="whitegrid", context="paper") | |
| palette = sns.color_palette("viridis", len(gqa_heads) + len(mla_d_compressed)) | |
| plt.figure(figsize=(14, 8), dpi=300) | |
| seq_len = np.logspace(2, 5, 100).astype(int) | |
| batch_size = 1 | |
| tps_values = [] | |
| gqa_count = len(gqa_heads) | |
| for i, n_kv_head in enumerate(gqa_heads): | |
| n_kv_head = int(n_kv_head) | |
| kv_per_token = gqa_kv_per_layer_per_token(n_kv_head, d_head, kv_parameter_size) | |
| gqa_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size, | |
| num_layers, swa_pattern, swa_size, memory_bandwidth) | |
| tps_values.extend(gqa_tps_values) | |
| plt.plot(seq_len, gqa_tps_values, label=f"GQA: {n_kv_head} heads", color=palette[i], | |
| linewidth=3.5, alpha=0.85) | |
| plt.axvline(x=ctx_length, color='red', linestyle='--', alpha=0.8, linewidth=2.5, | |
| label=f"Max Context Length ({ctx_length:,})") | |
| local_count = swa_pattern.count(AttentionType.LOCAL) | |
| global_count = swa_pattern.count(AttentionType.GLOBAL) | |
| if local_count > 0: | |
| plt.axvline(x=swa_size, color='blue', linestyle='--', alpha=0.8, linewidth=2.5, | |
| label=f"Sliding Window Limit ({swa_size:,})") | |
| for i, d_comp in enumerate(mla_d_compressed): | |
| d_comp = int(d_comp) | |
| kv_per_token = mla_kv_per_layer_per_token(d_comp, kv_parameter_size) | |
| mla_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size, | |
| num_layers, swa_pattern, swa_size, memory_bandwidth) | |
| tps_values.extend(mla_tps_values) | |
| plt.plot(seq_len, mla_tps_values, label=f"MLA: dc = {d_comp}", | |
| color=palette[i + gqa_count], linewidth=3.5, alpha=0.85) | |
| plt.xscale('log') | |
| if all(np.isfinite(tps_values)): | |
| min_tps = min(tps_values) | |
| max_tps = max(tps_values) | |
| y_min = max(0, min_tps * 0.9) | |
| y_max = max_tps * 1.1 | |
| plt.ylim(y_min, y_max) | |
| else: | |
| plt.ylim(15, 40) | |
| plt.gca().xaxis.set_major_formatter(ScalarFormatter()) | |
| plt.gca().yaxis.set_major_formatter(ScalarFormatter()) | |
| ax = plt.gca() | |
| ax.spines['top'].set_visible(False) | |
| ax.spines['right'].set_visible(False) | |
| ax.spines['left'].set_linewidth(1.5) | |
| ax.spines['bottom'].set_linewidth(1.5) | |
| attn_label = "Global" if local_count == 0 else f"SWA {local_count}:{global_count}" | |
| device_name = model_name.split(':')[0] if ':' in model_name else model_name | |
| plt.annotate(f"{device_name}\nBandwidth: {memory_bandwidth/1e9:.1f} GB/s\nParameter Size: {parameter_size:.1f} bits\nAttention Kind: {attn_label}", | |
| xy=(0.8, 0.97), | |
| xycoords='axes fraction', | |
| bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.9, edgecolor='darkgray'), | |
| va='top', | |
| fontsize=11) | |
| plt.xlabel('Context Length (tokens)', fontsize=14, fontweight='bold') | |
| plt.ylabel('Tokens per Second', fontsize=14, fontweight='bold') | |
| plt.tick_params(axis='both', which='major', labelsize=12) | |
| model_title = model_name.split(':')[1] if ':' in model_name else model_name | |
| plt.title(f"{model_title}: Tokens Per Second vs. Sequence Length", fontsize=18, | |
| fontweight='bold', pad=20) | |
| plt.legend(title="Configuration", frameon=True, framealpha=0.95, fontsize=12, title_fontsize=14) | |
| plt.grid(True, alpha=0.5) | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png') | |
| plt.close() | |
| buf.seek(0) | |
| from PIL import Image | |
| img = Image.open(buf) | |
| return img | |