Luisgust commited on
Commit
0860266
·
verified ·
1 Parent(s): 8ea86a9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +61 -13
main.py CHANGED
@@ -1,16 +1,64 @@
1
- import cv2,os
 
 
 
 
2
  import AnimeGANv3_src
3
- if __name__ == '__main__':
4
-
5
- f = 'A'
6
- input_imgs_path = r'../../v3-usa\dataset\USA\val'
7
- # input_imgs_path = r'/mnt/data/xinchen/v3-usa/dataset/USA/val'
8
- output_path = 'AnimeGANv3_usa_64_output'
9
- # img = cv2.imread(os.path.join(input_imgs_path, os.listdir(input_imgs_path)[0]))
10
- img = cv2.imread(os.path.join(input_imgs_path, 'jp_16.png'))
11
- out = AnimeGANv3_src.Convert(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), f, True)
12
- # cv2.imshow('d', cv2.cvtColor(out, cv2.COLOR_BGR2RGB))
13
- # cv2.waitKey(0)
14
- cv2.imwrite('a.jpg', cv2.cvtColor(out, cv2.COLOR_BGR2RGB))
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
+ from fastapi.responses import StreamingResponse
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
  import AnimeGANv3_src
7
+ from io import BytesIO
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Initialize FastAPI
10
+ app = FastAPI()
11
+
12
+ os.makedirs('output', exist_ok=True)
13
+
14
+ def process_image(img_path, style, if_face):
15
+ print(img_path, style, if_face)
16
+ try:
17
+ img = cv2.imread(img_path)
18
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
19
+ style_mapping = {
20
+ "AnimeGANv3_Arcane": "A",
21
+ "AnimeGANv3_Trump v1.0": "T",
22
+ "AnimeGANv3_Shinkai": "S",
23
+ "AnimeGANv3_PortraitSketch": "P",
24
+ "AnimeGANv3_Hayao": "H",
25
+ "AnimeGANv3_Disney v1.0": "D",
26
+ "AnimeGANv3_JP_face v1.0": "J",
27
+ "AnimeGANv3_Kpop v2.0": "K",
28
+ "AnimeGANv3_USA": "U"
29
+ }
30
+ f = style_mapping.get(style, "U")
31
+
32
+ det_face = True if if_face == "Yes" else False
33
+ output = AnimeGANv3_src.Convert(img, f, det_face)
34
+ save_path = f"output/out.{img_path.rsplit('.')[-1]}"
35
+ cv2.imwrite(save_path, output[:, :, ::-1])
36
+ return output, save_path
37
+ except Exception as error:
38
+ print('Error', error)
39
+ return None, None
40
+
41
+ @app.post("/inference/")
42
+ async def inference(file: UploadFile = File(...), Style: str = Form(...), if_face: str = Form(...)):
43
+ try:
44
+ # Save the uploaded file to a temporary location
45
+ file_location = f"temp_{file.filename}"
46
+ with open(file_location, "wb") as f:
47
+ f.write(await file.read())
48
+
49
+ # Process the image
50
+ output, save_path = process_image(file_location, Style, if_face)
51
+
52
+ if output is None:
53
+ raise HTTPException(status_code=500, detail="Processing failed")
54
+
55
+ # Read the processed image and prepare it for response
56
+ with open(save_path, "rb") as result_file:
57
+ result_bytes = result_file.read()
58
+
59
+ # Return the image as a blob
60
+ return StreamingResponse(BytesIO(result_bytes), media_type="image/jpeg")
61
+
62
+ except Exception as e:
63
+ raise HTTPException(status_code=500, detail=str(e))
64