neojack commited on
Commit
9b5318d
·
verified ·
1 Parent(s): d6eb356

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from datasets import load_dataset
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ from PIL import Image
5
+ import random
6
+ from io import BytesIO
7
+ import requests
8
+
9
+ # Load the CLIP model and processor
10
+ st.title("Meme Battle AI")
11
+ st.write("Stream memes directly and let AI determine the winner!")
12
+
13
+ # Load the CLIP model
14
+ @st.cache_resource
15
+ def load_model():
16
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
17
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
18
+ return model, processor
19
+
20
+ model, processor = load_model()
21
+
22
+ # Stream the dataset
23
+ @st.cache_resource
24
+ def load_streamed_dataset():
25
+ return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True)
26
+
27
+ dataset = load_streamed_dataset()
28
+
29
+ # Randomly fetch two memes
30
+ def fetch_random_memes():
31
+ samples = list(dataset.take(10)) # Fetch 10 samples from the streamed dataset
32
+ meme1, meme2 = random.sample(samples, 2)
33
+ return meme1, meme2
34
+
35
+ # Score the memes
36
+ def score_meme(image_url, caption):
37
+ try:
38
+ # Load the image from the URL
39
+ image = Image.open(BytesIO(requests.get(image_url).content))
40
+ # Preprocess image and caption
41
+ inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True)
42
+ # Get the compatibility score
43
+ outputs = model(**inputs)
44
+ logits_per_text = outputs.logits_per_text
45
+ return logits_per_text.item()
46
+ except Exception as e:
47
+ return 0
48
+
49
+ # Streamlit Interface
50
+ st.write("### Select two memes and let the AI determine the winner!")
51
+
52
+ if st.button("Start Meme Battle"):
53
+ meme1, meme2 = fetch_random_memes()
54
+
55
+ # Fetch data for Meme 1
56
+ caption1 = meme1["caption"]
57
+ image_url1 = meme1["image"]
58
+ score1 = score_meme(image_url1, caption1)
59
+
60
+ # Fetch data for Meme 2
61
+ caption2 = meme2["caption"]
62
+ image_url2 = meme2["image"]
63
+ score2 = score_meme(image_url2, caption2)
64
+
65
+ # Display Meme 1
66
+ st.write("#### Meme 1")
67
+ st.image(image_url1, caption=f"Caption: {caption1}")
68
+ st.write(f"AI Score: {score1:.2f}")
69
+
70
+ # Display Meme 2
71
+ st.write("#### Meme 2")
72
+ st.image(image_url2, caption=f"Caption: {caption2}")
73
+ st.write(f"AI Score: {score2:.2f}")
74
+
75
+ # Determine the winner
76
+ if score1 > score2:
77
+ st.write("🎉 **Meme 1 Wins!**")
78
+ elif score2 > score1:
79
+ st.write("🎉 **Meme 2 Wins!**")
80
+ else:
81
+ st.write("🤝 **It's a tie!**")