neojack commited on
Commit
269fddf
Β·
verified Β·
1 Parent(s): cefdf31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -25
app.py CHANGED
@@ -10,7 +10,6 @@ import random
10
  st.title("Meme Battle AI")
11
  st.write("Stream memes directly and let AI determine the winner!")
12
 
13
- # Load the CLIP model
14
  @st.cache_resource
15
  def load_model():
16
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
@@ -19,65 +18,55 @@ def load_model():
19
 
20
  model, processor = load_model()
21
 
22
- # Stream the dataset
23
  @st.cache_resource
24
  def load_streamed_dataset():
25
  return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True)
26
 
27
  dataset = load_streamed_dataset()
28
 
29
- # Randomly fetch two memes
30
  def fetch_random_memes():
31
- samples = list(dataset.take(10)) # Fetch 10 samples from the streamed dataset
32
- meme1, meme2 = random.sample(samples, 2)
 
33
  return meme1, meme2
34
 
35
- # Score the memes
 
 
 
 
36
  def score_meme(image_url, caption):
37
  try:
38
- # Load the image from the URL
39
- image = Image.open(BytesIO(requests.get(image_url).content))
40
- # Preprocess image and caption
41
  inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True)
42
- # Get the compatibility score
43
  outputs = model(**inputs)
44
  logits_per_text = outputs.logits_per_text
45
  return logits_per_text.item()
46
  except Exception as e:
 
47
  return 0
48
 
49
- # Streamlit Interface
50
- st.write("### Select two memes and let the AI determine the winner!")
51
-
52
  if st.button("Start Meme Battle"):
53
  meme1, meme2 = fetch_random_memes()
 
 
54
 
55
- # Extract caption and image for Meme 1
56
- caption1 = meme1["answers"][0] if meme1["answers"] else "No caption available"
57
- image_url1 = meme1["image"]
58
-
59
- # Extract caption and image for Meme 2
60
- caption2 = meme2["answers"][0] if meme2["answers"] else "No caption available"
61
- image_url2 = meme2["image"]
62
-
63
- # Fetch scores
64
  score1 = score_meme(image_url1, caption1)
65
  score2 = score_meme(image_url2, caption2)
66
 
67
- # Display Meme 1
68
  st.write("#### Meme 1")
69
  st.image(image_url1, caption=f"Caption: {caption1}")
70
  st.write(f"AI Score: {score1:.2f}")
71
 
72
- # Display Meme 2
73
  st.write("#### Meme 2")
74
  st.image(image_url2, caption=f"Caption: {caption2}")
75
  st.write(f"AI Score: {score2:.2f}")
76
 
77
- # Determine the winner
78
  if score1 > score2:
79
  st.write("πŸŽ‰ **Meme 1 Wins!**")
80
  elif score2 > score1:
81
  st.write("πŸŽ‰ **Meme 2 Wins!**")
82
  else:
83
  st.write("🀝 **It's a tie!**")
 
 
10
  st.title("Meme Battle AI")
11
  st.write("Stream memes directly and let AI determine the winner!")
12
 
 
13
  @st.cache_resource
14
  def load_model():
15
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
 
18
 
19
  model, processor = load_model()
20
 
 
21
  @st.cache_resource
22
  def load_streamed_dataset():
23
  return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True)
24
 
25
  dataset = load_streamed_dataset()
26
 
 
27
  def fetch_random_memes():
28
+ sample_size = 100
29
+ dataset_samples = list(dataset.shuffle(seed=random.randint(0, 1000)).take(sample_size))
30
+ meme1, meme2 = random.sample(dataset_samples, 2)
31
  return meme1, meme2
32
 
33
+ def parse_meme(meme):
34
+ caption = meme["answers"][0] if meme.get("answers") else "No caption available"
35
+ image_url = meme["image"]
36
+ return caption, image_url
37
+
38
  def score_meme(image_url, caption):
39
  try:
40
+ response = requests.get(image_url)
41
+ image = Image.open(BytesIO(response.content)).convert("RGB")
 
42
  inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True)
 
43
  outputs = model(**inputs)
44
  logits_per_text = outputs.logits_per_text
45
  return logits_per_text.item()
46
  except Exception as e:
47
+ st.error(f"Error scoring meme: {e}")
48
  return 0
49
 
 
 
 
50
  if st.button("Start Meme Battle"):
51
  meme1, meme2 = fetch_random_memes()
52
+ caption1, image_url1 = parse_meme(meme1)
53
+ caption2, image_url2 = parse_meme(meme2)
54
 
 
 
 
 
 
 
 
 
 
55
  score1 = score_meme(image_url1, caption1)
56
  score2 = score_meme(image_url2, caption2)
57
 
 
58
  st.write("#### Meme 1")
59
  st.image(image_url1, caption=f"Caption: {caption1}")
60
  st.write(f"AI Score: {score1:.2f}")
61
 
 
62
  st.write("#### Meme 2")
63
  st.image(image_url2, caption=f"Caption: {caption2}")
64
  st.write(f"AI Score: {score2:.2f}")
65
 
 
66
  if score1 > score2:
67
  st.write("πŸŽ‰ **Meme 1 Wins!**")
68
  elif score2 > score1:
69
  st.write("πŸŽ‰ **Meme 2 Wins!**")
70
  else:
71
  st.write("🀝 **It's a tie!**")
72
+