Luisgust commited on
Commit
545c27e
·
verified ·
1 Parent(s): cd344ff

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +55 -0
main.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from fastapi import FastAPI, File, UploadFile, Form
4
+ from fastapi.responses import JSONResponse
5
+ from gradio_client import Client, handle_file
6
+ from deep_translator import GoogleTranslator
7
+
8
+ app = FastAPI()
9
+
10
+ HF_TOKEN = os.getenv("HF_TOKEN")
11
+ if not HF_TOKEN:
12
+ raise ValueError("HF_TOKEN environment variable is not set.")
13
+
14
+ try:
15
+ client = Client("Luisgust/moondream1", hf_token=HF_TOKEN)
16
+ except Exception as e:
17
+ print(f"Failed to initialize Gradio client: {e}")
18
+ raise
19
+
20
+ @app.post("/get_caption")
21
+ async def get_caption(image: UploadFile = File(...), context: str = Form(...)):
22
+ try:
23
+ # Create a temporary file
24
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
25
+ # Write the uploaded file contents to the temp file
26
+ contents = await image.read()
27
+ temp_file.write(contents)
28
+ temp_file_path = temp_file.name
29
+
30
+ # Use the temporary file path with handle_file or any other processing
31
+ image_data = handle_file(temp_file_path)
32
+
33
+ # Call the Gradio API to get the description
34
+ description = client.predict(
35
+ image=image_data,
36
+ question=context,
37
+ api_name="/answer_question"
38
+ )
39
+
40
+ # Translate the description to Arabic
41
+ translator = GoogleTranslator(source='auto', target='ar')
42
+ translated_description = translator.translate(description)
43
+
44
+ # Return the translated result as a JSON response
45
+ return JSONResponse(content={"caption": translated_description})
46
+
47
+ except Exception as e:
48
+ print(f"Error during prediction: {e}")
49
+ return JSONResponse(content={"error": str(e)}, status_code=500)
50
+
51
+ finally:
52
+ # Remove the temporary file
53
+ if os.path.exists(temp_file_path):
54
+ os.remove(temp_file_path)
55
+