Spaces:
Sleeping
Sleeping
File size: 2,445 Bytes
9b5318d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import streamlit as st
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import random
from io import BytesIO
import requests
# Load the CLIP model and processor
st.title("Meme Battle AI")
st.write("Stream memes directly and let AI determine the winner!")
# Load the CLIP model
@st.cache_resource
def load_model():
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
return model, processor
model, processor = load_model()
# Stream the dataset
@st.cache_resource
def load_streamed_dataset():
return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True)
dataset = load_streamed_dataset()
# Randomly fetch two memes
def fetch_random_memes():
samples = list(dataset.take(10)) # Fetch 10 samples from the streamed dataset
meme1, meme2 = random.sample(samples, 2)
return meme1, meme2
# Score the memes
def score_meme(image_url, caption):
try:
# Load the image from the URL
image = Image.open(BytesIO(requests.get(image_url).content))
# Preprocess image and caption
inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True)
# Get the compatibility score
outputs = model(**inputs)
logits_per_text = outputs.logits_per_text
return logits_per_text.item()
except Exception as e:
return 0
# Streamlit Interface
st.write("### Select two memes and let the AI determine the winner!")
if st.button("Start Meme Battle"):
meme1, meme2 = fetch_random_memes()
# Fetch data for Meme 1
caption1 = meme1["caption"]
image_url1 = meme1["image"]
score1 = score_meme(image_url1, caption1)
# Fetch data for Meme 2
caption2 = meme2["caption"]
image_url2 = meme2["image"]
score2 = score_meme(image_url2, caption2)
# Display Meme 1
st.write("#### Meme 1")
st.image(image_url1, caption=f"Caption: {caption1}")
st.write(f"AI Score: {score1:.2f}")
# Display Meme 2
st.write("#### Meme 2")
st.image(image_url2, caption=f"Caption: {caption2}")
st.write(f"AI Score: {score2:.2f}")
# Determine the winner
if score1 > score2:
st.write("π **Meme 1 Wins!**")
elif score2 > score1:
st.write("π **Meme 2 Wins!**")
else:
st.write("π€ **It's a tie!**")
|