File size: 30,581 Bytes
3943768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
import base64
import functools
import os
import tempfile
import time
import types
import uuid
from functools import partial
from io import BytesIO
import numpy as np
from PIL.Image import Resampling

from gradio_utils.grclient import check_job
from src.enums import valid_imagegen_models, valid_imagechange_models, valid_imagestyle_models, docs_joiner_default, \
    llava16_model_max_length, llava16_image_tokens, llava16_image_fudge, VIDEO_EXTENSIONS, IMAGE_EXTENSIONS
from src.image_utils import fix_image_file
from src.utils import is_gradio_version4, get_docs_tokens, get_limited_text, makedirs, call_subprocess_onetask, \
    have_fiftyone, sanitize_filename

def is_animated_gif(file_path):
    if not file_path.endswith('.gif'):
        return False
    from PIL import Image, UnidentifiedImageError
    try:
        gif = Image.open(file_path)
    except (FileNotFoundError, UnidentifiedImageError):
        return False
    try:
        gif.seek(1)
    except EOFError:
        return False
    else:
        return True


def gif_to_mp4(gif_path):
    from moviepy.editor import VideoFileClip
    """
    Convert an animated GIF to an MP4 video.

    :param gif_path: Path to the input GIF file.
    :param mp4_path: Path to the output MP4 file.
    """
    clip = VideoFileClip(gif_path)
    mp4_path = gif_path.replace('.gif', '.mp4')
    clip.write_videofile(mp4_path, codec='libx264')
    return mp4_path


def is_video_file(file_path):
    """
    Determine if the file is a video by checking its extension, frame count, and frame rate.

    :param file_path: Path to the file.
    :return: True if the file is a video, False otherwise.
    """
    ext = os.path.splitext(file_path)[-1].lower()
    if ext not in VIDEO_EXTENSIONS:
        return False

    import cv2
    video = cv2.VideoCapture(file_path)
    frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
    frame_rate = video.get(cv2.CAP_PROP_FPS)
    video.release()

    # A valid video should have more than 0 frames and a positive frame rate
    return frame_count >= 1 and frame_rate > 0


def img_to_base64(image_file, resolution=None, output_format=None, str_bytes=True):
    # assert image_file.lower().endswith('jpg') or image_file.lower().endswith('jpeg')
    from PIL import Image

    from pathlib import Path
    ext = Path(image_file).suffix
    iformat = IMAGE_EXTENSIONS.get(ext)
    assert iformat is not None, "Invalid file extension %s for file %s" % (ext, image_file)

    image = Image.open(image_file)

    if resolution:
        image = image.resize(resolution, resample=Resampling.BICUBIC)

    if output_format:
        oformat = output_format.upper()
    elif iformat not in ['JPEG', 'PNG']:
        # use jpeg by default if nothing set, so most general format allowed
        oformat = 'JPEG'
    else:
        oformat = iformat

    buffered = BytesIO()
    image.save(buffered, format=oformat)
    img_str = base64.b64encode(buffered.getvalue())

    # FIXME: unsure about below
    if str_bytes:
        img_str = str(bytes("data:image/%s;base64," % oformat.lower(), encoding='utf-8') + img_str)
    else:
        img_str = f"data:image/{oformat.lower()};base64,{img_str.decode('utf-8')}"

    return img_str


def base64_to_img(img_str, output_path):
    """
    Convert a base64 string to an image or video file.

    :param img_str: The base64 encoded string with the image or video data.
    :param output_path: The path (without extension) where the output file will be saved.
    :return: The path to the saved file.
    """
    if img_str.startswith("b'"):
        # check if was a string of bytes joined like when str_bytes=True in above function
        img_str = img_str[2:-1]  # This removes the first b' and the last '

    # Split the string on "," to separate the metadata from the base64 data
    meta, base64_data = img_str.split(",", 1)
    # Extract the format from the metadata
    img_format = meta.split(';')[0].split('/')[-1]
    # Decode the base64 string to bytes
    img_bytes = base64.b64decode(base64_data)
    # Create output file path with the correct format extension
    output_file = f"{output_path}.{img_format}"
    # Write the bytes to a file
    with open(output_file, "wb") as f:
        f.write(img_bytes)
    print(f"Image saved to {output_file} with format {img_format}")
    return output_file


def video_to_base64frames(video_path):
    import cv2
    video = cv2.VideoCapture(video_path)

    base64Frames = []
    while video.isOpened():
        success, frame = video.read()
        if not success:
            break
        _, buffer = cv2.imencode(".jpg", frame)
        base64Frames.append(base64.b64encode(buffer).decode("utf-8"))

    video.release()
    print(len(base64Frames), "frames read.")
    return base64Frames


