neojack commited on
Commit
42d703c
Β·
verified Β·
1 Parent(s): 269fddf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -15
app.py CHANGED
@@ -2,8 +2,6 @@ import streamlit as st
2
  from datasets import load_dataset
3
  from transformers import CLIPProcessor, CLIPModel
4
  from PIL import Image
5
- from io import BytesIO
6
- import requests
7
  import random
8
 
9
  # Load the CLIP model and processor
@@ -25,21 +23,25 @@ def load_streamed_dataset():
25
  dataset = load_streamed_dataset()
26
 
27
  def fetch_random_memes():
28
- sample_size = 100
 
29
  dataset_samples = list(dataset.shuffle(seed=random.randint(0, 1000)).take(sample_size))
30
  meme1, meme2 = random.sample(dataset_samples, 2)
31
  return meme1, meme2
32
 
33
  def parse_meme(meme):
 
34
  caption = meme["answers"][0] if meme.get("answers") else "No caption available"
35
- image_url = meme["image"]
36
- return caption, image_url
37
 
38
- def score_meme(image_url, caption):
 
39
  try:
40
- response = requests.get(image_url)
41
- image = Image.open(BytesIO(response.content)).convert("RGB")
42
  inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True)
 
 
43
  outputs = model(**inputs)
44
  logits_per_text = outputs.logits_per_text
45
  return logits_per_text.item()
@@ -48,25 +50,31 @@ def score_meme(image_url, caption):
48
  return 0
49
 
50
  if st.button("Start Meme Battle"):
 
51
  meme1, meme2 = fetch_random_memes()
52
- caption1, image_url1 = parse_meme(meme1)
53
- caption2, image_url2 = parse_meme(meme2)
54
 
55
- score1 = score_meme(image_url1, caption1)
56
- score2 = score_meme(image_url2, caption2)
 
 
 
 
 
57
 
 
58
  st.write("#### Meme 1")
59
- st.image(image_url1, caption=f"Caption: {caption1}")
60
  st.write(f"AI Score: {score1:.2f}")
61
 
 
62
  st.write("#### Meme 2")
63
- st.image(image_url2, caption=f"Caption: {caption2}")
64
  st.write(f"AI Score: {score2:.2f}")
65
 
 
66
  if score1 > score2:
67
  st.write("πŸŽ‰ **Meme 1 Wins!**")
68
  elif score2 > score1:
69
  st.write("πŸŽ‰ **Meme 2 Wins!**")
70
  else:
71
  st.write("🀝 **It's a tie!**")
72
-
 
2
  from datasets import load_dataset
3
  from transformers import CLIPProcessor, CLIPModel
4
  from PIL import Image
 
 
5
  import random
6
 
7
  # Load the CLIP model and processor
 
23
  dataset = load_streamed_dataset()
24
 
25
  def fetch_random_memes():
26
+ """Fetch two random memes from the dataset."""
27
+ sample_size = 100 # Number of samples to shuffle
28
  dataset_samples = list(dataset.shuffle(seed=random.randint(0, 1000)).take(sample_size))
29
  meme1, meme2 = random.sample(dataset_samples, 2)
30
  return meme1, meme2
31
 
32
  def parse_meme(meme):
33
+ """Extract the caption and Pillow image from a meme."""
34
  caption = meme["answers"][0] if meme.get("answers") else "No caption available"
35
+ image = meme["image"] # This is already a PIL image object
36
+ return caption, image
37
 
38
+ def score_meme(image, caption):
39
+ """Score a meme by evaluating the image-caption compatibility."""
40
  try:
41
+ # Preprocess image and caption
 
42
  inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True)
43
+
44
+ # Get the compatibility score
45
  outputs = model(**inputs)
46
  logits_per_text = outputs.logits_per_text
47
  return logits_per_text.item()
 
50
  return 0
51
 
52
  if st.button("Start Meme Battle"):
53
+ # Fetch random memes
54
  meme1, meme2 = fetch_random_memes()
 
 
55
 
56
+ # Parse captions and images
57
+ caption1, image1 = parse_meme(meme1)
58
+ caption2, image2 = parse_meme(meme2)
59
+
60
+ # Score memes
61
+ score1 = score_meme(image1, caption1)
62
+ score2 = score_meme(image2, caption2)
63
 
64
+ # Display Meme 1
65
  st.write("#### Meme 1")
66
+ st.image(image1, caption=f"Caption: {caption1}")
67
  st.write(f"AI Score: {score1:.2f}")
68
 
69
+ # Display Meme 2
70
  st.write("#### Meme 2")
71
+ st.image(image2, caption=f"Caption: {caption2}")
72
  st.write(f"AI Score: {score2:.2f}")
73
 
74
+ # Determine the winner
75
  if score1 > score2:
76
  st.write("πŸŽ‰ **Meme 1 Wins!**")
77
  elif score2 > score1:
78
  st.write("πŸŽ‰ **Meme 2 Wins!**")
79
  else:
80
  st.write("🀝 **It's a tie!**")