Luisgust commited on
Commit
0ee03ec
·
verified ·
1 Parent(s): 65c6d78

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +51 -58
main.py CHANGED
@@ -1,74 +1,67 @@
1
- from fastapi import FastAPI, HTTPException, UploadFile, File
2
- from fastapi.responses import FileResponse
3
- from insightface.app import FaceAnalysis
 
 
4
  import insightface
 
 
 
5
  import cv2
6
- import shutil
7
- from tempfile import NamedTemporaryFile
8
-
9
- assert insightface.__version__ >= '0.7'
10
 
11
  app = FastAPI()
12
 
13
- def prepare_app():
14
- app = FaceAnalysis(name='buffalo_l')
15
- app.prepare(ctx_id=0, det_size=(640, 640))
16
- swapper = insightface.model_zoo.get_model('inswapper_128.onnx', download=True, download_zip=True)
17
- return app, swapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- face_app, swapper = prepare_app()
 
 
 
 
 
 
 
20
 
21
  def sort_faces(faces):
22
  return sorted(faces, key=lambda x: x.bbox[0])
23
 
24
  def get_face(faces, face_id):
25
  if len(faces) < face_id or face_id < 1:
26
- raise HTTPException(status_code=400, detail=f"The image includes only {len(faces)} faces, however, you asked for face {face_id}")
27
- return faces[face_id-1]
28
-
29
- def swap_faces(source_image_path, source_face_index, destination_image_path, destination_face_index):
30
- source_image = cv2.imread(source_image_path)
31
- destination_image = cv2.imread(destination_image_path)
32
-
33
- if source_image is None:
34
- raise HTTPException(status_code=400, detail="Source image could not be read.")
35
- if destination_image is None:
36
- raise HTTPException(status_code=400, detail="Destination image could not be read.")
37
 
38
- faces = sort_faces(face_app.get(source_image))
39
- if not faces:
40
- raise HTTPException(status_code=400, detail="No faces detected in the source image.")
41
- source_face = get_face(faces, source_face_index)
42
-
43
- res_faces = sort_faces(face_app.get(destination_image))
44
- if not res_faces:
45
- raise HTTPException(status_code=400, detail="No faces detected in the destination image.")
46
- res_face = get_face(res_faces, destination_face_index)
47
 
48
- result = swapper.get(destination_image, res_face, source_face, paste_back=True)
49
- return result
50
-
51
- @app.post("/swap_faces/")
52
- async def swap_faces_endpoint(
53
- source_file: UploadFile = File(...),
54
- source_face_index: int = 1,
55
- destination_file: UploadFile = File(...),
56
- destination_face_index: int = 1
57
- ):
58
- with NamedTemporaryFile(delete=False, suffix=".jpg") as source_temp_file:
59
- shutil.copyfileobj(source_file.file, source_temp_file)
60
- source_image_path = source_temp_file.name
61
 
62
- with NamedTemporaryFile(delete=False, suffix=".jpg") as destination_temp_file:
63
- shutil.copyfileobj(destination_file.file, destination_temp_file)
64
- destination_image_path = destination_temp_file.name
65
 
66
- try:
67
- result = swap_faces(source_image_path, source_face_index, destination_image_path, destination_face_index)
68
- result_path = "result.jpg"
69
- cv2.imwrite(result_path, result)
70
- return FileResponse(result_path, media_type="image/jpeg")
71
- except HTTPException as e:
72
- raise e
73
- except Exception as e:
74
- raise HTTPException(status_code=500, detail=str(e))
 
1
+
2
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
3
+ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ from fastapi.templating import Jinja2Templates
6
  import insightface
7
+ from insightface.app import FaceAnalysis
8
+ from insightface.model_zoo import get_model
9
+ import numpy as np
10
  import cv2
11
+ import io
 
 
 
12
 
13
  app = FastAPI()
14
 
15
+ # Initialize FaceAnalysis app and swapper
16
+ face_analysis = FaceAnalysis(name='buffalo_l')
17
+ face_analysis.prepare(ctx_id=0, det_size=(640, 640))
18
+ swapper = get_model('inswapper_128.onnx', download=True, download_zip=True)
19
+
20
+ @app.post("/swap_faces/")
21
+ async def swap_faces(source_file: UploadFile = File(...),
22
+ source_face_index: int = Form(...),
23
+ destination_file: UploadFile = File(...),
24
+ destination_face_index: int = Form(...)):
25
+ """Swaps faces between the source and destination images based on the specified face indices."""
26
+ source_bytes = await source_file.read()
27
+ destination_bytes = await destination_file.read()
28
+
29
+ # Decode images
30
+ source_image = cv2.imdecode(np.frombuffer(source_bytes, np.uint8), cv2.IMREAD_COLOR)
31
+ destination_image = cv2.imdecode(np.frombuffer(destination_bytes, np.uint8), cv2.IMREAD_COLOR)
32
+
33
+ # Face detection and sorting
34
+ faces_source = sort_faces(face_analysis.get(source_image))
35
+ if not faces_source:
36
+ raise HTTPException(status_code=400, detail="No faces detected in the source image.")
37
+ source_face = get_face(faces_source, source_face_index)
38
+
39
+ faces_destination = sort_faces(face_analysis.get(destination_image))
40
+ if not faces_destination:
41
+ raise HTTPException(status_code=400, detail="No faces detected in the destination image.")
42
+ destination_face = get_face(faces_destination, destination_face_index)
43
 
44
+ # Swap faces
45
+ result_image = swapper.get(destination_image, destination_face, source_face, paste_back=True)
46
+
47
+ # Convert result_image back to bytes
48
+ _, result_bytes = cv2.imencode('.jpg', result_image)
49
+
50
+ # Return the image bytes as a streaming response
51
+ return StreamingResponse(io.BytesIO(result_bytes), media_type="image/jpeg")
52
 
53
  def sort_faces(faces):
54
  return sorted(faces, key=lambda x: x.bbox[0])
55
 
56
  def get_face(faces, face_id):
57
  if len(faces) < face_id or face_id < 1:
58
+ raise ValueError(f"The image includes only {len(faces)} faces, however, you asked for face {face_id}")
59
+ return faces[face_id - 1]
 
 
 
 
 
 
 
 
 
60
 
61
+ @app.exception_handler(ValueError)
62
+ async def value_error_handler(request, exc):
63
+ """Custom exception handler to return JSON error responses for ValueError."""
64
+ return JSONResponse(status_code=400, content={"detail": str(exc)})
 
 
 
 
 
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
 
 
67