memebattle / app.py
neojack's picture
Create app.py
9b5318d verified
raw
history blame
2.45 kB
import streamlit as st
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import random
from io import BytesIO
import requests
# Load the CLIP model and processor
st.title("Meme Battle AI")
st.write("Stream memes directly and let AI determine the winner!")
# Load the CLIP model
@st.cache_resource
def load_model():
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
return model, processor
model, processor = load_model()
# Stream the dataset
@st.cache_resource
def load_streamed_dataset():
return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True)
dataset = load_streamed_dataset()
# Randomly fetch two memes
def fetch_random_memes():
samples = list(dataset.take(10)) # Fetch 10 samples from the streamed dataset
meme1, meme2 = random.sample(samples, 2)
return meme1, meme2
# Score the memes
def score_meme(image_url, caption):
try:
# Load the image from the URL
image = Image.open(BytesIO(requests.get(image_url).content))
# Preprocess image and caption
inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True)
# Get the compatibility score
outputs = model(**inputs)
logits_per_text = outputs.logits_per_text
return logits_per_text.item()
except Exception as e:
return 0
# Streamlit Interface
st.write("### Select two memes and let the AI determine the winner!")
if st.button("Start Meme Battle"):
meme1, meme2 = fetch_random_memes()
# Fetch data for Meme 1
caption1 = meme1["caption"]
image_url1 = meme1["image"]
score1 = score_meme(image_url1, caption1)
# Fetch data for Meme 2
caption2 = meme2["caption"]
image_url2 = meme2["image"]
score2 = score_meme(image_url2, caption2)
# Display Meme 1
st.write("#### Meme 1")
st.image(image_url1, caption=f"Caption: {caption1}")
st.write(f"AI Score: {score1:.2f}")
# Display Meme 2
st.write("#### Meme 2")
st.image(image_url2, caption=f"Caption: {caption2}")
st.write(f"AI Score: {score2:.2f}")
# Determine the winner
if score1 > score2:
st.write("πŸŽ‰ **Meme 1 Wins!**")
elif score2 > score1:
st.write("πŸŽ‰ **Meme 2 Wins!**")
else:
st.write("🀝 **It's a tie!**")