|
import matplotlib.pyplot as plt |
|
|
|
|
|
sparsity = [0, 10, 20, 30, 40, 50, 60, 70] |
|
perplexity = [5.8393, 5.8781, 6.0102, 6.3076, 7.0094, 9.0642, 20.2265, 103.5209] |
|
|
|
|
|
plt.figure(figsize=(6, 4)) |
|
plt.plot(sparsity, perplexity, marker='o', linestyle='-', color='b') |
|
plt.axhline(y=5.8393, color='g', linestyle='--', label='Perplexity at 0% Sparsity') |
|
plt.title("Perplexity vs. Weight Target Sparsity", fontsize=14) |
|
plt.xlabel("Weight Target Sparsity (%)", fontsize=12) |
|
plt.ylabel("Perplexity (lower is better)", fontsize=12) |
|
plt.legend(fontsize=10) |
|
plt.grid(True) |
|
plt.xticks(fontsize=10) |
|
plt.yticks(fontsize=10) |
|
plt.show() |
|
plt.savefig("perplexity_vs_sparsity.png") |