Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| ### CHINCHILLA PARAMS: | |
| E = 1.62 | |
| A = 406.4 | |
| B = 410.7 | |
| alpha = 0.336 | |
| beta = 0.283 | |
| Bn = 10**9 | |
| G = ((alpha*A)/(beta*B))**(1/(alpha+beta)) | |
| ### | |
| def to_flops(N, D): | |
| return 6 * N * D | |
| def n_opt(C): | |
| return G * ((C/6) ** (beta / (alpha+beta))) | |
| def d_opt(C): | |
| return (1/G) * ((C/6) ** (alpha / (alpha+beta))) | |
| def get_kd(kn): | |
| frac = (A/B)*(G**(-alpha-beta)) | |
| kd = (1-((kn**-alpha -1)*frac))**(1/(-beta)) | |
| return kd | |
| def compute_overhead(kn, kd): | |
| return kn*kd - 1 | |
| ### PRECOMPUTE CURVE: | |
| kn_min = 0.2 | |
| kn_max = 2 | |
| kns = np.linspace(0.05, 2, 100) | |
| overheads = [] | |
| for kn in np.linspace(0.2, 2, 100): | |
| kd = get_kd(kn) | |
| overheads.append(compute_overhead(kn, kd)*100) | |
| def plot_curve(kn, kd): | |
| plt.plot(kns, overheads) | |
| plt.scatter([kn], [kd]) | |
| plt.xlabel("Fraction of compute optimal model size") | |
| plt.ylabel("Compute overhead (%)") | |
| def compute(N, D): | |
| C = to_flops(N * Bn, D * Bn) | |
| N_opt = n_opt(C) | |
| D_opt = d_opt(C) | |
| kn = N/N_opt | |
| plot_curve(kn, 100*overhead(kn, get_kd(kn))) | |
| text = f"""Compute budget (TFLOPs): {C:.2E}\nTraining compute overhead (%): {100*overhead(kn, get_kd(kn)):.2f}\nInference cost fraction (%): {kn*100:.2f}""" | |
| return text | |
| with gr.Blocks() as demo: | |
| N = gr.Number(value=1, label="Model size (in B parameters)") | |
| D = gr.Number(value=100, label="Dataset size (in B tokens") | |
| button = gr.Button("Compute!") | |
| gr.Plot(value=plt) | |
| md = gr.Markdown(f"""{}""") | |
| button.click(fn=, inputs=[N, D], ouptus=[md]) | |
| demo.launch() |