File size: 25,107 Bytes
dd09c30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
import os
import uuid

import cv2
import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms

from diffusers import FluxFillPipeline, FluxTransformer2DModel
from diffusers.utils import check_min_version, load_image

WEIGHT_PATH = "dielz/textfux-test/transformer" 
# scheduler = "overshoot" # overshoot or default
scheduler = "default"


def read_words_from_text(input_text):
    """
    Reads words/list of words:
    - If input_text is a file path, it reads all non-empty lines from the file.
    - Otherwise, it directly splits the input by newlines into a list.
    """
    if isinstance(input_text, str) and os.path.exists(input_text):
        with open(input_text, 'r', encoding='utf-8') as f:
            words = [line.strip() for line in f if line.strip()]
    else:
        words = [line.strip() for line in input_text.splitlines() if line.strip()]
    return words

def generate_prompt(words):
    words_str = ', '.join(f"'{word}'" for word in words)
    prompt_template = (
        "The pair of images highlights some white words on a black background, as well as their style on a real-world scene image. "
        "[IMAGE1] is a template image rendering the text, with the words {words}; "
        "[IMAGE2] shows the text content {words} naturally and correspondingly integrated into the image."
    )
    return prompt_template.format(words=words_str)

prompt_template2 = (
    "The pair of images highlights some white words on a black background, as well as their style on a real-world scene image. "
    "[IMAGE1] is a template image rendering the text, with the words; "
    "[IMAGE2] shows the text content naturally and correspondingly integrated into the image."
)

PIPE = None
def load_flux_pipeline():
    global PIPE
    if PIPE is None:
        transformer = FluxTransformer2DModel.from_pretrained(
            WEIGHT_PATH,
            torch_dtype=torch.bfloat16
        )
        PIPE = FluxFillPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-Fill-dev",
            transformer=transformer,
            torch_dtype=torch.bfloat16
        ).to("cuda")
        PIPE.transformer.to(torch.bfloat16)
    return PIPE

