cbensimon HF Staff commited on
Commit
9544e60
·
1 Parent(s): 4d0c9f3

Begin actual demo

Browse files
Files changed (1) hide show
  1. app.py +54 -49
app.py CHANGED
@@ -15,75 +15,80 @@ import gradio as gr
15
  import spaces
16
  import torch
17
  import torch._inductor
 
18
  from torch._inductor.package import package_aoti
19
  from torch.export.pt2_archive._package import AOTICompiledModel
20
  from torch.export.pt2_archive._package_weights import Weights
21
- from torchvision.models import ResNet18_Weights, resnet18
22
 
23
 
24
- model = resnet18(weights=ResNet18_Weights.DEFAULT)
25
- model.eval()
26
- model.to('cuda')
27
 
28
- package_path = os.path.join(os.getcwd(), 'resnet18.pt2')
29
- inductor_configs = {'max_autotune': True}
30
- example_inputs = (torch.randn(2, 3, 224, 224, device='cuda'),)
31
 
32
  @spaces.GPU
33
- def compile_model():
34
- with torch.inference_mode():
35
- exported_program = torch.export.export(
36
- model,
37
- example_inputs,
38
- )
39
- artifacts = torch._inductor.aot_compile(exported_program.module(), *exported_program.example_inputs, options={
40
- 'aot_inductor.package_constants_in_so': False,
41
- 'aot_inductor.package_constants_on_disk': True,
42
- 'aot_inductor.package': True,
43
- 'max_autotune': True,
44
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  files = [file for file in artifacts if isinstance(file, str)]
46
  package_aoti(package_path, files)
 
47
  weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
48
  weights_: dict[str, torch.Tensor] = {}
 
49
  for name in weights:
50
  tensor, _properties = weights.get_weight(name)
51
  tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
52
  weights_[name] = tensor_.copy_(tensor).detach().share_memory_()
 
53
  return weights_
54
 
55
- weights = compile_model()
 
56
  weights = {name: tensor.to('cuda') for name, tensor in weights.items()}
57
 
58
- del model
59
 
60
- compiled_model: AOTICompiledModel | None = None
61
 
62
  @spaces.GPU
63
- def run_model():
64
- # TODO: compiled model loading should actually go in worker init ...args: (path, weights)
65
- # Something like: @spaces.GPU(aoti_load=(package_path, weights))
66
- # It will probably solve the Driver runtime error when idle-reusing
67
- # And avoids manually handling state with global
68
- # Ou autrement :
69
- # pipeline.transformer = ZeroGPUCompiledModel(pt2_path, weights)
70
- # Puis les instances de ZeroGPUCompiledModel sont chargées automatiquement pendant le worker init
71
- # C'est encore mieux ça (je crois que c'était l'idée que j'avais de base)
72
- # Une inferface encore plus high-level ce serait :
73
- # pipeline.transformer = ZeroGPUCompile(pipeline.transformer, kwargs=example_kwargs)
74
- # Et la compilation avec @spaces.GPU, le packaging, les wiehgts séparées, etc.
75
- # Tout ça serait géré automatiquement
76
- # Bon mais faut laisser plusieurs niveaux d'abstraction je pense
77
- # Et peut-être commencer par le low-level (voire pas d'helper du tout et tout en mode manuel mais pour le moment j'ai une driver context runtime error)
78
- # Je vais quand-même pouvoir trouver un niveau d'abstraction idéal
79
- global compiled_model
80
- if compiled_model is None:
81
- compiled_model = torch._inductor.aoti_load_package(package_path)
82
- compiled_model.load_constants(weights, check_full_update=True, user_managed=True)
83
- with torch.inference_mode():
84
- compiled_model(example_inputs)
85
- with torch.inference_mode():
86
- return str(compiled_model(example_inputs))
87
-
88
-
89
- gr.Interface(run_model, [], 'text').launch(show_error=True)
 
15
  import spaces
16
  import torch
17
  import torch._inductor
18
+ from diffusers import FluxPipeline
19
  from torch._inductor.package import package_aoti
20
  from torch.export.pt2_archive._package import AOTICompiledModel
21
  from torch.export.pt2_archive._package_weights import Weights
 
22
 
23
 
24
+ pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-schnell', torch_dtype=torch.bfloat16).to('cuda')
25
+ package_path = 'pipeline.pt2'
 
26
 
 
 
 
27
 
28
  @spaces.GPU
29
+ def compile_transformer():
30
+
31
+ def _example_tensor(*shape):
32
+ return torch.randn(*shape, device='cuda', dtype=torch.bfloat16)
33
+
34
+ is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
35
+ seq_length = 256 if is_timestep_distilled else 512
36
+
37
+ transformer_kwargs = {
38
+ 'hidden_states': _example_tensor(1, 4096, 64),
39
+ 'timestep': torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
40
+ 'guidance': None if is_timestep_distilled else torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
41
+ 'pooled_projections': _example_tensor(1, 768),
42
+ 'encoder_hidden_states': _example_tensor(1, seq_length, 4096),
43
+ 'txt_ids': _example_tensor(seq_length, 3),
44
+ 'img_ids': _example_tensor(4096, 3),
45
+ 'joint_attention_kwargs': {},
46
+ 'return_dict': False,
47
+ }
48
+
49
+ inductor_configs = {
50
+ 'conv_1x1_as_mm': True,
51
+ 'epilogue_fusion': False,
52
+ 'coordinate_descent_tuning': True,
53
+ 'coordinate_descent_check_all_directions': True,
54
+ 'max_autotune': True,
55
+ 'triton.cudagraphs': True,
56
+ }
57
+
58
+ exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)
59
+
60
+ artifacts = torch._inductor.aot_compile(exported.module(), *exported.example_inputs, options=inductor_configs | {
61
+ 'aot_inductor.package_constants_in_so': False,
62
+ 'aot_inductor.package_constants_on_disk': True,
63
+ 'aot_inductor.package': True,
64
+ })
65
+
66
  files = [file for file in artifacts if isinstance(file, str)]
67
  package_aoti(package_path, files)
68
+
69
  weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
70
  weights_: dict[str, torch.Tensor] = {}
71
+
72
  for name in weights:
73
  tensor, _properties = weights.get_weight(name)
74
  tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
75
  weights_[name] = tensor_.copy_(tensor).detach().share_memory_()
76
+
77
  return weights_
78
 
79
+
80
+ weights = compile_transformer()
81
  weights = {name: tensor.to('cuda') for name, tensor in weights.items()}
82
 
83
+ pipeline.transformer = None
84
 
 
85
 
86
  @spaces.GPU
87
+ def generate_image(prompt: str):
88
+ compiled_transformer: AOTICompiledModel = torch._inductor.aoti_load_package(package_path)
89
+ compiled_transformer.load_constants(weights, check_full_update=True, user_managed=True)
90
+ pipeline.transformer = compiled_transformer
91
+ return pipeline(prompt, num_inference_steps=4).images[0]
92
+
93
+
94
+ gr.Interface(generate_image, 'text', 'image').launch(show_error=True)