File size: 2,656 Bytes
9b5318d
 
 
 
cefdf31
9b5318d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42d703c
 
269fddf
 
9b5318d
 
269fddf
42d703c
269fddf
42d703c
 
269fddf
42d703c
 
9b5318d
42d703c
9b5318d
42d703c
 
9b5318d
 
 
 
269fddf
9b5318d
 
 
42d703c
9b5318d
 
42d703c
 
 
 
 
 
 
9b5318d
ad90ea7
 
9b5318d
ad90ea7
 
 
 
 
 
 
 
 
9b5318d
42d703c
9b5318d
 
 
 
 
 
ad90ea7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import random

# Load the CLIP model and processor
st.title("Meme Battle AI")
st.write("Stream memes directly and let AI determine the winner!")

@st.cache_resource
def load_model():
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    return model, processor

model, processor = load_model()

@st.cache_resource
def load_streamed_dataset():
    return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True)

dataset = load_streamed_dataset()

def fetch_random_memes():
    """Fetch two random memes from the dataset."""
    sample_size = 100  # Number of samples to shuffle
    dataset_samples = list(dataset.shuffle(seed=random.randint(0, 1000)).take(sample_size))
    meme1, meme2 = random.sample(dataset_samples, 2)
    return meme1, meme2

def parse_meme(meme):
    """Extract the caption and Pillow image from a meme."""
    caption = meme["answers"][0] if meme.get("answers") else "No caption available"
    image = meme["image"]  # This is already a PIL image object
    return caption, image

def score_meme(image, caption):
    """Score a meme by evaluating the image-caption compatibility."""
    try:
        # Preprocess image and caption
        inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True)

        # Get the compatibility score
        outputs = model(**inputs)
        logits_per_text = outputs.logits_per_text
        return logits_per_text.item()
    except Exception as e:
        st.error(f"Error scoring meme: {e}")
        return 0

if st.button("Start Meme Battle"):
    # Fetch random memes
    meme1, meme2 = fetch_random_memes()

    # Parse captions and images
    caption1, image1 = parse_meme(meme1)
    caption2, image2 = parse_meme(meme2)

    # Score memes
    score1 = score_meme(image1, caption1)
    score2 = score_meme(image2, caption2)

    # Display Meme 1 and Meme 2 side by side
    col1, col2 = st.columns(2)

    with col1:
        st.write("#### Meme 1")
        st.image(image1, caption=f"Caption: {caption1}")
        st.write(f"AI Score: {score1:.2f}")

    with col2:
        st.write("#### Meme 2")
        st.image(image2, caption=f"Caption: {caption2}")
        st.write(f"AI Score: {score2:.2f}")

    # Determine the winner
    if score1 > score2:
        st.write("πŸŽ‰ **Meme 1 Wins!**")
    elif score2 > score1:
        st.write("πŸŽ‰ **Meme 2 Wins!**")
    else:
        st.write("🀝 **It's a tie!**")