Model card auto-generated by SimpleTuner
Browse files
README.md
CHANGED
@@ -55,8 +55,8 @@ You may reuse the base model text encoder for inference.
|
|
55 |
|
56 |
## Training settings
|
57 |
|
58 |
-
- Training epochs:
|
59 |
-
- Training steps:
|
60 |
- Learning rate: 0.0001
|
61 |
- Learning rate schedule: constant
|
62 |
- Warmup steps: 500
|
@@ -66,15 +66,15 @@ You may reuse the base model text encoder for inference.
|
|
66 |
- Gradient accumulation steps: 1
|
67 |
- Number of GPUs: 3
|
68 |
- Gradient checkpointing: False
|
69 |
-
- Prediction type: epsilon[]
|
70 |
- Optimizer: adamw_bf16
|
71 |
- Trainable parameter precision: Pure BF16
|
72 |
- Base model precision: `no_change`
|
73 |
- Caption dropout probability: 0.0%
|
74 |
|
75 |
|
76 |
-
- LoRA Rank:
|
77 |
-
- LoRA Alpha:
|
78 |
- LoRA Dropout: 0.1
|
79 |
- LoRA initialisation style: default
|
80 |
|
@@ -97,36 +97,43 @@ You may reuse the base model text encoder for inference.
|
|
97 |
|
98 |
```python
|
99 |
import torch
|
100 |
-
from diffusers import
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
prompt = "A photo-realistic image of a cat"
|
108 |
-
|
109 |
-
|
110 |
-
## Optional: quantise the model to save on vram.
|
111 |
-
## Note: The model was not quantised during training, so it is not necessary to quantise it during inference time.
|
112 |
-
#from optimum.quanto import quantize, freeze, qint8
|
113 |
-
#quantize(pipeline.transformer, weights=qint8)
|
114 |
-
#freeze(pipeline.transformer)
|
115 |
-
|
116 |
-
pipeline.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') # the pipeline is already in its target precision level
|
117 |
-
model_output = pipeline(
|
118 |
prompt=prompt,
|
119 |
-
|
120 |
num_inference_steps=16,
|
121 |
-
generator=torch.Generator(device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu').manual_seed(42),
|
122 |
-
width=1024,
|
123 |
-
height=1024,
|
124 |
guidance_scale=4.0,
|
|
|
|
|
125 |
).images[0]
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
```
|
130 |
|
131 |
|
132 |
|
|
|
55 |
|
56 |
## Training settings
|
57 |
|
58 |
+
- Training epochs: 4
|
59 |
+
- Training steps: 10
|
60 |
- Learning rate: 0.0001
|
61 |
- Learning rate schedule: constant
|
62 |
- Warmup steps: 500
|
|
|
66 |
- Gradient accumulation steps: 1
|
67 |
- Number of GPUs: 3
|
68 |
- Gradient checkpointing: False
|
69 |
+
- Prediction type: epsilon (extra parameters=['training_scheduler_timestep_spacing=trailing', 'inference_scheduler_timestep_spacing=trailing', 'controlnet_enabled'])
|
70 |
- Optimizer: adamw_bf16
|
71 |
- Trainable parameter precision: Pure BF16
|
72 |
- Base model precision: `no_change`
|
73 |
- Caption dropout probability: 0.0%
|
74 |
|
75 |
|
76 |
+
- LoRA Rank: 64
|
77 |
+
- LoRA Alpha: 64.0
|
78 |
- LoRA Dropout: 0.1
|
79 |
- LoRA initialisation style: default
|
80 |
|
|
|
97 |
|
98 |
```python
|
99 |
import torch
|
100 |
+
from diffusers import PixArtSigmaPipeline, PixArtSigmaControlNetPipeline
|
101 |
+
# if you're not in the SimpleTuner environment, this import will fail.
|
102 |
+
from helpers.models.pixart.controlnet import PixArtSigmaControlNetAdapterModel
|
103 |
+
|
104 |
+
# Load base model
|
105 |
+
base_model_id = "terminusresearch/pixart-900m-1024-ft-v0.6"
|
106 |
+
controlnet_id = "bghira/pixart-controlnet-lora-test"
|
107 |
+
|
108 |
+
# Load ControlNet adapter
|
109 |
+
controlnet = PixArtSigmaControlNetAdapterModel.from_pretrained(
|
110 |
+
f"{controlnet_id}/controlnet"
|
111 |
+
)
|
112 |
+
|
113 |
+
# Create pipeline
|
114 |
+
pipeline = PixArtSigmaControlNetPipeline.from_pretrained(
|
115 |
+
base_model_id,
|
116 |
+
controlnet=controlnet,
|
117 |
+
torch_dtype=torch.bfloat16
|
118 |
+
)
|
119 |
+
pipeline.to('cuda' if torch.cuda.is_available() else 'cpu')
|
120 |
+
|
121 |
+
# Load your control image
|
122 |
+
from PIL import Image
|
123 |
+
control_image = Image.open("path/to/control/image.png")
|
124 |
+
|
125 |
+
# Generate
|
126 |
prompt = "A photo-realistic image of a cat"
|
127 |
+
image = pipeline(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
prompt=prompt,
|
129 |
+
image=control_image,
|
130 |
num_inference_steps=16,
|
|
|
|
|
|
|
131 |
guidance_scale=4.0,
|
132 |
+
generator=torch.Generator(device='cuda').manual_seed(42),
|
133 |
+
controlnet_conditioning_scale=1.0,
|
134 |
).images[0]
|
135 |
|
136 |
+
image.save("output.png")
|
|
|
|
|
137 |
|
138 |
|
139 |
|