nanovlm / app.py
ariG23498's picture
ariG23498 HF Staff
add demo
f2c2a4e
import gradio as gr
import spaces
import torch
from PIL import Image
# Set random seeds for reproducibility
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
from models.vision_language_model import VisionLanguageModel
from data.processors import get_tokenizer, get_image_processor
@spaces.GPU
def generate_outputs(image, query):
# Determine device
if torch.cuda.is_available():
device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
# Load model
hf_model = "lusxvr/nanoVLM-222M"
try:
model = VisionLanguageModel.from_pretrained(hf_model).to(device)
model.eval()
except Exception as e:
return f"Error loading model: {str(e)}", None, None, None, None
# Load tokenizer and image processor
try:
tokenizer = get_tokenizer(model.cfg.lm_tokenizer)
image_processor = get_image_processor(model.cfg.vit_img_size)
except Exception as e:
return f"Error loading tokenizer or image processor: {str(e)}", None, None, None, None
# Prepare text input
template = f"Question: {query} Answer:"
encoded = tokenizer.batch_encode_plus([template], return_tensors="pt")
tokens = encoded["input_ids"].to(device)
# Process image
try:
img = image.convert("RGB")
img_t = image_processor(img).unsqueeze(0).to(device)
except Exception as e:
return f"Error processing image: {str(e)}", None, None, None, None
# Generate four outputs
outputs = []
max_new_tokens = 50 # Fixed value from provided script
try:
for _ in range(4):
gen = model.generate(tokens, img_t, max_new_tokens=max_new_tokens)
out = tokenizer.batch_decode(gen, skip_special_tokens=True)[0]
outputs.append(out)
except Exception as e:
return f"Error during generation: {str(e)}", None, None, None, None
return None, outputs[0], outputs[1], outputs[2], outputs[3]
def main():
# Define minimal CSS for subtle aesthetic enhancements
css = """
.gradio-container {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
padding: 20px;
}
h1 {
color: #333;
text-align: center;
margin-bottom: 20px;
}
.description {
margin-bottom: 20px;
line-height: 1.6;
}
.gr-button {
padding: 10px 20px;
}
"""
# Define Gradio interface
with gr.Blocks(css=css, title="nanoVLM Image-to-Text Generator") as app:
gr.Markdown(
"# nanoVLM Image-to-Text Generator"
)
gr.Markdown(
"""
<div class="description">
This demo showcases <b>nanoVLM</b>, a lightweight vision-language model by HuggingFace.
Upload an image and provide a query to generate four text descriptions.
The model is based on the <a href="https://github.com/huggingface/nanoVLM/" target="_blank">nanoVLM repository</a>
and uses the pretrained model <a href="https://huggingface.co/lusxvr/nanoVLM-222M" target="_blank">lusxvr/nanoVLM-222M</a>.
nanoVLM is designed for efficient image-to-text generation, ideal for resource-constrained environments.
</div>
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(
type="pil",
label="Upload Image",
value="cat.jpg" # Set example image
)
query_input = gr.Textbox(
label="Query",
value="What is this?",
placeholder="Enter your query here",
lines=2
)
submit_button = gr.Button("Generate")
with gr.Column():
error_output = gr.Textbox(
label="Errors (if any)",
placeholder="No errors",
visible=True,
interactive=False
)
output1 = gr.Textbox(
label="Generation 1",
placeholder="Output 1 will appear here...",
lines=3
)
output2 = gr.Textbox(
label="Generation 2",
placeholder="Output 2 will appear here...",
lines=3
)
output3 = gr.Textbox(
label="Generation 3",
placeholder="Output 3 will appear here...",
lines=3
)
output4 = gr.Textbox(
label="Generation 4",
placeholder="Output 4 will appear here...",
lines=3
)
# Define action on submit
submit_button.click(
fn=generate_outputs,
inputs=[image_input, query_input],
outputs=[error_output, output1, output2, output3, output4]
)
# Launch the app
app.launch()
if __name__ == "__main__":
main()