refoundd commited on
Commit
4b67735
·
verified ·
1 Parent(s): 3e7e1e3

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +79 -0
  2. requirements.txt +17 -0
handler.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict
3
+ from PIL import Image
4
+ import torch
5
+ from diffusers import FluxPipeline
6
+ from huggingface_inference_toolkit.logging import logger
7
+ from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
8
+ import time
9
+ import torch.distributed as dist
10
+ from para_attn.context_parallel import init_context_parallel_mesh
11
+ from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
12
+ from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
13
+ dist.init_process_group()
14
+
15
+ torch.cuda.set_device(dist.get_rank())
16
+
17
+ class EndpointHandler:
18
+ def __init__(self, path=""):
19
+ self.pipe = FluxPipeline.from_pretrained(
20
+ "NoMoreCopyrightOrg/flux-dev",
21
+ torch_dtype=torch.bfloat16,
22
+ ).to("cuda")
23
+ mesh = init_context_parallel_mesh(
24
+ self.pipe.device.type,
25
+ max_ring_dim_size=2,
26
+ )
27
+ parallelize_pipe(
28
+ self.pipe,
29
+ mesh=mesh,
30
+ )
31
+ parallelize_vae(self.pipe.vae, mesh=mesh._flatten())
32
+ apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
33
+ torch._inductor.config.reorder_for_compute_comm_overlap = True
34
+ self.pipe.transformer = torch.compile(
35
+ self.pipe.transformer, mode="max-autotune-no-cudagraphs",
36
+ )
37
+ self.pipe.vae = torch.compile(
38
+ self.pipe.vae, mode="max-autotune-no-cudagraphs",
39
+ )
40
+
41
+ def __call__(self, data: Dict[str, Any]) -> str:
42
+ logger.info(f"Received incoming request with {data=}")
43
+
44
+ if "inputs" in data and isinstance(data["inputs"], str):
45
+ prompt = data.pop("inputs")
46
+ elif "prompt" in data and isinstance(data["prompt"], str):
47
+ prompt = data.pop("prompt")
48
+ else:
49
+ raise ValueError(
50
+ "Provided input body must contain either the key `inputs` or `prompt` with the"
51
+ " prompt to use for the image generation, and it needs to be a non-empty string."
52
+ )
53
+
54
+ parameters = data.pop("parameters", {})
55
+
56
+ num_inference_steps = parameters.get("num_inference_steps", 28)
57
+ width = parameters.get("width", 1024)
58
+ height = parameters.get("height", 1024)
59
+ guidance_scale = parameters.get("guidance_scale", 3.5)
60
+
61
+ # seed generator (seed cannot be provided as is but via a generator)
62
+ seed = parameters.get("seed", 0)
63
+ generator = torch.manual_seed(seed)
64
+ start_time = time.time()
65
+ result = self.pipe( # type: ignore
66
+ prompt,
67
+ height=height,
68
+ width=width,
69
+ guidance_scale=guidance_scale,
70
+ num_inference_steps=num_inference_steps,
71
+ generator=generator,
72
+ output_type="pil" if dist.get_rank() == 0 else "pt",
73
+ ).images[0]
74
+ end_time = time.time()
75
+ if dist.get_rank() == 0:
76
+ time_taken = end_time - start_time
77
+ print(f"Time taken: {time_taken:.2f} seconds")
78
+ return result
79
+ return "123"
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu126
2
+ torch==2.6.0+cu126
3
+ torchvision
4
+ torchaudio
5
+ huggingface_hub
6
+ torchao==0.9.0
7
+ diffusers==0.32.2
8
+ peft
9
+ transformers
10
+ numpy
11
+ scipy
12
+ Pillow
13
+ sentencepiece
14
+ protobuf
15
+ triton
16
+ schedule
17
+ para-attn