@functools.lru_cache(maxsize=10000, typed=False)
def video_to_frames(video_path, output_dir, resolution=None, image_format="jpg", video_frame_period=None,
                    extract_frames=None,
                    verbose=False):
    import cv2
    """
    Convert video to frames, save them as image files in the specified format, and return the list of file names.

    :param video_path: Path to the input video file.
    :param output_dir: Directory where the output frames will be saved.
    :param resolution: Tuple specifying the desired resolution (width, height) or None to keep the original resolution.
    :param image_format: String specifying the desired image format (e.g., "jpg", "png").
    :param video_frame_period: How often to sample frames from the video. If None, every 20th frame is saved.
      e.g. if pass non-real-time video, can set to 1 to save all frames, to mimic passing actual frames separately otherwise
    :param extract_frames: Number of frames to extract from the video. If None, all frames are saved.
    :param verbose: Boolean to control whether to print progress messages.
    :return: List of file names for the saved frames.

    Example usage:
    file_names = video_to_frames("input_video.mp4", "output_frames", resolution=(640, 480), image_format="png", verbose=True)
    print(file_names)
    """
    if output_dir is None:
        output_dir = os.path.join(tempfile.gettempdir(), 'image_path_%s' % sanitize_filename(video_path))

    enable_fiftyone = True  # optimal against issues if using function server
    if enable_fiftyone and \
            have_fiftyone and \
            (video_frame_period is not None and video_frame_period < 1 or not os.path.isfile(video_path)):
        # handles either automatic period or urls
        from src.vision.extract_movie import extract_unique_frames
        args = ()
        urls = [video_path] if not os.path.isfile(video_path) else None
        file = video_path if os.path.isfile(video_path) else None
        kwargs = {'urls': urls, 'file': file, 'download_dir': None, 'export_dir': output_dir,
                  'extract_frames': extract_frames}
        # fifty one is complex program and leaves around processes
        if False:  # NOTE: Assumes using function server to handle isolation if want production grade behavior
            func_new = partial(call_subprocess_onetask, extract_unique_frames, args, kwargs)
        else:
            func_new = functools.partial(extract_unique_frames, *args, **kwargs)
        export_dir = func_new()
        return [os.path.join(export_dir, x) for x in os.listdir(export_dir)]

    if video_frame_period and video_frame_period < 1:
        video_frame_period = None
    if video_frame_period in [None, 0]:
        # e.g. if no fiftyone and so can't do 0 case, then assume ok to do period based
        total_frames = count_frames(video_path)
        extract_frames = min(20, extract_frames or 20)  # no more than 20 frames total for now
        video_frame_period = total_frames // extract_frames

    video = cv2.VideoCapture(video_path)
    makedirs(output_dir)

    image_format = image_format or '.jpg'

    frame_count = 0
    file_names = []
    while True:
        success, frame = video.read()
        if not success:
            break

        # keep first frame, then keep a frame every video_frame_resolution frames
        if frame_count % video_frame_period != 0:
            frame_count += 1
            continue
        if resolution:
            frame = cv2.resize(frame, resolution)

        frame_filename = os.path.join(output_dir, f"frame_{frame_count:04d}.{image_format}")
        cv2.imwrite(frame_filename, frame)
        file_names.append(frame_filename)
        frame_count += 1
    video.release()

    if verbose:
        print(f"{frame_count} frames saved to {output_dir}.")

    return file_names


def count_frames(video_path):
    import cv2
    # Open the video file
    video = cv2.VideoCapture(video_path)

    # Check if video opened successfully
    if not video.isOpened():
        print("Error: Could not open video.")
        return -1

    # Get the total number of frames
    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))

    # Release the video capture object
    video.release()

    return total_frames


