Luisgust commited on
Commit
1ff6187
·
verified ·
1 Parent(s): 7dca8cc

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +19 -3
main.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from fastapi import FastAPI, File, UploadFile, Form
3
  from fastapi.responses import JSONResponse
4
  from gradio_client import Client, handle_file
@@ -18,18 +19,33 @@ except Exception as e:
18
  @app.post("/get_caption")
19
  async def get_caption(image: UploadFile = File(...), context: str = Form(...)):
20
  try:
21
- contents = await image.read()
22
- # Remove orig_name if it's not supported by handle_file
23
- image_data = handle_file(contents)
 
 
 
24
 
 
 
 
 
25
  result = client.predict(
26
  image=image_data,
27
  question=context,
28
  api_name="/answer_question"
29
  )
30
 
 
31
  return JSONResponse(content={"caption": result})
 
32
  except Exception as e:
33
  print(f"Error during prediction: {e}")
34
  return JSONResponse(content={"error": str(e)}, status_code=500)
35
 
 
 
 
 
 
 
 
1
  import os
2
+ import tempfile
3
  from fastapi import FastAPI, File, UploadFile, Form
4
  from fastapi.responses import JSONResponse
5
  from gradio_client import Client, handle_file
 
19
  @app.post("/get_caption")
20
  async def get_caption(image: UploadFile = File(...), context: str = Form(...)):
21
  try:
22
+ # Create a temporary file
23
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
24
+ # Write the uploaded file contents to the temp file
25
+ contents = await image.read()
26
+ temp_file.write(contents)
27
+ temp_file_path = temp_file.name
28
 
29
+ # Use the temporary file path with handle_file or any other processing
30
+ image_data = handle_file(temp_file_path)
31
+
32
+ # Call the Gradio API
33
  result = client.predict(
34
  image=image_data,
35
  question=context,
36
  api_name="/answer_question"
37
  )
38
 
39
+ # Return the result as a JSON response
40
  return JSONResponse(content={"caption": result})
41
+
42
  except Exception as e:
43
  print(f"Error during prediction: {e}")
44
  return JSONResponse(content={"error": str(e)}, status_code=500)
45
 
46
+ finally:
47
+ # Remove the temporary file
48
+ if os.path.exists(temp_file_path):
49
+ os.remove(temp_file_path)
50
+
51
+ # Run the server with: uvicorn main:app --reload