File size: 2,445 Bytes
9b5318d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import streamlit as st
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import random
from io import BytesIO
import requests

# Load the CLIP model and processor
st.title("Meme Battle AI")
st.write("Stream memes directly and let AI determine the winner!")

# Load the CLIP model
@st.cache_resource
def load_model():
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    return model, processor

model, processor = load_model()

# Stream the dataset
@st.cache_resource
def load_streamed_dataset():
    return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True)

dataset = load_streamed_dataset()

# Randomly fetch two memes
def fetch_random_memes():
    samples = list(dataset.take(10))  # Fetch 10 samples from the streamed dataset
    meme1, meme2 = random.sample(samples, 2)
    return meme1, meme2

# Score the memes
def score_meme(image_url, caption):
    try:
        # Load the image from the URL
        image = Image.open(BytesIO(requests.get(image_url).content))
        # Preprocess image and caption
        inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True)
        # Get the compatibility score
        outputs = model(**inputs)
        logits_per_text = outputs.logits_per_text
        return logits_per_text.item()
    except Exception as e:
        return 0

# Streamlit Interface
st.write("### Select two memes and let the AI determine the winner!")

if st.button("Start Meme Battle"):
    meme1, meme2 = fetch_random_memes()

    # Fetch data for Meme 1
    caption1 = meme1["caption"]
    image_url1 = meme1["image"]
    score1 = score_meme(image_url1, caption1)

    # Fetch data for Meme 2
    caption2 = meme2["caption"]
    image_url2 = meme2["image"]
    score2 = score_meme(image_url2, caption2)

    # Display Meme 1
    st.write("#### Meme 1")
    st.image(image_url1, caption=f"Caption: {caption1}")
    st.write(f"AI Score: {score1:.2f}")

    # Display Meme 2
    st.write("#### Meme 2")
    st.image(image_url2, caption=f"Caption: {caption2}")
    st.write(f"AI Score: {score2:.2f}")

    # Determine the winner
    if score1 > score2:
        st.write("πŸŽ‰ **Meme 1 Wins!**")
    elif score2 > score1:
        st.write("πŸŽ‰ **Meme 2 Wins!**")
    else:
        st.write("🀝 **It's a tie!**")