def process_file_list(file_list, output_dir, resolution=None, image_format="jpg",
                      rotate_align_resize_image=True,
                      video_frame_period=None,
                      extract_frames=None,
                      verbose=False):
    # FIXME: resolution is not used unless video, could use for every case, but resolution is set later when byte encoding for LLMs
    """
    Process a list of files, converting any videos to frames and updating the list to only contain image files.

    :param file_list: List of file paths to be processed.
    :param output_dir: Directory where the output frames will be saved.
    :param resolution: Tuple specifying the desired resolution (width, height) or None to keep the original resolution.
      Does not affect images as inputs, handled elsewhere when converting to base64 for LLM
    :param image_format: String specifying the desired image format (e.g., "jpg", "png").
    :param rotate_align_resize_image:  Whether to apply rotation, alignment, resize before giving to LLM
    :param video_frame_period: Period to save frames, if <1 then automatic
    :param extract_frames: how many frames to extract if automatic period mode
    :param verbose: Boolean to control whether to print progress messages.
    :return: Updated list of file names containing only image files.
    """
    if file_list is None:
        file_list = []
    if image_format is None:
        image_format = 'jpg'

    image_files = []

    for file in file_list:
        # i.e. if not file, then maybe youtube url
        is_maybe_video = os.path.isfile(file) and is_video_file(file) or not os.path.isfile(file) or is_animated_gif(
            file)
        if is_animated_gif(file):
            # FIXME: could convert gif -> mp4 with gif_to_mp4(gif_path)()
            # fiftyone can't handle animated gifs
            extract_frames = None
            if video_frame_period is not None and video_frame_period < 1:
                video_frame_period = None

        if is_maybe_video:
            # If it's a valid video, extract frames
            if verbose:
                print(f"Processing video file: {file}")
            # output_dir is None means only use file for location
            frame_files = video_to_frames(file, None, resolution, image_format, video_frame_period,
                                          extract_frames, verbose)
            image_files.extend(frame_files)
        else:
            # If it's not a valid video, add it to the image file list
            if rotate_align_resize_image:
                file_fixed = fix_image_file(file, do_align=True, do_rotate=True, do_pad=False, relaxed_resize=True)
            else:
                file_fixed = file
            image_files.append(file_fixed)

    return image_files


def fix_llava_prompt(file,
                     prompt=None,
                     allow_prompt_auto=True,
                     ):
    if prompt in ['auto', None] and allow_prompt_auto:
        prompt = "Describe the image and what does the image say?"
        # prompt = "According to the image, describe the image in full details with a well-structured response."
        if file in ['', None]:
            # let model handle if no prompt and no file
            prompt = ''
    # allow prompt = '', will describe image by default
    if prompt is None:
        if os.environ.get('HARD_ASSERTS'):
            raise ValueError('prompt is None')
        else:
            prompt = ''
    return prompt


def llava_prep(file_list,
               llava_model,
               image_model='llava-v1.6-vicuna-13b',
               client=None):
    assert client is not None or len(file_list) == 1

    file_list_new = []
    image_model_list_new = []
    for file in file_list:
        image_model_new, client, file_new = _llava_prep(file,
                                                        llava_model,
                                                        image_model=image_model,
                                                        client=client)
        file_list_new.append(file_new)
        image_model_list_new.append(image_model_new)
    assert len(image_model_list_new) >= 1
    assert len(file_list_new) >= 1
    return image_model_list_new[0], client, file_list_new


def _llava_prep(file,
                llava_model,
                image_model='llava-v1.6-vicuna-13b',
                client=None):
    prefix = ''
    if llava_model.startswith('http://'):
        prefix = 'http://'
    if llava_model.startswith('https://'):
        prefix = 'https://'
    llava_model = llava_model[len(prefix):]

    llava_model_split = llava_model.split(':')
    assert len(llava_model_split) >= 2
    # FIXME: Allow choose model in UI
    if len(llava_model_split) >= 2:
        pass
        # assume default model is ok
        # llava_ip = llava_model_split[0]
        # llava_port = llava_model_split[1]
    if len(llava_model_split) >= 3:
        image_model = llava_model_split[2]
        llava_model = ':'.join(llava_model_split[:2])
    # add back prefix
    llava_model = prefix + llava_model

    if client is None:
        from gradio_utils.grclient import GradioClient
        client = GradioClient(llava_model, check_hash=False, serialize=is_gradio_version4)
        client.setup()

    if not is_gradio_version4 and file and os.path.isfile(file):
        file = img_to_base64(file)

    assert image_model, "No image model specified"

    if isinstance(file, np.ndarray):
        from PIL import Image
        im = Image.fromarray(file)
        file = "%s.jpeg" % str(uuid.uuid4())
        im.save(file)

    return image_model, client, file


server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"


