memebattle / app.py
neojack's picture
Update app.py
ad90ea7 verified
import streamlit as st
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import random
# Load the CLIP model and processor
st.title("Meme Battle AI")
st.write("Stream memes directly and let AI determine the winner!")
@st.cache_resource
def load_model():
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
return model, processor
model, processor = load_model()
@st.cache_resource
def load_streamed_dataset():
return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True)
dataset = load_streamed_dataset()
def fetch_random_memes():
"""Fetch two random memes from the dataset."""
sample_size = 100 # Number of samples to shuffle
dataset_samples = list(dataset.shuffle(seed=random.randint(0, 1000)).take(sample_size))
meme1, meme2 = random.sample(dataset_samples, 2)
return meme1, meme2
def parse_meme(meme):
"""Extract the caption and Pillow image from a meme."""
caption = meme["answers"][0] if meme.get("answers") else "No caption available"
image = meme["image"] # This is already a PIL image object
return caption, image
def score_meme(image, caption):
"""Score a meme by evaluating the image-caption compatibility."""
try:
# Preprocess image and caption
inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True)
# Get the compatibility score
outputs = model(**inputs)
logits_per_text = outputs.logits_per_text
return logits_per_text.item()
except Exception as e:
st.error(f"Error scoring meme: {e}")
return 0
if st.button("Start Meme Battle"):
# Fetch random memes
meme1, meme2 = fetch_random_memes()
# Parse captions and images
caption1, image1 = parse_meme(meme1)
caption2, image2 = parse_meme(meme2)
# Score memes
score1 = score_meme(image1, caption1)
score2 = score_meme(image2, caption2)
# Display Meme 1 and Meme 2 side by side
col1, col2 = st.columns(2)
with col1:
st.write("#### Meme 1")
st.image(image1, caption=f"Caption: {caption1}")
st.write(f"AI Score: {score1:.2f}")
with col2:
st.write("#### Meme 2")
st.image(image2, caption=f"Caption: {caption2}")
st.write(f"AI Score: {score2:.2f}")
# Determine the winner
if score1 > score2:
st.write("πŸŽ‰ **Meme 1 Wins!**")
elif score2 > score1:
st.write("πŸŽ‰ **Meme 2 Wins!**")
else:
st.write("🀝 **It's a tie!**")