def run_inference(image_input, mask_input, words_input, num_steps=50, guidance_scale=30, seed=42):
    """
    Invokes the Flux model pipeline for inference:
    - Both image_input and mask_input are required to be concatenated composite images.
    - Automatically adjusts image dimensions to be multiples of 32 to meet model input requirements.
    - Generates a prompt based on the word list and passes it to the pipeline for inference execution.
    """
    if isinstance(image_input, str):
        inpaint_image = load_image(image_input).convert("RGB")
    else:
        inpaint_image = image_input.convert("RGB")
    if isinstance(mask_input, str):
        extended_mask = load_image(mask_input).convert("RGB")
    else:
        extended_mask = mask_input.convert("RGB")
    width, height = inpaint_image.size
    new_width = (width // 32) * 32
    new_height = (height // 32) * 32
    inpaint_image = inpaint_image.resize((new_width, new_height))
    extended_mask = extended_mask.resize((new_width, new_height))
    words = read_words_from_text(words_input)
    prompt = generate_prompt(words)
    print("Generated prompt:", prompt)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
    mask_transform = transforms.Compose([
        transforms.ToTensor()
    ])
    image_tensor = transform(inpaint_image)
    mask_tensor = mask_transform(extended_mask)
    generator = torch.Generator(device="cuda").manual_seed(int(seed))
    pipe = load_flux_pipeline()

    if scheduler == "overshoot":
        try:
            from diffusers import StochasticRFOvershotDiscreteScheduler
            scheduler_config = pipe.scheduler.config
            scheduler = StochasticRFOvershotDiscreteScheduler.from_config(scheduler_config)
            overshot_func = lambda t, dt: t + dt
            
            pipe.scheduler = scheduler
            pipe.scheduler.set_c(2.0)
            pipe.scheduler.set_overshot_func(overshot_func)
        except ImportError:
            print("StochasticRFOvershotDiscreteScheduler not found. Please ensure you have used the repo's diffusers.")
            pass

    result = pipe(
        height=new_height,
        width=new_width,
        image=inpaint_image,
        mask_image=extended_mask,
        num_inference_steps=num_steps,
        generator=generator,
        max_sequence_length=512,
        guidance_scale=guidance_scale,
        prompt=prompt_template2,
        prompt_2=prompt,
    ).images[0]

    return result

# =============================================================================
# Normal Mode: Direct Inference Call
# =============================================================================
def flux_demo_normal(image, mask, words, steps, guidance_scale, seed):
    """
    Gradio main function for normal mode:
    - Directly passes the input image, mask, and word list to run_inference for inference.
    - Returns the generated result image.
    """
    result = run_inference(image, mask, words, num_steps=steps, guidance_scale=guidance_scale, seed=seed)
    return result

# =============================================================================
# Helper functions for both single-line and multi-line rendering
# =============================================================================
def extract_mask(original, drawn, threshold=30):
    """
    Extracts a binary mask from the original image and the user-drawn image:
    - If 'drawn' is a dictionary and contains a "mask" key, that mask is directly binarized.
    - Otherwise, the mask is extracted using inversion and differentiation methods.
    """
    if isinstance(drawn, dict):
        if "mask" in drawn and drawn["mask"] is not None:
            drawn_mask = np.array(drawn["mask"]).astype(np.uint8)
            if drawn_mask.ndim == 3:
                drawn_mask = cv2.cvtColor(drawn_mask, cv2.COLOR_RGB2GRAY)
            _, binary_mask = cv2.threshold(drawn_mask, 50, 255, cv2.THRESH_BINARY)
            return Image.fromarray(binary_mask).convert("RGB")
        else:
            drawn_img = np.array(drawn["image"]).astype(np.uint8)
            drawn = 255 - drawn_img
    orig_arr = np.array(original).astype(np.int16)
    drawn_arr = np.array(drawn).astype(np.int16)
    diff = np.abs(drawn_arr - orig_arr)
    diff_gray = np.mean(diff, axis=-1)
    binary_mask = (diff_gray > threshold).astype(np.uint8) * 255
    return Image.fromarray(binary_mask).convert("RGB")

def get_next_seq_number():
    """
    Finds the next available sequential number (format: 0001, 0002,...) in the 'outputs_my' directory.
    When 'result_XXXX.png' does not exist, that number is considered available, and the formatted string XXXX is returned.
    """
    counter = 1
    while True:
        seq_str = f"{counter:04d}"
        result_path = os.path.join("outputs_my", f"result_{seq_str}.png")
        if not os.path.exists(result_path):
            return seq_str
        counter += 1

# =============================================================================
# Single-line text rendering functions
# =============================================================================
def draw_glyph_flexible(font, text, width, height, max_font_size=140):
    """
    Renders text horizontally centered on a canvas of specified size and returns a PIL Image.
    Font size is automatically adjusted to fit the canvas and is limited by max_font_size.
    """
    img = Image.new(mode='RGB', size=(width, height), color='black')
    if not text or not text.strip():
        return img
    draw = ImageDraw.Draw(img)

    # Initial font size for calculating scale ratio
    g_size = 50
    try:
        new_font = font.font_variant(size=g_size)
    except:
        new_font = font

    left, top, right, bottom = new_font.getbbox(text)
    text_width_initial = max(right - left, 1)
    text_height_initial = max(bottom - top, 1)

    # Calculate scale ratios based on width and height
    width_ratio = width * 0.9 / text_width_initial
    height_ratio = height * 0.9 / text_height_initial
    ratio = min(width_ratio, height_ratio)

    # Adjust maximum font size based on original image width
    if width > 1280:
        max_font_size = 200
    final_font_size = int(g_size * ratio)
    final_font_size = min(final_font_size, max_font_size)  # Apply upper limit

    # Use the final calculated font size
    try:
        final_font = font.font_variant(size=max(final_font_size, 10))
    except:
        final_font = font

    draw.text((width / 2, height / 2), text, font=final_font, fill='white', anchor='mm')
    return img

# =============================================================================
# Multi-line text rendering functions
# =============================================================================
def insert_spaces(text, num_spaces):
    """
    Inserts a specified number of spaces between each character to adjust the spacing during text rendering.
    """
    if len(text) <= 1:
        return text
    return (' ' * num_spaces).join(list(text))


def draw_glyph2(
    font,
    text,
    polygon,
    vertAng=10,
    scale=1,
    width=512,
    height=512,
    add_space=True,
    scale_factor=2,
    rotate_resample=Image.BICUBIC,
    downsample_resample=Image.Resampling.LANCZOS
):
    big_w = width * scale_factor
    big_h = height * scale_factor

    big_polygon = polygon * scale_factor * scale
    rect = cv2.minAreaRect(big_polygon.astype(np.float32))
    box = cv2.boxPoints(rect)
    box = np.intp(box)

    w, h = rect[1]
    angle = rect[2]
    if angle < -45:
        angle += 90
    angle = -angle
    if w < h:
        angle += 90

    vert = False
    if (abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng):
        _w = max(box[:, 0]) - min(box[:, 0])
        _h = max(box[:, 1]) - min(box[:, 1])
        if _h >= _w:
            vert = True
            angle = 0

    big_img = Image.new("RGBA", (big_w, big_h), (0, 0, 0, 0))
    tmp = Image.new("RGB", big_img.size, "white")
    tmp_draw = ImageDraw.Draw(tmp)

    _, _, _tw, _th = tmp_draw.textbbox((0, 0), text, font=font)
    if _th == 0:
        text_w = 0
    else:
        w_f, h_f = float(w), float(h)
        text_w = min(w_f, h_f) * (_tw / _th)

    if text_w <= max(w, h):
        if len(text) > 1 and not vert and add_space:
            for i in range(1, 100):
                text_sp = insert_spaces(text, i)
                _, _, tw2, th2 = tmp_draw.textbbox((0, 0), text_sp, font=font)
                if th2 != 0:
                    if min(w, h) * (tw2 / th2) > max(w, h):
                        break
            text = insert_spaces(text, i-1)
        font_size = min(w, h) * 0.80
    else:
        shrink = 0.75 if vert else 0.85
        if text_w != 0:
            font_size = min(w, h) / (text_w / max(w, h)) * shrink
        else:
            font_size = min(w, h) * 0.80

    new_font = font.font_variant(size=int(font_size))
    left, top, right, bottom = new_font.getbbox(text)
    text_width = right - left
    text_height = bottom - top

    layer = Image.new("RGBA", big_img.size, (0, 0, 0, 0))
    draw_layer = ImageDraw.Draw(layer)
    cx, cy = rect[0]
    if not vert:
        draw_layer.text(
            (cx - text_width // 2, cy - text_height // 2 - top),
            text,
            font=new_font,
            fill=(255, 255, 255, 255)
        )
    else:
        _w_ = max(box[:, 0]) - min(box[:, 0])
        x_s = min(box[:, 0]) + _w_ // 2 - text_height // 2
        y_s = min(box[:, 1])
        for c in text:
            draw_layer.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255))
            _, _t, _, _b = new_font.getbbox(c)
            y_s += _b

    rotated_layer = layer.rotate(
        angle,
        expand=True,
        center=(cx, cy),
        resample=rotate_resample
    )

    xo = int((big_img.width - rotated_layer.width) // 2)
    yo = int((big_img.height - rotated_layer.height) // 2)
    big_img.paste(rotated_layer, (xo, yo), rotated_layer)

    final_img = big_img.resize((width, height), downsample_resample)
    final_np = np.array(final_img)
    return final_np

def render_glyph_multi(original, computed_mask, texts):
    mask_np = np.array(computed_mask.convert("L"))
    contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    regions = []
    for cnt in contours:
        x, y, w, h = cv2.boundingRect(cnt)
        if w * h < 50:
            continue
        regions.append((x, y, w, h, cnt))
    regions = sorted(regions, key=lambda r: (r[1], r[0]))
    
    render_img = Image.new("RGBA", original.size, (0, 0, 0, 0))
    try:
        base_font = ImageFont.truetype("resource/font/Arial-Unicode-Regular.ttf", 40)
    except:
        base_font = ImageFont.load_default()
    
    for i, region in enumerate(regions):
        if i >= len(texts):
            break
        text = texts[i].strip()
        if not text:
            continue
        cnt = region[4]
        polygon = cnt.reshape(-1, 2)
        rendered_np = draw_glyph2(
            font=base_font,
            text=text,
            polygon=polygon,
            vertAng=10,
            scale=1,
            width=original.size[0],
            height=original.size[1],
            add_space=True,
            scale_factor=1,
            rotate_resample=Image.BICUBIC,
            downsample_resample=Image.Resampling.LANCZOS
        )
        rendered_img = Image.fromarray(rendered_np, mode="RGBA")
        render_img = Image.alpha_composite(render_img, rendered_img)
    return render_img.convert("RGB")


def choose_concat_direction(height, width):
    """
    Selects the concatenation direction based on the original image's aspect ratio:
    - If height is greater than width, horizontal concatenation is used.
    - Otherwise, vertical concatenation is used.
    """
    return 'horizontal' if height > width else 'vertical'

def is_multiline_text(text):
    """
    Determines if the input text should be treated as multi-line based on line breaks.
    """
    lines = [line.strip() for line in text.splitlines() if line.strip()]
    return len(lines) > 1

# =============================================================================
# Custom Mode: Unified function that handles both single-line and multi-line
# =============================================================================
def flux_demo_custom(original_image, drawn_mask, words, steps, guidance_scale, seed):
    """
    Unified custom mode Gradio main function:
    - Automatically detects whether to use single-line or multi-line rendering based on input text
    - If text contains line breaks, uses multi-line rendering
    - If text is single line, uses single-line rendering
    """
    computed_mask = extract_mask(original_image, drawn_mask)
    
    # Determine rendering mode based on text input
    if is_multiline_text(words):
        print("Using multi-line text rendering mode")
        return flux_demo_custom_multiline(original_image, computed_mask, words, steps, guidance_scale, seed)
    else:
        print("Using single-line text rendering mode")
        return flux_demo_custom_singleline(original_image, computed_mask, words, steps, guidance_scale, seed)

def flux_demo_custom_multiline(original_image, computed_mask, words, steps, guidance_scale, seed):
    """
    Multi-line rendering mode:
    1. Splits the user-input text into a list by line, with each line corresponding to a mask region.
    2. Calls render_glyph_multi for each independent region to render skewed/curved text, generating a rendered image.
    3. Selects the concatenation direction based on the original image's dimensions.
    4. Passes the concatenated images to run_inference, returning the generated result and cropped image.
    """
    texts = read_words_from_text(words)
    render_img = render_glyph_multi(original_image, computed_mask, texts)
    width, height = original_image.size
    empty_mask = np.zeros((height, width), dtype=np.uint8)
    direction = choose_concat_direction(height, width)
    if direction == 'horizontal':
        combined_image = np.hstack((np.array(render_img), np.array(original_image)))
        combined_mask = np.hstack((empty_mask, np.array(computed_mask.convert("L"))))
    else:
        combined_image = np.vstack((np.array(render_img), np.array(original_image)))
        combined_mask = np.vstack((empty_mask, np.array(computed_mask.convert("L"))))
    combined_mask = cv2.cvtColor(combined_mask, cv2.COLOR_GRAY2RGB)
    composite_image = Image.fromarray(combined_image)
    composite_mask = Image.fromarray(combined_mask)
    result = run_inference(composite_image, composite_mask, words, num_steps=steps, guidance_scale=guidance_scale, seed=seed)

    # Crop the result, keeping only the scene image portion.
    width, height = result.size
    if direction == 'horizontal':
        cropped_result = result.crop((width // 2, 0, width, height))
    else:
        cropped_result = result.crop((0, height // 2, width, height))
    
    save_results(result, cropped_result, computed_mask, original_image, composite_image, words)
    return cropped_result, composite_image, composite_mask

def flux_demo_custom_singleline(original_image, computed_mask, words, steps, guidance_scale, seed):
    """
    Single-line rendering mode:
    1. Concatenates user input text into a single line.
    2. Renders single-line text above the original image.
    3. Calls model inference and crops the result precisely.
    """
    # Process text, concatenate into single line
    text_lines = read_words_from_text(words)
    single_line_text = ' '.join(text_lines)

    # Calculate dimensions and generate concatenated image and mask
    w, h = original_image.size
    text_height_ratio = 0.15625
    text_render_height = int(w * text_height_ratio)
    
    # Load font
    try:
        font = ImageFont.truetype("resource/font/Arial-Unicode-Regular.ttf", 60)
    except IOError:
        font = ImageFont.load_default()
        print("Warning: Font not found, using default font.")

    # Render single-line text image
    text_render_pil = draw_glyph_flexible(font, single_line_text, width=w, height=text_render_height)
    # Create pure black mask with same size as text rendering
    text_mask_pil = Image.new("RGB", text_render_pil.size, "black")
    
    # Always use vertical concatenation
    composite_image = Image.fromarray(np.vstack((np.array(text_render_pil), np.array(original_image))))
    composite_mask = Image.fromarray(np.vstack((np.array(text_mask_pil), np.array(computed_mask))))
    
    # Call model inference
    full_result = run_inference(composite_image, composite_mask, words, num_steps=steps, guidance_scale=guidance_scale, seed=seed)

    # Crop result proportionally, keeping only the scene image portion
    res_w, res_h = full_result.size
    orig_h = h  # Original scene image height
    # Calculate crop line top edge position
    crop_top_edge = int(res_h * (text_render_height / (orig_h + text_render_height)))
    cropped_result = full_result.crop((0, crop_top_edge, res_w, res_h))
    
    save_results(full_result, cropped_result, computed_mask, original_image, composite_image, words)
    return cropped_result, composite_image, composite_mask

def save_results(result, cropped_result, computed_mask, original_image, composite_image, words):
    """
    Save all related images and text files
    """
    os.makedirs("outputs_my", exist_ok=True)
    os.makedirs("outputs_my/crop", exist_ok=True)
    os.makedirs("outputs_my/mask", exist_ok=True)
    os.makedirs("outputs_my/ori", exist_ok=True)
    os.makedirs("outputs_my/composite", exist_ok=True)
    os.makedirs("outputs_my/txt", exist_ok=True)

    seq = get_next_seq_number()
    result_filename = os.path.join("outputs_my", f"result_{seq}.png")
    crop_filename = os.path.join("outputs_my", "crop", f"crop_{seq}.png")
    mask_filename = os.path.join("outputs_my", "mask", f"mask_{seq}.png")
    ori_filename = os.path.join("outputs_my", "ori", f"ori_{seq}.png")
    composite_filename = os.path.join("outputs_my", "composite", f"composite_{seq}.png")
    txt_filename = os.path.join("outputs_my", "txt", f"words_{seq}.txt")

    # Save images
    result.save(result_filename)
    cropped_result.save(crop_filename)
    computed_mask.save(mask_filename)
    original_image.save(ori_filename)
    composite_image.save(composite_filename)
    with open(txt_filename, "w", encoding="utf-8") as f:
        f.write(words)

# =============================================================================
# Gradio Interface
# =============================================================================
with gr.Blocks(title="Flux Inference Demo") as demo:
    gr.Markdown("## Flux Inference Demo")
    with gr.Tabs():
        with gr.TabItem("Custom Mode"):
            with gr.Row():
                with gr.Column(scale=1, min_width=350):
                    gr.Markdown("### Image Input")
                    original_image_custom = gr.Image(type="pil", label="Upload Original Image")
                    gr.Markdown("### Draw Mask on Image")
                    mask_drawing_custom = gr.Image(type="pil", label="Draw Mask on Original Image", tool="sketch")

                with gr.Column(scale=1, min_width=350):
                    gr.Markdown("### Parameter Settings")
                    words_custom = gr.Textbox(
                        lines=5, 
                        placeholder="Enter text here (single line recommended, faster and stronger).\nMultiple lines are supported, with each line rendered in corresponding mask regions.", 
                        label="Text Input"
                    )
                    steps_custom = gr.Slider(minimum=10, maximum=100, step=1, value=30, label="Inference Steps")
                    guidance_scale_custom = gr.Slider(minimum=1, maximum=50, step=1, value=30, label="Guidance Scale")
                    seed_custom = gr.Number(value=42, label="Random Seed")
                    run_custom = gr.Button("Generate Results")

            with gr.Tabs():
                with gr.TabItem("Generated Results"):
                    output_result_custom = gr.Image(type="pil", label="Generated Results")
                with gr.TabItem("Input Preview"):
                    output_composite_custom = gr.Image(type="pil", label="Concatenated Original Image")
                    output_mask_custom = gr.Image(type="pil", label="Concatenated Mask")

            original_image_custom.change(fn=lambda x: x, inputs=original_image_custom, outputs=mask_drawing_custom)
            run_custom.click(fn=flux_demo_custom,
                inputs=[original_image_custom, mask_drawing_custom, words_custom, steps_custom, guidance_scale_custom, seed_custom],
                outputs=[output_result_custom, output_composite_custom, output_mask_custom])

        with gr.TabItem("Normal Mode"):
            with gr.Row():
                with gr.Column(scale=1, min_width=350):
                    gr.Markdown("### Image Input")
                    image_normal = gr.Image(type="pil", label="Image Input")
                    gr.Markdown("### Mask Input")
                    mask_normal = gr.Image(type="pil", label="Mask Input")
                with gr.Column(scale=1, min_width=350):
                    gr.Markdown("### Parameter Settings")
                    words_normal = gr.Textbox(lines=5, placeholder="Please enter words here, one per line", label="Text List")
                    steps_normal = gr.Slider(minimum=10, maximum=100, step=1, value=30, label="Inference Steps")
                    guidance_scale_normal = gr.Slider(minimum=1, maximum=50, step=1, value=30, label="Guidance Scale")
                    seed_normal = gr.Number(value=42, label="Random Seed")
                    run_normal = gr.Button("Generate Results")
                    output_normal = gr.Image(type="pil", label="Generated Results")
            run_normal.click(fn=flux_demo_normal,
                inputs=[image_normal, mask_normal, words_normal, steps_normal, guidance_scale_normal, seed_normal],
                outputs=output_normal)

    gr.Markdown(
        """
        ### Instructions
        - **Custom Mode**: 
          - Upload an original image, then draw a mask on it
          - **Single-line mode**: Enter text without line breaks - all text will be joined and rendered as one line above the image
          - **Multi-line mode**: Enter text with line breaks - each line will be rendered in the corresponding mask region with skewed/curved effects
          - The system automatically detects which mode to use based on your text input
        - **Normal Mode**: Directly upload an image, mask, and a list of words to generate the result image.
        """
    )

if __name__ == "__main__":
    check_min_version("0.30.1")
    demo.launch()