PictoVerse / app.py
JorgeV20
updated app
8004d8d
import gradio as gr
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from transformers import pipeline
import torch
from PIL import Image
from datasets import load_dataset
import soundfile as sf
import random
import string
import spaces
#--- IMAGE CAPTION-
def model():
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
return model
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
if gr.NO_RELOAD:
llm_model=model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
llm_model.to(device)
max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def predict_step(image_paths):
images = []
for image_path in image_paths:
i_image = Image.open(image_path)
if i_image.mode != "RGB":
i_image = i_image.convert(mode="RGB")
images.append(i_image)
pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
output_ids = llm_model.generate(pixel_values, **gen_kwargs)
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds
##----TEXT TO SPEECH
# load the processor
def load_processor():
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
return processor
# load the model
def load_speech_model():
speech_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
return speech_model
# load the vocoder, that is the voice
def load_vocoder():
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
return vocoder
# we load this dataset to get the speaker embeddings
def load_embeddings_dataset():
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
return embeddings_dataset
# speaker ids from the embeddings dataset
speakers = {
'awb': 0, # Scottish male
'bdl': 1138, # US male
'clb': 2271, # US female
'jmk': 3403, # Canadian male
'ksp': 4535, # Indian male
'rms': 5667, # US male
'slt': 6799 # US female
}
def save_text_to_speech(text, speaker=None):
# preprocess text
inputs = processor(text=text, return_tensors="pt").to(device)
if speaker is not None:
# load xvector containing speaker's voice characteristics from a dataset
speaker_embeddings = torch.tensor(embeddings_dataset[speaker]["xvector"]).unsqueeze(0).to(device)
else:
# random vector, meaning a random voice
speaker_embeddings = torch.randn((1, 512)).to(device)
# generate speech with the models
speech = speech_model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
if speaker is not None:
# if we have a speaker, we use the speaker's ID in the filename
output_filename = f"{speaker}-{'-'.join(text.split()[:6])}.mp3"
#output_filename = "speech.mp3"
else:
# if we don't have a speaker, we use a random string in the filename
random_str = ''.join(random.sample(string.ascii_letters+string.digits, k=5))
output_filename = f"{random_str}-{'-'.join(text.split()[:6])}.mp3"
#output_filename = "speech.mp3"
# save the generated speech to a file with 16KHz sampling rate
sf.write(output_filename, speech.cpu().numpy(), samplerate=16000)
# return the filename for reference
return output_filename
def load_text_generator():
gen = pipeline('text-generation', model='gpt2') # uses GPT-2
return gen
if gr.NO_RELOAD:
processor = load_processor()
speech_model=load_speech_model()
vocoder=load_vocoder()
embeddings_dataset = load_embeddings_dataset()
gen=load_text_generator()
def gradio_predict(image):
if image is None:
return ""
image_path = "temp_image.jpg"
image.save(image_path) # Save the uploaded image temporarily
prediction = predict_step([image_path])
return prediction[0].capitalize() if prediction else "Prediction failed."
import re
def remove_last_incomplete_sentence(text):
# Find all sentences ending with ., !, or ?
sentences = re.findall(r'[^.!?]*[.!?]', text, re.DOTALL)
# If there's no complete sentence found, return the original text
if not sentences:
return text
# Join the complete sentences
cleaned_text = ''.join(sentences).strip()
return cleaned_text
@spaces.GPU()
def get_story(pred):
gen_text=gen(pred, max_length=100,)[0]
cleaned_text = remove_last_incomplete_sentence(gen_text['generated_text'])
output_filename_2 = save_text_to_speech(cleaned_text, speaker=speakers["slt"])
return cleaned_text, output_filename_2
#---FRONT END
DESCRIPTION = """ # PictoVerse
### Dive into the multiverse of storytelling with PictoVerse, where every image unveils an array of parallel dimensions.
PictoVerse crafts captivating narratives from your photos, each set in a distinct universe of its own.
"""
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type='pil', label="Image")
clear_button = gr.Button("Clear")
with gr.Column(scale=4):
output_text = gr.Textbox(label="Prediction")
gen_text = gr.Textbox(label="Generated Story")
output_filename_2=gr.Audio(label='Audio')
button1 = gr.Button("Generate Story and Audio")
button1.click(fn=get_story, inputs=output_text, outputs=[gen_text, output_filename_2])
input_image.change(fn=gradio_predict, inputs=input_image, outputs=output_text)
clear_button.click(lambda: (None, "", "", None), inputs=[], outputs=[input_image, output_text, gen_text, output_filename_2])
demo.launch()