def get_prompt_with_texts(texts, prompt, max_new_tokens, min_max_new_tokens, tokenizer):
    if tokenizer is None:
        raise RuntimeError("Not setup for multi-image without tokenizer")
        # from transformers import AutoTokenizer
        # tokenizer = AutoTokenizer.from_pretrained(base_model)
    if hasattr(tokenizer, 'model_max_length'):
        model_max_length = tokenizer.model_max_length
    else:
        model_max_length = llava16_model_max_length

    user_part = '\n\nReduce the above information into single correct answer to the following question: ' + prompt
    user_part_tokens = len(tokenizer.encode(user_part))

    text_context_list = ['Answer #%s:\n\n%s' % (ii, text) for ii, text in enumerate(texts)]

    # see if too many tokens
    text_tokens_trial = len(tokenizer.encode(docs_joiner_default.join(text_context_list)))
    if user_part_tokens + text_tokens_trial + max_new_tokens >= model_max_length:
        max_new_tokens = min_max_new_tokens
    fudge = llava16_image_fudge
    max_input_tokens = model_max_length - max_new_tokens - fudge  # fudge for extra chars

    top_k_docs, one_doc_size, num_doc_tokens = \
        get_docs_tokens(tokenizer, text_context_list=text_context_list, max_input_tokens=max_input_tokens)
    text_context_list_cut = text_context_list[:top_k_docs]
    texts_joined = docs_joiner_default.join(text_context_list_cut)

    prompt_with_texts = '\n"""\n' + texts_joined + '\n"""\n'
    prompt_with_texts += user_part

    return prompt_with_texts.replace('image', 'document').replace('Image', 'Document')


def get_llava_response(file=None,
                       llava_model=None,
                       prompt=None,
                       chat_conversation=[],
                       allow_prompt_auto=False,
                       image_model='llava-v1.6-vicuna-13b', temperature=0.2,
                       top_p=0.7, max_new_tokens=512,
                       min_max_new_tokens=512,
                       tokenizer=None,
                       image_process_mode="Default",
                       include_image=False,
                       client=None,
                       max_time=None,
                       force_stream=True,
                       verbose=False,
                       ):
    max_new_tokens = min(max_new_tokens, 1024)  # for hard_cutoff to be easy to know

    kwargs = locals().copy()

    force_stream |= isinstance(file, list) and len(file) > 1
    if isinstance(file, str):
        file_list = [file]
    elif isinstance(file, list):
        file_list = file
        if len(file_list) == 0:
            file_list = [None]
    else:
        file_list = [None]

    if force_stream:
        text = ''
        for res in get_llava_stream(**kwargs):
            text = res
        return text, prompt

    image_model = os.path.basename(image_model)  # in case passed HF link
    prompt = fix_llava_prompt(file_list, prompt, allow_prompt_auto=allow_prompt_auto)
    max_new_tokens1 = max_new_tokens if len(file_list) <= 4 else min(max_new_tokens, min_max_new_tokens)
    if tokenizer:
        model_max_length = tokenizer.model_max_length
    else:
        model_max_length = llava16_model_max_length
    image_tokens = llava16_image_tokens if len(file_list) >= 1 and file_list[0] is not None else 0
    fudge = llava16_image_fudge
    hard_limit_tokens = model_max_length - max_new_tokens1 - fudge - image_tokens
    prompt = get_limited_text(hard_limit_tokens, prompt, tokenizer, verbose=False)

    image_model, client, file_list = \
        llava_prep(file_list, llava_model,
                   image_model=image_model,
                   client=client)

    reses = []
    for file in file_list:
        res = client.predict(prompt,
                             chat_conversation if len(file_list) == 1 else [],
                             file,
                             image_process_mode,
                             include_image,
                             image_model,
                             temperature,
                             top_p,
                             max_new_tokens1,
                             api_name='/textbox_api_submit')
        reses.append(res)

    if len(reses) > 1:
        reses = [x for x in reses if server_error_msg not in x]
        prompt_with_texts = get_prompt_with_texts(reses, prompt, max_new_tokens, min_max_new_tokens, tokenizer)
        res = client.predict(prompt_with_texts,
                             chat_conversation,
                             None,
                             image_process_mode,
                             include_image,
                             image_model,
                             temperature,
                             top_p,
                             max_new_tokens,
                             api_name='/textbox_api_submit')
    else:
        res = reses[0]

    return res, prompt


