import streamlit as st from datasets import load_dataset from transformers import CLIPProcessor, CLIPModel from PIL import Image import random # Load the CLIP model and processor st.title("Meme Battle AI") st.write("Stream memes directly and let AI determine the winner!") @st.cache_resource def load_model(): model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") return model, processor model, processor = load_model() @st.cache_resource def load_streamed_dataset(): return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True) dataset = load_streamed_dataset() def fetch_random_memes(): """Fetch two random memes from the dataset.""" sample_size = 100 # Number of samples to shuffle dataset_samples = list(dataset.shuffle(seed=random.randint(0, 1000)).take(sample_size)) meme1, meme2 = random.sample(dataset_samples, 2) return meme1, meme2 def parse_meme(meme): """Extract the caption and Pillow image from a meme.""" caption = meme["answers"][0] if meme.get("answers") else "No caption available" image = meme["image"] # This is already a PIL image object return caption, image def score_meme(image, caption): """Score a meme by evaluating the image-caption compatibility.""" try: # Preprocess image and caption inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True) # Get the compatibility score outputs = model(**inputs) logits_per_text = outputs.logits_per_text return logits_per_text.item() except Exception as e: st.error(f"Error scoring meme: {e}") return 0 if st.button("Start Meme Battle"): # Fetch random memes meme1, meme2 = fetch_random_memes() # Parse captions and images caption1, image1 = parse_meme(meme1) caption2, image2 = parse_meme(meme2) # Score memes score1 = score_meme(image1, caption1) score2 = score_meme(image2, caption2) # Display Meme 1 and Meme 2 side by side col1, col2 = st.columns(2) with col1: st.write("#### Meme 1") st.image(image1, caption=f"Caption: {caption1}") st.write(f"AI Score: {score1:.2f}") with col2: st.write("#### Meme 2") st.image(image2, caption=f"Caption: {caption2}") st.write(f"AI Score: {score2:.2f}") # Determine the winner if score1 > score2: st.write("🎉 **Meme 1 Wins!**") elif score2 > score1: st.write("🎉 **Meme 2 Wins!**") else: st.write("🤝 **It's a tie!**")