keval-fst commited on
Commit
56f01e9
·
verified ·
1 Parent(s): e637721

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -21
app.py CHANGED
@@ -1,20 +1,20 @@
 
 
1
  from fastapi import FastAPI, File, UploadFile, Form
2
  from fastapi.responses import JSONResponse
3
- from transformers import pipeline
4
- from io import BytesIO
5
  from PIL import Image
6
- import os
7
-
8
- hf_token = os.environ.get("HF_TOKEN")
9
 
10
  app = FastAPI()
11
 
12
- # Load pipeline once
13
- pipe = pipeline("image-text-to-text", model="google/medgemma-4b-pt",use_auth_token=hf_token)
14
 
15
- @app.get("/")
16
- def root():
17
- return {"message": "MedGemma Image-to-Text API running"}
 
18
 
19
  @app.post("/analyze/")
20
  async def analyze_image(prompt: str = Form(...), file: UploadFile = File(...)):
@@ -22,17 +22,28 @@ async def analyze_image(prompt: str = Form(...), file: UploadFile = File(...)):
22
  image_data = await file.read()
23
  image = Image.open(BytesIO(image_data)).convert("RGB")
24
 
25
- response = pipe([
26
- {
27
- "role": "user",
28
- "content": [
29
- {"type": "image", "image": image},
30
- {"type": "text", "text": prompt}
31
- ]
32
- }
33
- ])
34
-
35
- return {"result": response[0]['generated_text']}
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  except Exception as e:
38
  return JSONResponse(status_code=500, content={"error": str(e)})
 
1
+ import os
2
+ import requests
3
  from fastapi import FastAPI, File, UploadFile, Form
4
  from fastapi.responses import JSONResponse
 
 
5
  from PIL import Image
6
+ from io import BytesIO
7
+ import base64
 
8
 
9
  app = FastAPI()
10
 
11
+ HF_API_URL = "https://api-inference.huggingface.co/models/google/medgemma-27b-it"
12
+ HF_TOKEN = os.getenv("HF_TOKEN")
13
 
14
+ headers = {
15
+ "Authorization": f"Bearer {HF_TOKEN}",
16
+ "Content-Type": "application/json"
17
+ }
18
 
19
  @app.post("/analyze/")
20
  async def analyze_image(prompt: str = Form(...), file: UploadFile = File(...)):
 
22
  image_data = await file.read()
23
  image = Image.open(BytesIO(image_data)).convert("RGB")
24
 
25
+ buffered = BytesIO()
26
+ image.save(buffered, format="JPEG")
27
+ img_str = base64.b64encode(buffered.getvalue()).decode()
28
+
29
+ payload = {
30
+ "inputs": [
31
+ {
32
+ "role": "user",
33
+ "content": [
34
+ {"type": "image", "image": img_str},
35
+ {"type": "text", "text": prompt}
36
+ ]
37
+ }
38
+ ]
39
+ }
40
+
41
+ response = requests.post(HF_API_URL, headers=headers, json=payload)
42
+
43
+ if response.status_code != 200:
44
+ return JSONResponse(status_code=response.status_code, content=response.json())
45
+
46
+ return {"result": response.json()[0]['generated_text']}
47
 
48
  except Exception as e:
49
  return JSONResponse(status_code=500, content={"error": str(e)})