Spaces:
Sleeping
Sleeping
import streamlit as st | |
from datasets import load_dataset | |
from transformers import CLIPProcessor, CLIPModel | |
from PIL import Image | |
from io import BytesIO | |
import requests | |
import random | |
# Load the CLIP model and processor | |
st.title("Meme Battle AI") | |
st.write("Stream memes directly and let AI determine the winner!") | |
def load_model(): | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
return model, processor | |
model, processor = load_model() | |
def load_streamed_dataset(): | |
return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True) | |
dataset = load_streamed_dataset() | |
def fetch_random_memes(): | |
sample_size = 100 | |
dataset_samples = list(dataset.shuffle(seed=random.randint(0, 1000)).take(sample_size)) | |
meme1, meme2 = random.sample(dataset_samples, 2) | |
return meme1, meme2 | |
def parse_meme(meme): | |
caption = meme["answers"][0] if meme.get("answers") else "No caption available" | |
image_url = meme["image"] | |
return caption, image_url | |
def score_meme(image_url, caption): | |
try: | |
response = requests.get(image_url) | |
image = Image.open(BytesIO(response.content)).convert("RGB") | |
inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True) | |
outputs = model(**inputs) | |
logits_per_text = outputs.logits_per_text | |
return logits_per_text.item() | |
except Exception as e: | |
st.error(f"Error scoring meme: {e}") | |
return 0 | |
if st.button("Start Meme Battle"): | |
meme1, meme2 = fetch_random_memes() | |
caption1, image_url1 = parse_meme(meme1) | |
caption2, image_url2 = parse_meme(meme2) | |
score1 = score_meme(image_url1, caption1) | |
score2 = score_meme(image_url2, caption2) | |
st.write("#### Meme 1") | |
st.image(image_url1, caption=f"Caption: {caption1}") | |
st.write(f"AI Score: {score1:.2f}") | |
st.write("#### Meme 2") | |
st.image(image_url2, caption=f"Caption: {caption2}") | |
st.write(f"AI Score: {score2:.2f}") | |
if score1 > score2: | |
st.write("π **Meme 1 Wins!**") | |
elif score2 > score1: | |
st.write("π **Meme 2 Wins!**") | |
else: | |
st.write("π€ **It's a tie!**") | |