Spaces:
Sleeping
Sleeping
import streamlit as st | |
from datasets import load_dataset | |
from transformers import CLIPProcessor, CLIPModel | |
from PIL import Image | |
import random | |
from io import BytesIO | |
import requests | |
# Load the CLIP model and processor | |
st.title("Meme Battle AI") | |
st.write("Stream memes directly and let AI determine the winner!") | |
# Load the CLIP model | |
def load_model(): | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
return model, processor | |
model, processor = load_model() | |
# Stream the dataset | |
def load_streamed_dataset(): | |
return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True) | |
dataset = load_streamed_dataset() | |
# Randomly fetch two memes | |
def fetch_random_memes(): | |
samples = list(dataset.take(10)) # Fetch 10 samples from the streamed dataset | |
meme1, meme2 = random.sample(samples, 2) | |
return meme1, meme2 | |
# Score the memes | |
def score_meme(image_url, caption): | |
try: | |
# Load the image from the URL | |
image = Image.open(BytesIO(requests.get(image_url).content)) | |
# Preprocess image and caption | |
inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True) | |
# Get the compatibility score | |
outputs = model(**inputs) | |
logits_per_text = outputs.logits_per_text | |
return logits_per_text.item() | |
except Exception as e: | |
return 0 | |
# Streamlit Interface | |
st.write("### Select two memes and let the AI determine the winner!") | |
if st.button("Start Meme Battle"): | |
meme1, meme2 = fetch_random_memes() | |
# Fetch data for Meme 1 | |
caption1 = meme1["caption"] | |
image_url1 = meme1["image"] | |
score1 = score_meme(image_url1, caption1) | |
# Fetch data for Meme 2 | |
caption2 = meme2["caption"] | |
image_url2 = meme2["image"] | |
score2 = score_meme(image_url2, caption2) | |
# Display Meme 1 | |
st.write("#### Meme 1") | |
st.image(image_url1, caption=f"Caption: {caption1}") | |
st.write(f"AI Score: {score1:.2f}") | |
# Display Meme 2 | |
st.write("#### Meme 2") | |
st.image(image_url2, caption=f"Caption: {caption2}") | |
st.write(f"AI Score: {score2:.2f}") | |
# Determine the winner | |
if score1 > score2: | |
st.write("π **Meme 1 Wins!**") | |
elif score2 > score1: | |
st.write("π **Meme 2 Wins!**") | |
else: | |
st.write("π€ **It's a tie!**") | |