File size: 3,646 Bytes
311e1d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from flask import Flask, request, jsonify
from flask_cors import CORS
from ultralytics import YOLO
import numpy as np
import cv2
import io
import os
from dotenv import load_dotenv
from azure.storage.blob import BlobServiceClient

# Load environment variables from .env file
load_dotenv()

# Initialize Flask app
app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": os.getenv("ORIGINS")}})  # Allow requests from specified origins

# Get Azure Blob Storage connection string from environment variable
connect_str = os.getenv("AZURE_STORAGE_CONNECTION_STRING")
blob_service_client = BlobServiceClient.from_connection_string(connect_str)
container_name = os.getenv("AZURE_BLOB_CONTAINER_NAME")
blob_name = os.getenv("AZURE_BLOB_MODEL_PATH")  # Path to your model within the container

model = None

def load_model():
    global model
    try:
        container_client = blob_service_client.get_container_client(container_name)
        blob_client = container_client.get_blob_client(blob_name)
        model_data = blob_client.download_blob().readall()
        model = YOLO(io.BytesIO(model_data))  # Load YOLO model from bytes
        print("YOLO model loaded successfully from Azure Blob Storage.")
    except Exception as e:
        print(f"Error loading model from Azure Blob Storage: {e}")
        print(f"Container Name: {container_name}") 
        print(f"Blob Name: {blob_name}")
        return jsonify({"error": f"Failed to load model: {e}"}), 500


with app.app_context():
# def before_first_request():
    print("💮 Loading YOLO model...")
    load_model()

@app.route("/")
def home():
    return jsonify({"message": "Welcome to the YOLOv5 API!", "host": request.host, "url": request.url})

@app.route("/predict", methods=["POST"])
def predict():
    """Receive an image from the frontend, run YOLO model, and return detections."""
    if "file" not in request.files:
        return jsonify({"error": "No file uploaded"}), 400

    file = request.files["file"]
    # Read image file as a byte stream and convert to numpy array using OpenCV
    file_bytes = np.frombuffer(file.read(), np.uint8)
    image_array = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)

    try:
        results = model.predict(image_array)
        predictions = {}

        # Process detections
        for result in results:
            for box in result.boxes:
                class_id = int(box.cls[0].item())  # Class ID
                label = model.names[class_id]  # Class name
                confidence = round(box.conf[0].item(), 2)  # Confidence score
                x1, y1, x2, y2 = map(int, box.xyxy[0])  # Bounding box coordinates

                prediction = {
                    "bbox": [x1, y1, x2, y2],
                    "class_id": class_id,
                    "confidence": confidence
                }
                # Ensure the label name doesn't contain spaces and is lowercase e.g. "root_piece"
                label = label.replace(" ", "_").lower()
                if label not in predictions: 
                    predictions[label] = []
                predictions[label].append(prediction)

    except Exception as e:
        return jsonify({"error": str(e)}), 500

    return jsonify({"prediction": predictions})

if __name__ == "__main__":
    if os.getenv("ENVIRONMENT") == "production":
        # Use Gunicorn to run the app in production
        from gunicorn.app.wsgiapp import run
        run()
    else:
        # Use Flask's built-in server for local development
        app.run(debug=True, host='0.0.0.0', port=5000)