Mayada commited on
Commit
f3c07ed
1 Parent(s): b3bcaf6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -20
app.py CHANGED
@@ -1,24 +1,115 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
-
4
- # Load your image captioning model from Hugging Face
5
- model_name = "Mayada/AIC-transformer" # Update this with your model path
6
- captioner = pipeline("image-to-text", model=model_name)
7
-
8
- # Define a function to generate a caption from an image
9
- def generate_caption(image):
10
- result = captioner(image)
11
- return result[0]['generated_text']
12
-
13
- # Create a Gradio interface
14
- interface = gr.Interface(
15
- fn=generate_caption, # Function to process image and return caption
16
- inputs=gr.inputs.Image(type="pil"), # Accept image input
17
- outputs="text", # Output the caption as text
18
- title="AIC-transformer-2023", # Title for your interface
19
- description="Description", # Description for users
20
  )
21
 
22
- # Launch the Gradio interface
23
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
1
  import gradio as gr
2
+ from gradio.themes.base import Base
3
+ from PIL import Image
4
+ import torch
5
+ import torchvision.transforms as transforms
6
+ from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoModelForSeq2SeqLM
7
+
8
+ # Load the models
9
+ caption_model = VisionEncoderDecoderModel.from_pretrained('Mayada/AIC-transformer') # Your model on Hugging Face
10
+ caption_tokenizer = AutoTokenizer.from_pretrained('aubmindlab/bert-base-arabertv02')
11
+ question_model = AutoModelForSeq2SeqLM.from_pretrained("Mihakram/AraT5-base-question-generation")
12
+ question_tokenizer = AutoTokenizer.from_pretrained("Mihakram/AraT5-base-question-generation")
13
+
14
+ # Define the normalization and transformations
15
+ normalize = transforms.Normalize(
16
+ mean=[0.485, 0.456, 0.406], # ImageNet mean
17
+ std=[0.229, 0.224, 0.225] # ImageNet standard deviation
 
 
18
  )
19
 
20
+ inference_transforms = transforms.Compose([
21
+ transforms.Resize((224, 224)),
22
+ transforms.ToTensor(),
23
+ normalize
24
+ ])
25
+
26
+ # Load the dictionary (use it from your Hugging Face Space or include in the repo)
27
+ dictionary = {
28
+ "caption": "alternative_caption" # Replace with your actual dictionary
29
+ }
30
+
31
+ # Function to correct words in the caption using the dictionary
32
+ def correct_caption(caption):
33
+ corrected_words = [dictionary.get(word, word) for word in caption.split()]
34
+ corrected_caption = " ".join(corrected_words)
35
+ return corrected_caption
36
+
37
+ # Function to generate captions for an image
38
+ def generate_captions(image):
39
+ img_tensor = inference_transforms(image).unsqueeze(0)
40
+ generated = caption_model.generate(
41
+ img_tensor,
42
+ num_beams=3,
43
+ max_length=10,
44
+ early_stopping=True,
45
+ do_sample=True,
46
+ top_k=1000,
47
+ num_return_sequences=1,
48
+ )
49
+ captions = [caption_tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated]
50
+ return captions
51
+
52
+ # Function to generate questions given a context and answer
53
+ def generate_questions(context, answer):
54
+ text = "context: " + context + " " + "answer: " + answer + " </s>"
55
+ text_encoding = question_tokenizer.encode_plus(
56
+ text, return_tensors="pt"
57
+ )
58
+ question_model.eval()
59
+ generated_ids = question_model.generate(
60
+ input_ids=text_encoding['input_ids'],
61
+ attention_mask=text_encoding['attention_mask'],
62
+ max_length=64,
63
+ num_beams=5,
64
+ num_return_sequences=1
65
+ )
66
+ questions = [question_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(
67
+ 'question: ', ' ') for g in generated_ids]
68
+ return questions
69
+
70
+ # Gradio Interface Function
71
+ def caption_question_interface(image):
72
+ captions = generate_captions(image)
73
+ corrected_captions = [correct_caption(caption) for caption in captions]
74
+ questions_with_answers = []
75
+
76
+ for caption in corrected_captions:
77
+ words = caption.split()
78
+ if len(words) > 0:
79
+ answer = words[0]
80
+ question = generate_questions(caption, answer)
81
+ questions_with_answers.extend([(q, answer) for q in question])
82
+ if len(words) > 1:
83
+ answer = words[1]
84
+ question = generate_questions(caption, answer)
85
+ questions_with_answers.extend([(q, answer) for q in question])
86
+ if len(words) > 1:
87
+ answer = " ".join(words[:2])
88
+ question = generate_questions(caption, answer)
89
+ questions_with_answers.extend([(q, answer) for q in question])
90
+ if len(words) > 2:
91
+ answer = words[2]
92
+ question = generate_questions(caption, answer)
93
+ questions_with_answers.extend([(q, answer) for q in question])
94
+ if len(words) > 3:
95
+ answer = words[3]
96
+ question = generate_questions(caption, answer)
97
+ questions_with_answers.extend([(q, answer) for q in question])
98
+
99
+ formatted_questions = [f"Question: {q}\nAnswer: {a}" for q, a in questions_with_answers]
100
+ formatted_questions = "\n".join(formatted_questions)
101
+
102
+ return "\n".join(corrected_captions), formatted_questions
103
+
104
+ gr_interface = gr.Interface(
105
+ fn=caption_question_interface,
106
+ inputs=gr.inputs.Image(type="pil", label="Input Image"),
107
+ outputs=[
108
+ gr.outputs.Textbox(label="Generated Captions"),
109
+ gr.outputs.Textbox(label="Generated Questions and Answers")
110
+ ],
111
+ title="Image Captioning and Question Generation",
112
+ description="Generate captions and questions for images using pre-trained models."
113
+ )
114
 
115
+ gr_interface.launch()