safras_ml_api / main.py
Arafath10's picture
Update main.py
809b3d0 verified
raw
history blame
1.02 kB
from ultralytics import YOLO
import asyncio
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import requests
model = YOLO('best.pt') # load a custom model
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
names = {0: 'breaking-stage', 1: 'half-riping-stage', 2: 'overripe un-healthy', 3: 'ripe', 4: 'ripe_with_consumable_disease', 5: 'unripe'}
@app.post("/predict")
async def get_prediction(file: UploadFile = File(...)):
file_name = file.filename
print(file_name)
with open(file_name, "wb") as file_object:
file_object.write(file.file.read())
# Predict with the model
results = model(file_name)
# View results
class_name = ""
for r in results:
print(names[int(r.probs.top1)])
class_name = names[int(r.probs.top1)]
return class_name