memebattle / app.py
neojack's picture
Update app.py
269fddf verified
raw
history blame
2.28 kB
import streamlit as st
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
from io import BytesIO
import requests
import random
# Load the CLIP model and processor
st.title("Meme Battle AI")
st.write("Stream memes directly and let AI determine the winner!")
@st.cache_resource
def load_model():
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
return model, processor
model, processor = load_model()
@st.cache_resource
def load_streamed_dataset():
return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True)
dataset = load_streamed_dataset()
def fetch_random_memes():
sample_size = 100
dataset_samples = list(dataset.shuffle(seed=random.randint(0, 1000)).take(sample_size))
meme1, meme2 = random.sample(dataset_samples, 2)
return meme1, meme2
def parse_meme(meme):
caption = meme["answers"][0] if meme.get("answers") else "No caption available"
image_url = meme["image"]
return caption, image_url
def score_meme(image_url, caption):
try:
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert("RGB")
inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_text = outputs.logits_per_text
return logits_per_text.item()
except Exception as e:
st.error(f"Error scoring meme: {e}")
return 0
if st.button("Start Meme Battle"):
meme1, meme2 = fetch_random_memes()
caption1, image_url1 = parse_meme(meme1)
caption2, image_url2 = parse_meme(meme2)
score1 = score_meme(image_url1, caption1)
score2 = score_meme(image_url2, caption2)
st.write("#### Meme 1")
st.image(image_url1, caption=f"Caption: {caption1}")
st.write(f"AI Score: {score1:.2f}")
st.write("#### Meme 2")
st.image(image_url2, caption=f"Caption: {caption2}")
st.write(f"AI Score: {score2:.2f}")
if score1 > score2:
st.write("πŸŽ‰ **Meme 1 Wins!**")
elif score2 > score1:
st.write("πŸŽ‰ **Meme 2 Wins!**")
else:
st.write("🀝 **It's a tie!**")