cat-dog-api / app.py
blaxx14's picture
Add application file
20e43d5
raw
history blame
1.13 kB
from fastapi import FastAPI, UploadFile, File
import tensorflow as tf
import numpy as np
from PIL import Image
import io
import requests
import os
MODEL_URL = "https://huggingface.co/blaxx14/cat-vs-dog-inceptionv3/resolve/main/cat_dog_inception_v3.h5"
MODEL_PATH = "cat_dog_inception_v3.h5"
app = FastAPI()
def download_model():
if not os.path.exists(MODEL_PATH):
print("Downloading model...")
response = requests.get(MODEL_URL)
with open(MODEL_PATH, "wb") as f:
f.write(response.content)
print("Model downloaded!")
download_model()
print("Loading model...")
model = tf.keras.models.load_model(MODEL_PATH)
print("Model loaded!")
def preprocess_image(image):
img = Image.open(io.BytesIO(image)).convert("RGB")
img = img.resize((150, 150))
img = np.array(img) / 255.0
img = np.expand_dims(img, axis=0)
return img
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
image = await file.read()
processed_img = preprocess_image(image)
prediction = model.predict(processed_img)
return {"prediction": prediction.tolist()}