def get_llava_stream(file, llava_model,
                     prompt=None,
                     chat_conversation=[],
                     allow_prompt_auto=False,
                     image_model='llava-v1.6-vicuna-13b', temperature=0.2,
                     top_p=0.7, max_new_tokens=512,
                     min_max_new_tokens=512,
                     tokenizer=None,
                     image_process_mode="Default",
                     include_image=False,
                     client=None,
                     verbose_level=0,
                     max_time=None,
                     force_stream=True,  # dummy arg
                     verbose=False,
                     ):
    max_new_tokens = min(max_new_tokens, 1024)  # for hard_cutoff to be easy to know

    if isinstance(file, str):
        file_list = [file]
    elif isinstance(file, list):
        file_list = file
        if len(file_list) == 0:
            file_list = [None]
    else:
        file_list = [None]

    image_model = os.path.basename(image_model)  # in case passed HF link
    prompt = fix_llava_prompt(file_list, prompt, allow_prompt_auto=allow_prompt_auto)
    max_new_tokens1 = max_new_tokens if len(file_list) <= 4 else min(max_new_tokens, min_max_new_tokens)
    if tokenizer:
        model_max_length = tokenizer.model_max_length
    else:
        model_max_length = llava16_model_max_length
    image_tokens = llava16_image_tokens if len(file_list) >= 1 and file_list[0] is not None else 0
    fudge = llava16_image_fudge
    hard_limit_tokens = model_max_length - max_new_tokens1 - fudge - image_tokens
    prompt = get_limited_text(hard_limit_tokens, prompt, tokenizer)

    image_model, client, file_list = \
        llava_prep(file_list, llava_model,
                   image_model=image_model,
                   client=client)

    jobs = []
    for file in file_list:
        job = client.submit(prompt,
                            chat_conversation,
                            file,
                            image_process_mode,
                            include_image,
                            image_model,
                            temperature,
                            top_p,
                            max_new_tokens1,
                            api_name='/textbox_api_submit')
        jobs.append(job)

    t0 = time.time()
    job_outputs_nums = [0] * len(jobs)
    texts = [''] * len(jobs)
    done_all = False
    reses = [''] * len(jobs)
    while True:
        for ji, job in enumerate(jobs):
            if verbose_level == 2:
                print("Inside: %s" % llava_model, time.time() - t0, flush=True)
            e = check_job(job, timeout=0, raise_exception=False)
            if e is not None:
                continue
            if max_time is not None and time.time() - t0 > max_time:
                done_all = True
                break
            outputs_list = job.outputs().copy()
            job_outputs_num_new = len(outputs_list[job_outputs_nums[ji]:])
            for num in range(job_outputs_num_new):
                reses[ji] = outputs_list[job_outputs_nums[ji] + num]
                if verbose_level == 2:
                    print('Stream %d: %s' % (num, reses[ji]), flush=True)
                elif verbose_level == 1:
                    print('Stream %d' % (job_outputs_nums[ji] + num), flush=True)
                if reses[ji]:
                    texts[ji] = reses[ji]
                    if len(jobs) == 1:
                        yield texts[ji]
            job_outputs_nums[ji] += job_outputs_num_new
            time.sleep(0.005)
        if done_all or all([job.done() for job in jobs]):
            break

    for ji, job in enumerate(jobs):
        e = check_job(job, timeout=0, raise_exception=False)
        if e is not None:
            continue
        outputs_list = job.outputs().copy()
        job_outputs_num_new = len(outputs_list[job_outputs_nums[ji]:])
        for num in range(job_outputs_num_new):
            reses[ji] = outputs_list[job_outputs_nums[ji] + num]
            if verbose_level == 2:
                print('Final Stream %d: %s' % (num, reses[ji]), flush=True)
            elif verbose_level == 1:
                print('Final Stream %d' % (job_outputs_nums[ji] + num), flush=True)
            if reses[ji]:
                texts[ji] = reses[ji]
                if len(jobs) == 1:
                    yield texts[ji]
        job_outputs_nums[ji] += job_outputs_num_new
        if verbose_level == 1:
            print("total job_outputs_num=%d" % job_outputs_nums[ji], flush=True)

    if len(jobs) > 1:
        # recurse without image(s)
        ntexts_before = len(texts)
        texts = [x for x in texts if server_error_msg not in x]
        ntexts_after = len(texts)
        if ntexts_after != ntexts_before:
            print("texts: %s -> %s" % (ntexts_before, ntexts_after))
        prompt_with_texts = get_prompt_with_texts(texts, prompt, max_new_tokens, min_max_new_tokens, tokenizer)
        text = ''
        max_new_tokens = max_new_tokens if len(jobs) > 4 else min(max_new_tokens, min_max_new_tokens)
        for res in get_llava_stream(None,
                                    llava_model,
                                    prompt=prompt_with_texts,
                                    chat_conversation=chat_conversation,
                                    allow_prompt_auto=allow_prompt_auto,
                                    image_model=image_model,
                                    temperature=temperature,
                                    top_p=top_p,
                                    # avoid long outputs
                                    max_new_tokens=max_new_tokens,
                                    min_max_new_tokens=min_max_new_tokens,
                                    tokenizer=tokenizer,
                                    image_process_mode=image_process_mode,
                                    include_image=include_image,
                                    client=client,
                                    verbose_level=verbose_level,
                                    max_time=max_time,
                                    force_stream=force_stream,  # dummy arg
                                    verbose=verbose,
                                    ):
            text = res
            yield text
    else:
        assert len(texts) == 1
        text = texts[0]

    return text


