Spaces:
Sleeping
Sleeping
import streamlit as st | |
from datasets import load_dataset | |
from transformers import CLIPProcessor, CLIPModel | |
from PIL import Image | |
import random | |
# Load the CLIP model and processor | |
st.title("Meme Battle AI") | |
st.write("Stream memes directly and let AI determine the winner!") | |
def load_model(): | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
return model, processor | |
model, processor = load_model() | |
def load_streamed_dataset(): | |
return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True) | |
dataset = load_streamed_dataset() | |
def fetch_random_memes(): | |
"""Fetch two random memes from the dataset.""" | |
sample_size = 100 # Number of samples to shuffle | |
dataset_samples = list(dataset.shuffle(seed=random.randint(0, 1000)).take(sample_size)) | |
meme1, meme2 = random.sample(dataset_samples, 2) | |
return meme1, meme2 | |
def parse_meme(meme): | |
"""Extract the caption and Pillow image from a meme.""" | |
caption = meme["answers"][0] if meme.get("answers") else "No caption available" | |
image = meme["image"] # This is already a PIL image object | |
return caption, image | |
def score_meme(image, caption): | |
"""Score a meme by evaluating the image-caption compatibility.""" | |
try: | |
# Preprocess image and caption | |
inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True) | |
# Get the compatibility score | |
outputs = model(**inputs) | |
logits_per_text = outputs.logits_per_text | |
return logits_per_text.item() | |
except Exception as e: | |
st.error(f"Error scoring meme: {e}") | |
return 0 | |
if st.button("Start Meme Battle"): | |
# Fetch random memes | |
meme1, meme2 = fetch_random_memes() | |
# Parse captions and images | |
caption1, image1 = parse_meme(meme1) | |
caption2, image2 = parse_meme(meme2) | |
# Score memes | |
score1 = score_meme(image1, caption1) | |
score2 = score_meme(image2, caption2) | |
# Display Meme 1 and Meme 2 side by side | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("#### Meme 1") | |
st.image(image1, caption=f"Caption: {caption1}") | |
st.write(f"AI Score: {score1:.2f}") | |
with col2: | |
st.write("#### Meme 2") | |
st.image(image2, caption=f"Caption: {caption2}") | |
st.write(f"AI Score: {score2:.2f}") | |
# Determine the winner | |
if score1 > score2: | |
st.write("π **Meme 1 Wins!**") | |
elif score2 > score1: | |
st.write("π **Meme 2 Wins!**") | |
else: | |
st.write("π€ **It's a tie!**") | |