Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| import random | |
| import numpy as np | |
| from PIL import Image | |
| import imagehash | |
| import cv2 | |
| import os | |
| import spaces | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension | |
| from transformers.image_transforms import resize, to_channel_dimension_format | |
| from typing import List | |
| from PIL import Image | |
| from collections import Counter | |
| from datasets import load_dataset, concatenate_datasets | |
| DEVICE = torch.device("cuda") | |
| PROCESSOR = AutoProcessor.from_pretrained( | |
| "HuggingFaceM4/idefics2_raven_finetuned", | |
| token=os.environ["HF_AUTH_TOKEN"], | |
| ) | |
| MODEL = AutoModelForCausalLM.from_pretrained( | |
| "HuggingFaceM4/idefics2_raven_finetuned", | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| token=os.environ["HF_AUTH_TOKEN"], | |
| ).to(DEVICE) | |
| if MODEL.config.use_resampler: | |
| image_seq_len = MODEL.config.perceiver_config.resampler_n_latents | |
| else: | |
| image_seq_len = ( | |
| MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size | |
| ) ** 2 | |
| BOS_TOKEN = PROCESSOR.tokenizer.bos_token | |
| BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids | |
| DATASET = load_dataset("HuggingFaceM4/RAVEN_rendered", split="validation") | |
| ## Utils | |
| def convert_to_rgb(image): | |
| # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background | |
| # for transparent images. The call to `alpha_composite` handles this case | |
| if image.mode == "RGB": | |
| return image | |
| image_rgba = image.convert("RGBA") | |
| background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) | |
| alpha_composite = Image.alpha_composite(background, image_rgba) | |
| alpha_composite = alpha_composite.convert("RGB") | |
| return alpha_composite | |
| # The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip, | |
| # so this is a hack in order to redefine ONLY the transform method | |
| def custom_transform(x): | |
| x = convert_to_rgb(x) | |
| x = to_numpy_array(x) | |
| x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR) | |
| x = PROCESSOR.image_processor.rescale(x, scale=1 / 255) | |
| x = PROCESSOR.image_processor.normalize( | |
| x, | |
| mean=PROCESSOR.image_processor.image_mean, | |
| std=PROCESSOR.image_processor.image_std | |
| ) | |
| x = to_channel_dimension_format(x, ChannelDimension.FIRST) | |
| x = torch.tensor(x) | |
| return x | |
| def pixel_difference(image1, image2): | |
| def color(im): | |
| arr = np.array(im).flatten() | |
| arr_list = arr.tolist() | |
| counts = Counter(arr_list) | |
| most_common = counts.most_common(2) | |
| if most_common[0][0] == 255: | |
| return most_common[1][0] | |
| else: | |
| return most_common[0][0] | |
| def canny_edges(im): | |
| im = cv2.Canny(np.array(im), 50, 100) | |
| im[im!=0] = 255 | |
| return Image.fromarray(im) | |
| def phash(im): | |
| return imagehash.phash(canny_edges(im), hash_size=32) | |
| def surface(im): | |
| return (np.array(im) != 255).sum() | |
| color_diff = np.abs(color(image1) - color(image2)) | |
| hash_diff = phash(image1) - phash(image2) | |
| surface_diff = np.abs(surface(image1) - surface(image2)) | |
| if int(hash_diff/7) < 10: | |
| return color_diff < 10 or int(surface_diff / (160 * 160) * 100) < 10 | |
| elif color_diff < 10: | |
| return int(surface_diff / (160 * 160) * 100) < 10 or int(hash_diff/7) < 10 | |
| elif int(surface_diff / (160 * 160) * 100) < 10: | |
| return int(hash_diff/7) < 10 or color_diff < 10 | |
| else: | |
| return False | |
| # End of Utils | |
| def load_sample(): | |
| n = len(DATASET) | |
| found_sample = False | |
| while not found_sample: | |
| idx = random.randint(0, n) | |
| sample = DATASET[idx] | |
| found_sample = True | |
| return sample["image"], sample["label"], "", "", "" | |
| def model_inference( | |
| image, | |
| ): | |
| if image is None: | |
| raise ValueError("`image` is None. It should be a PIL image.") | |
| # return "A" | |
| inputs = PROCESSOR.tokenizer( | |
| f"{BOS_TOKEN}User:<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>Which figure should complete the logical sequence?<end_of_utterance>\nAssistant:", | |
| return_tensors="pt", | |
| add_special_tokens=False, | |
| ) | |
| inputs["pixel_values"] = PROCESSOR.image_processor( | |
| [image], | |
| transform=custom_transform | |
| ) | |
| inputs = { | |
| k: v.to(DEVICE) | |
| for k, v in inputs.items() | |
| } | |
| generation_kwargs = dict( | |
| inputs, | |
| bad_words_ids=BAD_WORDS_IDS, | |
| max_length=4, | |
| ) | |
| # Regular generation version | |
| generated_ids = MODEL.generate(**generation_kwargs) | |
| generated_text = PROCESSOR.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=True | |
| )[0] | |
| return generated_text[-1] | |
| model_prediction = gr.TextArea( | |
| label="AI's guess", | |
| visible=True, | |
| lines=1, | |
| max_lines=1, | |
| interactive=False, | |
| ) | |
| user_prediction = gr.TextArea( | |
| label="Your guess", | |
| visible=True, | |
| lines=1, | |
| max_lines=1, | |
| interactive=False, | |
| ) | |
| result = gr.TextArea( | |
| label="Win or lose?", | |
| visible=True, | |
| lines=1, | |
| max_lines=1, | |
| interactive=False, | |
| ) | |
| css = """ | |
| .gradio-container{max-width: 1000px!important} | |
| h1{display: flex;align-items: center;justify-content: center;gap: .25em} | |
| *{transition: width 0.5s ease, flex-grow 0.5s ease} | |
| """ | |
| with gr.Blocks(title="Beat the AI", theme=gr.themes.Base(), css=css) as demo: | |
| gr.Markdown( | |
| "Are you smarter than the AI?" | |
| ) | |
| load_new_sample = gr.Button(value="Load new sample") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=4, min_width=250) as upload_area: | |
| imagebox = gr.Image( | |
| image_mode="L", | |
| type="pil", | |
| visible=True, | |
| sources=None, | |
| ) | |
| with gr.Column(scale=4): | |
| with gr.Row(): | |
| a = gr.Button(value="A", min_width=1) | |
| b = gr.Button(value="B", min_width=1) | |
| c = gr.Button(value="C", min_width=1) | |
| d = gr.Button(value="D", min_width=1) | |
| with gr.Row(): | |
| e = gr.Button(value="E", min_width=1) | |
| f = gr.Button(value="F", min_width=1) | |
| g = gr.Button(value="G", min_width=1) | |
| h = gr.Button(value="H", min_width=1) | |
| with gr.Row(): | |
| model_prediction.render() | |
| user_prediction.render() | |
| solution = gr.TextArea( | |
| label="Solution", | |
| visible=False, | |
| lines=1, | |
| max_lines=1, | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| result.render() | |
| load_new_sample.click( | |
| fn=load_sample, | |
| inputs=[], | |
| outputs=[imagebox, solution, model_prediction, user_prediction, result] | |
| ) | |
| gr.on( | |
| triggers=[ | |
| a.click, | |
| b.click, | |
| c.click, | |
| d.click, | |
| e.click, | |
| f.click, | |
| g.click, | |
| h.click, | |
| ], | |
| fn=model_inference, | |
| inputs=[imagebox], | |
| outputs=[model_prediction], | |
| ).then( | |
| fn=lambda x, y, z: "π₯" if x==y else f"π© The solution is {chr(ord('A') + int(z))}", | |
| inputs=[model_prediction, user_prediction, solution], | |
| outputs=[result], | |
| ) | |
| a.click(fn=lambda: "A", inputs=[], outputs=[user_prediction]) | |
| b.click(fn=lambda: "B", inputs=[], outputs=[user_prediction]) | |
| c.click(fn=lambda: "C", inputs=[], outputs=[user_prediction]) | |
| d.click(fn=lambda: "D", inputs=[], outputs=[user_prediction]) | |
| e.click(fn=lambda: "E", inputs=[], outputs=[user_prediction]) | |
| f.click(fn=lambda: "F", inputs=[], outputs=[user_prediction]) | |
| g.click(fn=lambda: "G", inputs=[], outputs=[user_prediction]) | |
| h.click(fn=lambda: "H", inputs=[], outputs=[user_prediction]) | |
| demo.load() | |
| demo.queue(max_size=40, api_open=False) | |
| demo.launch(max_threads=400) |