deletesoon / app.py
Mohaddz's picture
app and requirements
d311746 verified
raw
history blame
No virus
3.51 kB
import gradio as gr
import torch
from PIL import Image
import numpy as np
import tensorflow as tf
from transformers import SegformerForSemanticSegmentation, AutoFeatureExtractor
import cv2
import json
# Load models
part_seg_model = SegformerForSemanticSegmentation.from_pretrained("Mohaddz/huggingCars")
damage_seg_model = SegformerForSemanticSegmentation.from_pretrained("Mohaddz/DamageSeg")
feature_extractor = AutoFeatureExtractor.from_pretrained("Mohaddz/huggingCars")
dl_model = tf.keras.models.load_model('improved_car_damage_prediction_model.h5')
# Load parts list
with open('cars117.json', 'r', encoding='utf-8') as f:
data = json.load(f)
all_parts = sorted(list(set(part for entry in data.values() for part in entry.get('replaced_parts', []))))
def process_image(image):
# Convert to RGB if it's not
if image.mode != 'RGB':
image = image.convert('RGB')
# Prepare input for the model
inputs = feature_extractor(images=image, return_tensors="pt")
# Get damage segmentation
with torch.no_grad():
damage_output = damage_seg_model(**inputs).logits
damage_features = damage_output.squeeze().detach().numpy()
# Create damage segmentation heatmap
damage_heatmap = create_heatmap(damage_features)
damage_heatmap_resized = cv2.resize(damage_heatmap, (image.size[0], image.size[1]))
# Create annotated damage image
image_array = np.array(image)
damage_mask = np.argmax(damage_features, axis=0)
damage_mask_resized = cv2.resize(damage_mask, (image.size[0], image.size[1]), interpolation=cv2.INTER_NEAREST)
overlay = np.zeros_like(image_array)
overlay[damage_mask_resized > 0] = [255, 0, 0] # Red color for damage
annotated_image = cv2.addWeighted(image_array, 1, overlay, 0.5, 0)
# Process for part prediction and heatmap
with torch.no_grad():
part_output = part_seg_model(**inputs).logits
part_features = part_output.squeeze().detach().numpy()
part_heatmap = create_heatmap(part_features)
part_heatmap_resized = cv2.resize(part_heatmap, (image.size[0], image.size[1]))
# Predict parts to replace
input_vector = np.concatenate([part_features.mean(axis=(1, 2)), damage_features.mean(axis=(1, 2))])
prediction = dl_model.predict(np.array([input_vector]))
predicted_parts = [(all_parts[i], float(prob)) for i, prob in enumerate(prediction[0]) if prob > 0.1]
predicted_parts.sort(key=lambda x: x[1], reverse=True)
return (Image.fromarray(annotated_image),
Image.fromarray(damage_heatmap_resized),
Image.fromarray(part_heatmap_resized),
"\n".join([f"{part}: {prob:.2f}" for part, prob in predicted_parts[:5]]))
def create_heatmap(features):
heatmap = np.sum(features, axis=0)
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
heatmap = np.uint8(255 * heatmap)
return cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Annotated Damage"),
gr.Image(type="pil", label="Damage Heatmap"),
gr.Image(type="pil", label="Part Segmentation Heatmap"),
gr.Textbox(label="Predicted Parts to Replace")
],
title="Car Damage Assessment",
description="Upload an image of a damaged car to get an assessment."
)
iface.launch(share=True)