def get_image_model_dict(enable_image,
                         image_models,
                         image_gpu_ids,
                         ):
    image_dict = {}
    if not enable_image:
        return image_dict

    if image_gpu_ids is None:
        image_gpu_ids = ['auto'] * len(image_models)
    if not image_gpu_ids:
        image_gpu_ids = ['auto'] * len(image_models)

    for image_model_name in valid_imagegen_models + valid_imagechange_models + valid_imagestyle_models:
        if image_model_name in image_models:
            imagegen_index = image_models.index(image_model_name)
            if image_model_name == 'sdxl_turbo':
                from src.vision.sdxl_turbo import get_pipe_make_image, make_image
            elif image_model_name == 'playv2':
                from src.vision.playv2 import get_pipe_make_image, make_image
            elif image_model_name == 'sdxl':
                from src.vision.stable_diffusion_xl import get_pipe_make_image, make_image
            elif image_model_name == 'sd3':
                from src.vision.stable_diffusion_xl import get_pipe_make_image, make_image
                get_pipe_make_image = functools.partial(get_pipe_make_image,
                                                        base_model='stabilityai/stable-diffusion-3-medium-diffusers',
                                                        refiner_model=None)
                make_image = functools.partial(make_image,
                                               base_model='stabilityai/stable-diffusion-3-medium-diffusers',
                                               refiner_model=None)
            elif image_model_name == 'flux.1-dev':
                from src.vision.flux import get_pipe_make_image, make_image
            elif image_model_name == 'flux.1-schnell':
                from src.vision.flux import get_pipe_make_image_2 as get_pipe_make_image
                from src.vision.flux import make_image
            elif image_model_name == 'sdxl_change':
                from src.vision.sdxl_turbo import get_pipe_change_image as get_pipe_make_image, change_image
                make_image = change_image
            # FIXME: style
            else:
                raise ValueError("Invalid image_model_name=%s" % image_model_name)
            pipe = get_pipe_make_image(gpu_id=image_gpu_ids[imagegen_index])
            image_dict[image_model_name] = dict(pipe=pipe, make_image=make_image)
    return image_dict


def pdf_to_base64_pngs(pdf_path, quality=75, max_size=(1024, 1024), ext='png', pages=None):
    """
    Define the function to convert a pdf slide deck to a list of images. Note that we need to ensure we resize images to keep them within Claude's size limits.
    """
    # https://github.com/anthropics/anthropic-cookbook/blob/main/multimodal/reading_charts_graphs_powerpoints.ipynb
    from PIL import Image
    import io
    import fitz
    import tempfile

    # Open the PDF file
    doc = fitz.open(pdf_path)

    # Iterate through each page of the PDF
    images = []
    if pages is None:
        pages = list(range(doc.page_count))
    else:
        assert isinstance(pages, (list, tuple, types.GeneratorType))

    for page_num in pages:
        # Load the page
        page = doc.load_page(page_num)

        # Render the page as a PNG image
        pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72))

        # Save the PNG image
        output_path = f"{tempfile.mkdtemp()}/page_{page_num + 1}.{ext}"
        pix.save(output_path)
        images.append(output_path)
    # Close the PDF document
    doc.close()

    if ext == 'png':
        iformat = 'PNG'
    elif ext in ['jpeg', 'jpg']:
        iformat = 'JPEG'
    else:
        raise ValueError("No such ext=%s" % ext)

    images = [Image.open(image) for image in images]
    base64_encoded_pngs = []
    for image in images:
        # Resize the image if it exceeds the maximum size
        if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
            image.thumbnail(max_size, Image.Resampling.LANCZOS)
        image_data = io.BytesIO()
        image.save(image_data, format=iformat, optimize=True, quality=quality)
        image_data.seek(0)
        base64_encoded = base64.b64encode(image_data.getvalue()).decode('utf-8')
        base64_encoded_pngs.append(base64_encoded)

    return base64_encoded_pngs