import streamlit as st from datasets import load_dataset from transformers import CLIPProcessor, CLIPModel from PIL import Image import random from io import BytesIO import requests # Load the CLIP model and processor st.title("Meme Battle AI") st.write("Stream memes directly and let AI determine the winner!") # Load the CLIP model @st.cache_resource def load_model(): model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") return model, processor model, processor = load_model() # Stream the dataset @st.cache_resource def load_streamed_dataset(): return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True) dataset = load_streamed_dataset() # Randomly fetch two memes def fetch_random_memes(): samples = list(dataset.take(10)) # Fetch 10 samples from the streamed dataset meme1, meme2 = random.sample(samples, 2) return meme1, meme2 # Score the memes def score_meme(image_url, caption): try: # Load the image from the URL image = Image.open(BytesIO(requests.get(image_url).content)) # Preprocess image and caption inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True) # Get the compatibility score outputs = model(**inputs) logits_per_text = outputs.logits_per_text return logits_per_text.item() except Exception as e: return 0 # Streamlit Interface st.write("### Select two memes and let the AI determine the winner!") if st.button("Start Meme Battle"): meme1, meme2 = fetch_random_memes() # Fetch data for Meme 1 caption1 = meme1["caption"] image_url1 = meme1["image"] score1 = score_meme(image_url1, caption1) # Fetch data for Meme 2 caption2 = meme2["caption"] image_url2 = meme2["image"] score2 = score_meme(image_url2, caption2) # Display Meme 1 st.write("#### Meme 1") st.image(image_url1, caption=f"Caption: {caption1}") st.write(f"AI Score: {score1:.2f}") # Display Meme 2 st.write("#### Meme 2") st.image(image_url2, caption=f"Caption: {caption2}") st.write(f"AI Score: {score2:.2f}") # Determine the winner if score1 > score2: st.write("🎉 **Meme 1 Wins!**") elif score2 > score1: st.write("🎉 **Meme 2 Wins!**") else: st.write("🤝 **It's a tie!**")