|
import gradio as gr |
|
import spaces |
|
import torch |
|
from PIL import Image |
|
|
|
|
|
torch.manual_seed(0) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed_all(0) |
|
|
|
from models.vision_language_model import VisionLanguageModel |
|
from data.processors import get_tokenizer, get_image_processor |
|
|
|
|
|
@spaces.GPU |
|
def generate_outputs(image, query): |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
device = torch.device("mps") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
|
|
hf_model = "lusxvr/nanoVLM-222M" |
|
try: |
|
model = VisionLanguageModel.from_pretrained(hf_model).to(device) |
|
model.eval() |
|
except Exception as e: |
|
return f"Error loading model: {str(e)}", None, None, None, None |
|
|
|
|
|
try: |
|
tokenizer = get_tokenizer(model.cfg.lm_tokenizer) |
|
image_processor = get_image_processor(model.cfg.vit_img_size) |
|
except Exception as e: |
|
return f"Error loading tokenizer or image processor: {str(e)}", None, None, None, None |
|
|
|
|
|
template = f"Question: {query} Answer:" |
|
encoded = tokenizer.batch_encode_plus([template], return_tensors="pt") |
|
tokens = encoded["input_ids"].to(device) |
|
|
|
|
|
try: |
|
img = image.convert("RGB") |
|
img_t = image_processor(img).unsqueeze(0).to(device) |
|
except Exception as e: |
|
return f"Error processing image: {str(e)}", None, None, None, None |
|
|
|
|
|
outputs = [] |
|
max_new_tokens = 50 |
|
try: |
|
for _ in range(4): |
|
gen = model.generate(tokens, img_t, max_new_tokens=max_new_tokens) |
|
out = tokenizer.batch_decode(gen, skip_special_tokens=True)[0] |
|
outputs.append(out) |
|
except Exception as e: |
|
return f"Error during generation: {str(e)}", None, None, None, None |
|
|
|
return None, outputs[0], outputs[1], outputs[2], outputs[3] |
|
|
|
|
|
def main(): |
|
|
|
css = """ |
|
.gradio-container { |
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; |
|
padding: 20px; |
|
} |
|
h1 { |
|
color: #333; |
|
text-align: center; |
|
margin-bottom: 20px; |
|
} |
|
.description { |
|
margin-bottom: 20px; |
|
line-height: 1.6; |
|
} |
|
.gr-button { |
|
padding: 10px 20px; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css, title="nanoVLM Image-to-Text Generator") as app: |
|
gr.Markdown( |
|
"# nanoVLM Image-to-Text Generator" |
|
) |
|
gr.Markdown( |
|
""" |
|
<div class="description"> |
|
This demo showcases <b>nanoVLM</b>, a lightweight vision-language model by HuggingFace. |
|
Upload an image and provide a query to generate four text descriptions. |
|
The model is based on the <a href="https://github.com/huggingface/nanoVLM/" target="_blank">nanoVLM repository</a> |
|
and uses the pretrained model <a href="https://huggingface.co/lusxvr/nanoVLM-222M" target="_blank">lusxvr/nanoVLM-222M</a>. |
|
nanoVLM is designed for efficient image-to-text generation, ideal for resource-constrained environments. |
|
</div> |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image( |
|
type="pil", |
|
label="Upload Image", |
|
value="cat.jpg" |
|
) |
|
query_input = gr.Textbox( |
|
label="Query", |
|
value="What is this?", |
|
placeholder="Enter your query here", |
|
lines=2 |
|
) |
|
submit_button = gr.Button("Generate") |
|
|
|
with gr.Column(): |
|
error_output = gr.Textbox( |
|
label="Errors (if any)", |
|
placeholder="No errors", |
|
visible=True, |
|
interactive=False |
|
) |
|
output1 = gr.Textbox( |
|
label="Generation 1", |
|
placeholder="Output 1 will appear here...", |
|
lines=3 |
|
) |
|
output2 = gr.Textbox( |
|
label="Generation 2", |
|
placeholder="Output 2 will appear here...", |
|
lines=3 |
|
) |
|
output3 = gr.Textbox( |
|
label="Generation 3", |
|
placeholder="Output 3 will appear here...", |
|
lines=3 |
|
) |
|
output4 = gr.Textbox( |
|
label="Generation 4", |
|
placeholder="Output 4 will appear here...", |
|
lines=3 |
|
) |
|
|
|
|
|
submit_button.click( |
|
fn=generate_outputs, |
|
inputs=[image_input, query_input], |
|
outputs=[error_output, output1, output2, output3, output4] |
|
) |
|
|
|
|
|
app.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |