FLUX.1-dev-base / aoti_base_example.py
cbensimon's picture
cbensimon HF Staff
Test AOTI base example during app startup
8c2e0d0
raw
history blame
984 Bytes
"""
Modified from https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html
"""
import os
import torch
import torch._inductor
from torchvision.models import ResNet18_Weights, resnet18
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()
package_path = os.path.join(os.getcwd(), "resnet18.pt2")
inductor_configs = {'max_autotune': True}
device = "cuda"
# Compile
with torch.inference_mode():
model = model.to(device=device)
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
exported_program = torch.export.export(
model,
example_inputs,
)
torch._inductor.aoti_compile_and_package(
exported_program,
package_path=package_path,
inductor_configs=inductor_configs
)
# Load
compiled_model = torch._inductor.aoti_load_package(package_path)
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
# Run
with torch.inference_mode():
output = compiled_model(example_inputs)