Spaces:
Running
on
Zero
Running
on
Zero
Begin actual demo
Browse files
app.py
CHANGED
@@ -15,75 +15,80 @@ import gradio as gr
|
|
15 |
import spaces
|
16 |
import torch
|
17 |
import torch._inductor
|
|
|
18 |
from torch._inductor.package import package_aoti
|
19 |
from torch.export.pt2_archive._package import AOTICompiledModel
|
20 |
from torch.export.pt2_archive._package_weights import Weights
|
21 |
-
from torchvision.models import ResNet18_Weights, resnet18
|
22 |
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
model.to('cuda')
|
27 |
|
28 |
-
package_path = os.path.join(os.getcwd(), 'resnet18.pt2')
|
29 |
-
inductor_configs = {'max_autotune': True}
|
30 |
-
example_inputs = (torch.randn(2, 3, 224, 224, device='cuda'),)
|
31 |
|
32 |
@spaces.GPU
|
33 |
-
def
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
files = [file for file in artifacts if isinstance(file, str)]
|
46 |
package_aoti(package_path, files)
|
|
|
47 |
weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
|
48 |
weights_: dict[str, torch.Tensor] = {}
|
|
|
49 |
for name in weights:
|
50 |
tensor, _properties = weights.get_weight(name)
|
51 |
tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
|
52 |
weights_[name] = tensor_.copy_(tensor).detach().share_memory_()
|
|
|
53 |
return weights_
|
54 |
|
55 |
-
|
|
|
56 |
weights = {name: tensor.to('cuda') for name, tensor in weights.items()}
|
57 |
|
58 |
-
|
59 |
|
60 |
-
compiled_model: AOTICompiledModel | None = None
|
61 |
|
62 |
@spaces.GPU
|
63 |
-
def
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
# C'est encore mieux ça (je crois que c'était l'idée que j'avais de base)
|
72 |
-
# Une inferface encore plus high-level ce serait :
|
73 |
-
# pipeline.transformer = ZeroGPUCompile(pipeline.transformer, kwargs=example_kwargs)
|
74 |
-
# Et la compilation avec @spaces.GPU, le packaging, les wiehgts séparées, etc.
|
75 |
-
# Tout ça serait géré automatiquement
|
76 |
-
# Bon mais faut laisser plusieurs niveaux d'abstraction je pense
|
77 |
-
# Et peut-être commencer par le low-level (voire pas d'helper du tout et tout en mode manuel mais pour le moment j'ai une driver context runtime error)
|
78 |
-
# Je vais quand-même pouvoir trouver un niveau d'abstraction idéal
|
79 |
-
global compiled_model
|
80 |
-
if compiled_model is None:
|
81 |
-
compiled_model = torch._inductor.aoti_load_package(package_path)
|
82 |
-
compiled_model.load_constants(weights, check_full_update=True, user_managed=True)
|
83 |
-
with torch.inference_mode():
|
84 |
-
compiled_model(example_inputs)
|
85 |
-
with torch.inference_mode():
|
86 |
-
return str(compiled_model(example_inputs))
|
87 |
-
|
88 |
-
|
89 |
-
gr.Interface(run_model, [], 'text').launch(show_error=True)
|
|
|
15 |
import spaces
|
16 |
import torch
|
17 |
import torch._inductor
|
18 |
+
from diffusers import FluxPipeline
|
19 |
from torch._inductor.package import package_aoti
|
20 |
from torch.export.pt2_archive._package import AOTICompiledModel
|
21 |
from torch.export.pt2_archive._package_weights import Weights
|
|
|
22 |
|
23 |
|
24 |
+
pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-schnell', torch_dtype=torch.bfloat16).to('cuda')
|
25 |
+
package_path = 'pipeline.pt2'
|
|
|
26 |
|
|
|
|
|
|
|
27 |
|
28 |
@spaces.GPU
|
29 |
+
def compile_transformer():
|
30 |
+
|
31 |
+
def _example_tensor(*shape):
|
32 |
+
return torch.randn(*shape, device='cuda', dtype=torch.bfloat16)
|
33 |
+
|
34 |
+
is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
|
35 |
+
seq_length = 256 if is_timestep_distilled else 512
|
36 |
+
|
37 |
+
transformer_kwargs = {
|
38 |
+
'hidden_states': _example_tensor(1, 4096, 64),
|
39 |
+
'timestep': torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
|
40 |
+
'guidance': None if is_timestep_distilled else torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
|
41 |
+
'pooled_projections': _example_tensor(1, 768),
|
42 |
+
'encoder_hidden_states': _example_tensor(1, seq_length, 4096),
|
43 |
+
'txt_ids': _example_tensor(seq_length, 3),
|
44 |
+
'img_ids': _example_tensor(4096, 3),
|
45 |
+
'joint_attention_kwargs': {},
|
46 |
+
'return_dict': False,
|
47 |
+
}
|
48 |
+
|
49 |
+
inductor_configs = {
|
50 |
+
'conv_1x1_as_mm': True,
|
51 |
+
'epilogue_fusion': False,
|
52 |
+
'coordinate_descent_tuning': True,
|
53 |
+
'coordinate_descent_check_all_directions': True,
|
54 |
+
'max_autotune': True,
|
55 |
+
'triton.cudagraphs': True,
|
56 |
+
}
|
57 |
+
|
58 |
+
exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)
|
59 |
+
|
60 |
+
artifacts = torch._inductor.aot_compile(exported.module(), *exported.example_inputs, options=inductor_configs | {
|
61 |
+
'aot_inductor.package_constants_in_so': False,
|
62 |
+
'aot_inductor.package_constants_on_disk': True,
|
63 |
+
'aot_inductor.package': True,
|
64 |
+
})
|
65 |
+
|
66 |
files = [file for file in artifacts if isinstance(file, str)]
|
67 |
package_aoti(package_path, files)
|
68 |
+
|
69 |
weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
|
70 |
weights_: dict[str, torch.Tensor] = {}
|
71 |
+
|
72 |
for name in weights:
|
73 |
tensor, _properties = weights.get_weight(name)
|
74 |
tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
|
75 |
weights_[name] = tensor_.copy_(tensor).detach().share_memory_()
|
76 |
+
|
77 |
return weights_
|
78 |
|
79 |
+
|
80 |
+
weights = compile_transformer()
|
81 |
weights = {name: tensor.to('cuda') for name, tensor in weights.items()}
|
82 |
|
83 |
+
pipeline.transformer = None
|
84 |
|
|
|
85 |
|
86 |
@spaces.GPU
|
87 |
+
def generate_image(prompt: str):
|
88 |
+
compiled_transformer: AOTICompiledModel = torch._inductor.aoti_load_package(package_path)
|
89 |
+
compiled_transformer.load_constants(weights, check_full_update=True, user_managed=True)
|
90 |
+
pipeline.transformer = compiled_transformer
|
91 |
+
return pipeline(prompt, num_inference_steps=4).images[0]
|
92 |
+
|
93 |
+
|
94 |
+
gr.Interface(generate_image, 'text', 'image').launch(show_error=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|