Spaces:
Sleeping
Sleeping
File size: 2,656 Bytes
9b5318d cefdf31 9b5318d 42d703c 269fddf 9b5318d 269fddf 42d703c 269fddf 42d703c 269fddf 42d703c 9b5318d 42d703c 9b5318d 42d703c 9b5318d 269fddf 9b5318d 42d703c 9b5318d 42d703c 9b5318d ad90ea7 9b5318d ad90ea7 9b5318d 42d703c 9b5318d ad90ea7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import streamlit as st
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import random
# Load the CLIP model and processor
st.title("Meme Battle AI")
st.write("Stream memes directly and let AI determine the winner!")
@st.cache_resource
def load_model():
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
return model, processor
model, processor = load_model()
@st.cache_resource
def load_streamed_dataset():
return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True)
dataset = load_streamed_dataset()
def fetch_random_memes():
"""Fetch two random memes from the dataset."""
sample_size = 100 # Number of samples to shuffle
dataset_samples = list(dataset.shuffle(seed=random.randint(0, 1000)).take(sample_size))
meme1, meme2 = random.sample(dataset_samples, 2)
return meme1, meme2
def parse_meme(meme):
"""Extract the caption and Pillow image from a meme."""
caption = meme["answers"][0] if meme.get("answers") else "No caption available"
image = meme["image"] # This is already a PIL image object
return caption, image
def score_meme(image, caption):
"""Score a meme by evaluating the image-caption compatibility."""
try:
# Preprocess image and caption
inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True)
# Get the compatibility score
outputs = model(**inputs)
logits_per_text = outputs.logits_per_text
return logits_per_text.item()
except Exception as e:
st.error(f"Error scoring meme: {e}")
return 0
if st.button("Start Meme Battle"):
# Fetch random memes
meme1, meme2 = fetch_random_memes()
# Parse captions and images
caption1, image1 = parse_meme(meme1)
caption2, image2 = parse_meme(meme2)
# Score memes
score1 = score_meme(image1, caption1)
score2 = score_meme(image2, caption2)
# Display Meme 1 and Meme 2 side by side
col1, col2 = st.columns(2)
with col1:
st.write("#### Meme 1")
st.image(image1, caption=f"Caption: {caption1}")
st.write(f"AI Score: {score1:.2f}")
with col2:
st.write("#### Meme 2")
st.image(image2, caption=f"Caption: {caption2}")
st.write(f"AI Score: {score2:.2f}")
# Determine the winner
if score1 > score2:
st.write("π **Meme 1 Wins!**")
elif score2 > score1:
st.write("π **Meme 2 Wins!**")
else:
st.write("π€ **It's a tie!**")
|