from fastapi import FastAPI, File, Form, UploadFile from fastapi.responses import StreamingResponse import torch from transformers import GroundingDinoForObjectDetection, AutoProcessor from PIL import Image, ImageDraw import io import mysql.connector app = FastAPI() # MySQL database connection settings username = 'ukrqsqxg_ukrqsqxg' password = 'emmy@0790467621' host = 'www.kacafix.com' database = 'ukrqsqxg_millbox_storage' # Create a connection to the MySQL database cnx = mysql.connector.connect( user=username, password=password, host=host, database=database ) # Create a cursor object to execute SQL queries cursor = cnx.cursor() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = GroundingDinoForObjectDetection.from_pretrained('IDEA-Research/grounding-dino-tiny') processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny") model.to(device) @app.post("/predict") async def predict( image: UploadFile = File(...), labels: str = Form(...), box_threshold: float = Form(...), text_threshold: float = Form(...) ): image_data = await image.read() image = Image.open(io.BytesIO(image_data)) labels = labels.split("\n") labels = [label if label.endswith(".") else label + "." for label in labels] labels = " ".join(labels) inputs = processor(images=image, text=labels, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) result = processor.post_process_grounded_object_detection( outputs, inputs.input_ids, box_threshold=box_threshold, text_threshold=text_threshold, target_sizes=[image.size[::-1]] )[0] draw = ImageDraw.Draw(image) boxes = result["boxes"].int().cpu().tolist() pred_labels = result["labels"] for box, label in zip(boxes, pred_labels): draw.rectangle(box, outline="red", width=3) draw.text((box[0], box[1]), label, fill="red") output_image_io = io.BytesIO() image.save(output_image_io, format='JPEG') output_image_io.seek(0) # Save the data in the MySQL database query = ("INSERT INTO model_requests (labels, box_threshold, text_threshold, output_image) " "VALUES (%s, %s, %s, %s)") cursor.execute(query, (labels, box_threshold, text_threshold, output_image_io.getvalue())) cnx.commit() return StreamingResponse(output_image_io, media_type="image/jpeg") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)