|
from flask import Flask, request, jsonify
|
|
from flask_cors import CORS
|
|
from ultralytics import YOLO
|
|
import numpy as np
|
|
import cv2
|
|
import io
|
|
import os
|
|
from dotenv import load_dotenv
|
|
from azure.storage.blob import BlobServiceClient
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
app = Flask(__name__)
|
|
CORS(app, resources={r"/*": {"origins": os.getenv("ORIGINS")}})
|
|
|
|
|
|
connect_str = os.getenv("AZURE_STORAGE_CONNECTION_STRING")
|
|
blob_service_client = BlobServiceClient.from_connection_string(connect_str)
|
|
container_name = os.getenv("AZURE_BLOB_CONTAINER_NAME")
|
|
blob_name = os.getenv("AZURE_BLOB_MODEL_PATH")
|
|
|
|
model = None
|
|
|
|
def load_model():
|
|
global model
|
|
try:
|
|
container_client = blob_service_client.get_container_client(container_name)
|
|
blob_client = container_client.get_blob_client(blob_name)
|
|
model_data = blob_client.download_blob().readall()
|
|
model = YOLO(io.BytesIO(model_data))
|
|
print("YOLO model loaded successfully from Azure Blob Storage.")
|
|
except Exception as e:
|
|
print(f"Error loading model from Azure Blob Storage: {e}")
|
|
print(f"Container Name: {container_name}")
|
|
print(f"Blob Name: {blob_name}")
|
|
return jsonify({"error": f"Failed to load model: {e}"}), 500
|
|
|
|
|
|
with app.app_context():
|
|
|
|
print("💮 Loading YOLO model...")
|
|
load_model()
|
|
|
|
@app.route("/")
|
|
def home():
|
|
return jsonify({"message": "Welcome to the YOLOv5 API!", "host": request.host, "url": request.url})
|
|
|
|
@app.route("/predict", methods=["POST"])
|
|
def predict():
|
|
"""Receive an image from the frontend, run YOLO model, and return detections."""
|
|
if "file" not in request.files:
|
|
return jsonify({"error": "No file uploaded"}), 400
|
|
|
|
file = request.files["file"]
|
|
|
|
file_bytes = np.frombuffer(file.read(), np.uint8)
|
|
image_array = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
|
|
|
try:
|
|
results = model.predict(image_array)
|
|
predictions = {}
|
|
|
|
|
|
for result in results:
|
|
for box in result.boxes:
|
|
class_id = int(box.cls[0].item())
|
|
label = model.names[class_id]
|
|
confidence = round(box.conf[0].item(), 2)
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
|
|
|
prediction = {
|
|
"bbox": [x1, y1, x2, y2],
|
|
"class_id": class_id,
|
|
"confidence": confidence
|
|
}
|
|
|
|
label = label.replace(" ", "_").lower()
|
|
if label not in predictions:
|
|
predictions[label] = []
|
|
predictions[label].append(prediction)
|
|
|
|
except Exception as e:
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
return jsonify({"prediction": predictions})
|
|
|
|
if __name__ == "__main__":
|
|
if os.getenv("ENVIRONMENT") == "production":
|
|
|
|
from gunicorn.app.wsgiapp import run
|
|
run()
|
|
else:
|
|
|
|
app.run(debug=True, host='0.0.0.0', port=5000) |