|
import os |
|
import tempfile |
|
from fastapi import FastAPI, File, UploadFile, Form |
|
from fastapi.responses import JSONResponse |
|
from gradio_client import Client, handle_file |
|
from deep_translator import GoogleTranslator |
|
|
|
app = FastAPI() |
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
if not HF_TOKEN: |
|
raise ValueError("HF_TOKEN environment variable is not set.") |
|
|
|
try: |
|
client = Client("Luisgust/moondream1", hf_token=HF_TOKEN) |
|
except Exception as e: |
|
print(f"Failed to initialize Gradio client: {e}") |
|
raise |
|
|
|
@app.post("/get_caption") |
|
async def get_caption(image: UploadFile = File(...), context: str = Form(...)): |
|
try: |
|
|
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file: |
|
|
|
contents = await image.read() |
|
temp_file.write(contents) |
|
temp_file_path = temp_file.name |
|
|
|
|
|
image_data = handle_file(temp_file_path) |
|
|
|
|
|
description = client.predict( |
|
image=image_data, |
|
question=context, |
|
api_name="/answer_question" |
|
) |
|
|
|
|
|
translator = GoogleTranslator(source='auto', target='ar') |
|
translated_description = translator.translate(description) |
|
|
|
|
|
return JSONResponse(content={"caption": translated_description}) |
|
|
|
except Exception as e: |
|
print(f"Error during prediction: {e}") |
|
return JSONResponse(content={"error": str(e)}, status_code=500) |
|
|
|
finally: |
|
|
|
if os.path.exists(temp_file_path): |
|
os.remove(temp_file_path) |
|
|
|
|