Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
from transformers import pipeline | |
import torch | |
from PIL import Image | |
from datasets import load_dataset | |
import soundfile as sf | |
import random | |
import string | |
import spaces | |
#--- IMAGE CAPTION- | |
def model(): | |
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
return model | |
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
if gr.NO_RELOAD: | |
llm_model=model() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
llm_model.to(device) | |
max_length = 16 | |
num_beams = 4 | |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams} | |
def predict_step(image_paths): | |
images = [] | |
for image_path in image_paths: | |
i_image = Image.open(image_path) | |
if i_image.mode != "RGB": | |
i_image = i_image.convert(mode="RGB") | |
images.append(i_image) | |
pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values | |
pixel_values = pixel_values.to(device) | |
output_ids = llm_model.generate(pixel_values, **gen_kwargs) | |
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) | |
preds = [pred.strip() for pred in preds] | |
return preds | |
##----TEXT TO SPEECH | |
# load the processor | |
def load_processor(): | |
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
return processor | |
# load the model | |
def load_speech_model(): | |
speech_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device) | |
return speech_model | |
# load the vocoder, that is the voice | |
def load_vocoder(): | |
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device) | |
return vocoder | |
# we load this dataset to get the speaker embeddings | |
def load_embeddings_dataset(): | |
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
return embeddings_dataset | |
# speaker ids from the embeddings dataset | |
speakers = { | |
'awb': 0, # Scottish male | |
'bdl': 1138, # US male | |
'clb': 2271, # US female | |
'jmk': 3403, # Canadian male | |
'ksp': 4535, # Indian male | |
'rms': 5667, # US male | |
'slt': 6799 # US female | |
} | |
def save_text_to_speech(text, speaker=None): | |
# preprocess text | |
inputs = processor(text=text, return_tensors="pt").to(device) | |
if speaker is not None: | |
# load xvector containing speaker's voice characteristics from a dataset | |
speaker_embeddings = torch.tensor(embeddings_dataset[speaker]["xvector"]).unsqueeze(0).to(device) | |
else: | |
# random vector, meaning a random voice | |
speaker_embeddings = torch.randn((1, 512)).to(device) | |
# generate speech with the models | |
speech = speech_model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder) | |
if speaker is not None: | |
# if we have a speaker, we use the speaker's ID in the filename | |
output_filename = f"{speaker}-{'-'.join(text.split()[:6])}.mp3" | |
#output_filename = "speech.mp3" | |
else: | |
# if we don't have a speaker, we use a random string in the filename | |
random_str = ''.join(random.sample(string.ascii_letters+string.digits, k=5)) | |
output_filename = f"{random_str}-{'-'.join(text.split()[:6])}.mp3" | |
#output_filename = "speech.mp3" | |
# save the generated speech to a file with 16KHz sampling rate | |
sf.write(output_filename, speech.cpu().numpy(), samplerate=16000) | |
# return the filename for reference | |
return output_filename | |
def load_text_generator(): | |
gen = pipeline('text-generation', model='gpt2') # uses GPT-2 | |
return gen | |
if gr.NO_RELOAD: | |
processor = load_processor() | |
speech_model=load_speech_model() | |
vocoder=load_vocoder() | |
embeddings_dataset = load_embeddings_dataset() | |
gen=load_text_generator() | |
def gradio_predict(image): | |
if image is None: | |
return "" | |
image_path = "temp_image.jpg" | |
image.save(image_path) # Save the uploaded image temporarily | |
prediction = predict_step([image_path]) | |
return prediction[0].capitalize() if prediction else "Prediction failed." | |
import re | |
def remove_last_incomplete_sentence(text): | |
# Find all sentences ending with ., !, or ? | |
sentences = re.findall(r'[^.!?]*[.!?]', text, re.DOTALL) | |
# If there's no complete sentence found, return the original text | |
if not sentences: | |
return text | |
# Join the complete sentences | |
cleaned_text = ''.join(sentences).strip() | |
return cleaned_text | |
def get_story(pred): | |
gen_text=gen(pred, max_length=100,)[0] | |
cleaned_text = remove_last_incomplete_sentence(gen_text['generated_text']) | |
output_filename_2 = save_text_to_speech(cleaned_text, speaker=speakers["slt"]) | |
return cleaned_text, output_filename_2 | |
#---FRONT END | |
DESCRIPTION = """ # PictoVerse | |
### Dive into the multiverse of storytelling with PictoVerse, where every image unveils an array of parallel dimensions. | |
PictoVerse crafts captivating narratives from your photos, each set in a distinct universe of its own. | |
""" | |
with gr.Blocks() as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(type='pil', label="Image") | |
clear_button = gr.Button("Clear") | |
with gr.Column(scale=4): | |
output_text = gr.Textbox(label="Prediction") | |
gen_text = gr.Textbox(label="Generated Story") | |
output_filename_2=gr.Audio(label='Audio') | |
button1 = gr.Button("Generate Story and Audio") | |
button1.click(fn=get_story, inputs=output_text, outputs=[gen_text, output_filename_2]) | |
input_image.change(fn=gradio_predict, inputs=input_image, outputs=output_text) | |
clear_button.click(lambda: (None, "", "", None), inputs=[], outputs=[input_image, output_text, gen_text, output_filename_2]) | |
demo.launch() | |