Update main.py
Browse files
main.py
CHANGED
@@ -1,74 +1,67 @@
|
|
1 |
-
|
2 |
-
from fastapi
|
3 |
-
from
|
|
|
|
|
4 |
import insightface
|
|
|
|
|
|
|
5 |
import cv2
|
6 |
-
import
|
7 |
-
from tempfile import NamedTemporaryFile
|
8 |
-
|
9 |
-
assert insightface.__version__ >= '0.7'
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
def sort_faces(faces):
|
22 |
return sorted(faces, key=lambda x: x.bbox[0])
|
23 |
|
24 |
def get_face(faces, face_id):
|
25 |
if len(faces) < face_id or face_id < 1:
|
26 |
-
raise
|
27 |
-
return faces[face_id-1]
|
28 |
-
|
29 |
-
def swap_faces(source_image_path, source_face_index, destination_image_path, destination_face_index):
|
30 |
-
source_image = cv2.imread(source_image_path)
|
31 |
-
destination_image = cv2.imread(destination_image_path)
|
32 |
-
|
33 |
-
if source_image is None:
|
34 |
-
raise HTTPException(status_code=400, detail="Source image could not be read.")
|
35 |
-
if destination_image is None:
|
36 |
-
raise HTTPException(status_code=400, detail="Destination image could not be read.")
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
res_faces = sort_faces(face_app.get(destination_image))
|
44 |
-
if not res_faces:
|
45 |
-
raise HTTPException(status_code=400, detail="No faces detected in the destination image.")
|
46 |
-
res_face = get_face(res_faces, destination_face_index)
|
47 |
|
48 |
-
result = swapper.get(destination_image, res_face, source_face, paste_back=True)
|
49 |
-
return result
|
50 |
-
|
51 |
-
@app.post("/swap_faces/")
|
52 |
-
async def swap_faces_endpoint(
|
53 |
-
source_file: UploadFile = File(...),
|
54 |
-
source_face_index: int = 1,
|
55 |
-
destination_file: UploadFile = File(...),
|
56 |
-
destination_face_index: int = 1
|
57 |
-
):
|
58 |
-
with NamedTemporaryFile(delete=False, suffix=".jpg") as source_temp_file:
|
59 |
-
shutil.copyfileobj(source_file.file, source_temp_file)
|
60 |
-
source_image_path = source_temp_file.name
|
61 |
|
62 |
-
with NamedTemporaryFile(delete=False, suffix=".jpg") as destination_temp_file:
|
63 |
-
shutil.copyfileobj(destination_file.file, destination_temp_file)
|
64 |
-
destination_image_path = destination_temp_file.name
|
65 |
|
66 |
-
try:
|
67 |
-
result = swap_faces(source_image_path, source_face_index, destination_image_path, destination_face_index)
|
68 |
-
result_path = "result.jpg"
|
69 |
-
cv2.imwrite(result_path, result)
|
70 |
-
return FileResponse(result_path, media_type="image/jpeg")
|
71 |
-
except HTTPException as e:
|
72 |
-
raise e
|
73 |
-
except Exception as e:
|
74 |
-
raise HTTPException(status_code=500, detail=str(e))
|
|
|
1 |
+
|
2 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
3 |
+
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
|
4 |
+
from fastapi.staticfiles import StaticFiles
|
5 |
+
from fastapi.templating import Jinja2Templates
|
6 |
import insightface
|
7 |
+
from insightface.app import FaceAnalysis
|
8 |
+
from insightface.model_zoo import get_model
|
9 |
+
import numpy as np
|
10 |
import cv2
|
11 |
+
import io
|
|
|
|
|
|
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
15 |
+
# Initialize FaceAnalysis app and swapper
|
16 |
+
face_analysis = FaceAnalysis(name='buffalo_l')
|
17 |
+
face_analysis.prepare(ctx_id=0, det_size=(640, 640))
|
18 |
+
swapper = get_model('inswapper_128.onnx', download=True, download_zip=True)
|
19 |
+
|
20 |
+
@app.post("/swap_faces/")
|
21 |
+
async def swap_faces(source_file: UploadFile = File(...),
|
22 |
+
source_face_index: int = Form(...),
|
23 |
+
destination_file: UploadFile = File(...),
|
24 |
+
destination_face_index: int = Form(...)):
|
25 |
+
"""Swaps faces between the source and destination images based on the specified face indices."""
|
26 |
+
source_bytes = await source_file.read()
|
27 |
+
destination_bytes = await destination_file.read()
|
28 |
+
|
29 |
+
# Decode images
|
30 |
+
source_image = cv2.imdecode(np.frombuffer(source_bytes, np.uint8), cv2.IMREAD_COLOR)
|
31 |
+
destination_image = cv2.imdecode(np.frombuffer(destination_bytes, np.uint8), cv2.IMREAD_COLOR)
|
32 |
+
|
33 |
+
# Face detection and sorting
|
34 |
+
faces_source = sort_faces(face_analysis.get(source_image))
|
35 |
+
if not faces_source:
|
36 |
+
raise HTTPException(status_code=400, detail="No faces detected in the source image.")
|
37 |
+
source_face = get_face(faces_source, source_face_index)
|
38 |
+
|
39 |
+
faces_destination = sort_faces(face_analysis.get(destination_image))
|
40 |
+
if not faces_destination:
|
41 |
+
raise HTTPException(status_code=400, detail="No faces detected in the destination image.")
|
42 |
+
destination_face = get_face(faces_destination, destination_face_index)
|
43 |
|
44 |
+
# Swap faces
|
45 |
+
result_image = swapper.get(destination_image, destination_face, source_face, paste_back=True)
|
46 |
+
|
47 |
+
# Convert result_image back to bytes
|
48 |
+
_, result_bytes = cv2.imencode('.jpg', result_image)
|
49 |
+
|
50 |
+
# Return the image bytes as a streaming response
|
51 |
+
return StreamingResponse(io.BytesIO(result_bytes), media_type="image/jpeg")
|
52 |
|
53 |
def sort_faces(faces):
|
54 |
return sorted(faces, key=lambda x: x.bbox[0])
|
55 |
|
56 |
def get_face(faces, face_id):
|
57 |
if len(faces) < face_id or face_id < 1:
|
58 |
+
raise ValueError(f"The image includes only {len(faces)} faces, however, you asked for face {face_id}")
|
59 |
+
return faces[face_id - 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
+
@app.exception_handler(ValueError)
|
62 |
+
async def value_error_handler(request, exc):
|
63 |
+
"""Custom exception handler to return JSON error responses for ValueError."""
|
64 |
+
return JSONResponse(status_code=400, content={"detail": str(exc)})
|
|
|
|
|
|
|
|
|
|
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
|
|
|
|
|
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|