Spaces:
Build error
Build error
| import base64 | |
| import logging | |
| import re | |
| import time | |
| from functools import partial | |
| from io import BytesIO | |
| import gradio as gr | |
| import torch | |
| from extensions.multimodal.multimodal_embedder import MultimodalEmbedder | |
| from modules import shared | |
| params = { | |
| "add_all_images_to_prompt": False, | |
| # device to run vision encoder on | |
| "vision_device": None, | |
| # bits to load vision encoder in, either 16 or 32 | |
| "vision_bits": 32, | |
| # device to run multimodal projector on | |
| "projector_device": None, | |
| # multimodal projector bits, either 32 or 16 | |
| "projector_bits": 32 | |
| } | |
| # If 'state' is True, will hijack the next chat generation | |
| input_hijack = { | |
| 'state': False, | |
| 'value': ["", ""] | |
| } | |
| # initialized in ui, so that params are loaded from settings | |
| multimodal_embedder: MultimodalEmbedder = None | |
| def add_chat_picture(picture, text, visible_text): | |
| # resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable) | |
| max_hw, min_hw = max(picture.size), min(picture.size) | |
| aspect_ratio = max_hw / min_hw | |
| shortest_edge = int(max(300 / aspect_ratio, 224)) | |
| longest_edge = int(shortest_edge * aspect_ratio) | |
| w = shortest_edge if picture.width < picture.height else longest_edge | |
| h = shortest_edge if picture.width >= picture.height else longest_edge | |
| picture = picture.resize((w, h)) | |
| buffer = BytesIO() | |
| picture.save(buffer, format="JPEG") | |
| img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| image = f'<img src="data:image/jpeg;base64,{img_str}">' | |
| if '<image>' in text: | |
| text = text.replace('<image>', image) | |
| else: | |
| text = text + '\n' + image | |
| if visible_text == '' or visible_text is None: | |
| visible_text = text | |
| elif '<image>' in visible_text: | |
| visible_text = visible_text.replace('<image>', image) | |
| else: | |
| visible_text = visible_text + '\n' + image | |
| return text, visible_text | |
| def custom_tokenized_length(prompt): | |
| return multimodal_embedder.len_in_tokens(prompt) | |
| def tokenizer_modifier(state, prompt, input_ids, input_embeds): | |
| global params | |
| start_ts = time.time() | |
| image_match = re.search(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', prompt) | |
| if image_match is None: | |
| return prompt, input_ids, input_embeds | |
| prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params) | |
| logging.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s') | |
| return (prompt, | |
| input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64), | |
| input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype)) | |
| def ui(): | |
| global multimodal_embedder | |
| multimodal_embedder = MultimodalEmbedder(params) | |
| with gr.Column(): | |
| picture_select = gr.Image(label='Send a picture', type='pil') | |
| # The models don't seem to deal well with multiple images | |
| single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one') | |
| # Prepare the input hijack | |
| picture_select.upload( | |
| lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}), | |
| [picture_select], | |
| None | |
| ) | |
| picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["", ""]}), None, None) | |
| single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None) | |
| shared.gradio['Generate'].click(lambda: None, None, picture_select) | |
| shared.gradio['textbox'].submit(lambda: None, None, picture_select) | |