Spaces:
Runtime error
Runtime error
| from transformers import Blip2ForConditionalGeneration | |
| from transformers import Blip2Processor | |
| from peft import PeftModel, PeftConfig | |
| import streamlit as st | |
| from PIL import Image | |
| import torch | |
| preprocess_ckp = "Salesforce/blip2-opt-2.7b" #Checkpoint path used for perprocess image | |
| base_model_ckp = "/model/blip2-opt-2.7b-fp16-sharded" #Base model checkpoint path | |
| peft_model_ckp = "/model/blip2_peft" #PEFT model checkpoint path | |
| init_model_required = True | |
| processor = None | |
| model = None | |
| def init_model(): | |
| if init_model_required: | |
| #Preprocess input | |
| processor = Blip2Processor.from_pretrained(preprocess_ckp) | |
| #Model | |
| model = Blip2ForConditionalGeneration.from_pretrained(base_model_ckp, load_in_8bit = True, device_map = "auto") | |
| model = PeftModel.from_pretrained(model, peft_model_ckp) | |
| init_model_required = False | |
| def main(): | |
| st.title("Fashion Image Caption using BLIP2") | |
| init_model() | |
| file_name = st.file_uploader("Upload image") | |
| if file_name is not None: | |
| image_col, caption_text = st.columns(2) | |
| image_col.header("Image") | |
| image = Image.open(file_name) | |
| image_col.image(image, use_column_width = True) | |
| #Preprocess the image | |
| inputs = processor(images = image, return_tensors = "pt").to('cuda', torch.float16) | |
| pixel_values = inputs.pixel_values | |
| #Predict the caption for the imahe | |
| generated_ids = model.generate(pixel_values = pixel_values, max_length = 25) | |
| generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| #Output the predict text | |
| caption_text.header("Generated Caption") | |
| caption_text.text(generated_caption) | |
| if __name__ == "__main__": | |
| main() |