added file
Browse files- bg_rm_with_fastapi.py +193 -0
bg_rm_with_fastapi.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
|
2 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
4 |
+
import uvicorn
|
5 |
+
import requests
|
6 |
+
from PIL import Image
|
7 |
+
import io
|
8 |
+
import warnings
|
9 |
+
from transparent_background import Remover
|
10 |
+
import ssl
|
11 |
+
import torch
|
12 |
+
import re
|
13 |
+
import json
|
14 |
+
import numpy as np
|
15 |
+
from torch.quantization import quantize_dynamic
|
16 |
+
from transformers import CLIPProcessor, CLIPModel
|
17 |
+
from langchain_ollama import OllamaLLM
|
18 |
+
|
19 |
+
# Disable SSL verification and warnings
|
20 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
21 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
22 |
+
warnings.filterwarnings('ignore', category=FutureWarning)
|
23 |
+
|
24 |
+
llm = OllamaLLM(model="llama2",base_url = "http://localhost:11434",system="you are an jewellery expert",temperature=0.0)
|
25 |
+
model=CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
26 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
27 |
+
|
28 |
+
quantized_model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) #8-bit quantized model
|
29 |
+
|
30 |
+
j_type=["necklace", "finger ring","single earring","earrings","necklace without chain","bangels","pendant without chain"]
|
31 |
+
p_gem=["diamond center stone", "ruby center stone", "emerald center stone", "sapphire center stone", "amethyst center stone", "pearl center stone", "topaz center stone", "opal center stone", "garnet center stone", "aquamarine center stone"]
|
32 |
+
s_gem=["surrounded by small diamond","surounded by nothing or no secondary stone"]
|
33 |
+
design=[ "modern design", "classic design", "minimalist design", "flower design","round shaped", "oval shaped", "square shaped", "cushion shaped", "pear shaped"]
|
34 |
+
size=["small size", "medium size", "large size"]
|
35 |
+
metal=["gold", "silver"]
|
36 |
+
# occasion=["wedding occasion", "casual occasion", "formal occasion", "party occasion", "gifting ", "travel"]
|
37 |
+
# t_audience=["women", "men", "teen", "fashionista", "casual"]
|
38 |
+
t_audience=["women", "men"]
|
39 |
+
visual_desc=["dazzling", "radiant", "glittering", "shimmering", "captivating", "bold", "playful", "charming"]
|
40 |
+
|
41 |
+
t=[j_type,p_gem,s_gem,design,size,metal,t_audience,visual_desc]
|
42 |
+
|
43 |
+
app = FastAPI()
|
44 |
+
def generating_prompt(image):
|
45 |
+
lst1=[]
|
46 |
+
image=image
|
47 |
+
#add the path of image to generate description
|
48 |
+
for items in t:
|
49 |
+
inputs = processor(text=items, images=image, return_tensors="pt", padding=True)
|
50 |
+
# with torch.cuda.amp.autocast():
|
51 |
+
outputs = quantized_model(**inputs)
|
52 |
+
# print(outputs)
|
53 |
+
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
54 |
+
probs = logits_per_image.softmax(dim=1).detach().numpy()
|
55 |
+
probs=np.array(probs)
|
56 |
+
# print(probs)
|
57 |
+
indes=np.argmax(probs)
|
58 |
+
lst1.append(items[indes])
|
59 |
+
res = llm.invoke(f"generate the description(2 to 4 lines) and title(3 to 5 words) of a object from the given features :{str(lst1)}")
|
60 |
+
text = res
|
61 |
+
substring = "Title:"
|
62 |
+
desc="Description:"
|
63 |
+
match0 = re.search(substring, text)
|
64 |
+
match1 = re.search(desc,text)
|
65 |
+
if match0 and match1:
|
66 |
+
title=text[match0.start():match1.start()]
|
67 |
+
description = text[match1.start():]
|
68 |
+
X = title.split(":")
|
69 |
+
y = description.split(":")
|
70 |
+
di = {X[0]:X[1],y[0]:y[1]}
|
71 |
+
json_object = json.dumps(di)
|
72 |
+
return json_object
|
73 |
+
else:
|
74 |
+
return f"The substring '{substring}' is not found."
|
75 |
+
# Enable CORS
|
76 |
+
app.add_middleware(
|
77 |
+
CORSMiddleware,
|
78 |
+
allow_origins=["*"], # Change this to specific domains if needed
|
79 |
+
allow_credentials=True,
|
80 |
+
allow_methods=["*"],
|
81 |
+
allow_headers=["*"],
|
82 |
+
)
|
83 |
+
|
84 |
+
@app.get("/")
|
85 |
+
def home():
|
86 |
+
return {"status": "running", "message": "Background removal service is operational", "version": "1.0"}
|
87 |
+
|
88 |
+
@app.post("/remove-background")
|
89 |
+
async def remove_background(
|
90 |
+
request: Request,
|
91 |
+
image: UploadFile = File(None),
|
92 |
+
imageUrl: str = Form(None),
|
93 |
+
backgroundColor: str = Form(None)):
|
94 |
+
try:
|
95 |
+
input_image = None
|
96 |
+
|
97 |
+
# Handle JSON request
|
98 |
+
if request.headers.get("content-type") == "application/json":
|
99 |
+
data = await request.json()
|
100 |
+
imageUrl = data.get("imageUrl")
|
101 |
+
backgroundColor = data.get("backgroundColor")
|
102 |
+
|
103 |
+
if image:
|
104 |
+
# Handle direct image upload
|
105 |
+
input_image = Image.open(io.BytesIO(await image.read()))
|
106 |
+
elif imageUrl:
|
107 |
+
# Handle image URL
|
108 |
+
response = requests.get(imageUrl)
|
109 |
+
if response.status_code != 200:
|
110 |
+
raise HTTPException(status_code=400, detail="Failed to fetch image from URL")
|
111 |
+
input_image = Image.open(io.BytesIO(response.content))
|
112 |
+
else:
|
113 |
+
raise HTTPException(status_code=400, detail="No image or image URL provided")
|
114 |
+
|
115 |
+
# Initialize remover
|
116 |
+
remover = Remover()
|
117 |
+
|
118 |
+
# Convert input_image to RGB mode
|
119 |
+
input_image = input_image.convert('RGB')
|
120 |
+
|
121 |
+
# Remove background using new method
|
122 |
+
output_image = remover.process(input_image, type='rgba'
|
123 |
+
|
124 |
+
)
|
125 |
+
|
126 |
+
# If background color is specified, apply it
|
127 |
+
if backgroundColor:
|
128 |
+
# Convert hex to RGB
|
129 |
+
bg_color = tuple(int(backgroundColor.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))
|
130 |
+
|
131 |
+
# Create new image with background color
|
132 |
+
background = Image.new('RGBA', output_image.size, bg_color + (255,))
|
133 |
+
# Use alpha channel as mask
|
134 |
+
background.paste(output_image, (0, 0), output_image)
|
135 |
+
output_image = background
|
136 |
+
|
137 |
+
# Save to buffer
|
138 |
+
output_buffer = io.BytesIO()
|
139 |
+
output_image.save(output_buffer, format='PNG')
|
140 |
+
output_buffer.seek(0)
|
141 |
+
|
142 |
+
return StreamingResponse(output_buffer, media_type="image/png", headers={"Content-Disposition": "attachment; filename=removed_bg.png"})
|
143 |
+
|
144 |
+
except Exception as e:
|
145 |
+
print(f"Error processing image: {e}")
|
146 |
+
raise HTTPException(status_code=500, detail=str(e))
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
@app.post("/description_gen")
|
151 |
+
async def description_gen(
|
152 |
+
request: Request,
|
153 |
+
image: UploadFile = File(None),
|
154 |
+
imageUrl: str = Form(None) ):
|
155 |
+
try:
|
156 |
+
input_image = None
|
157 |
+
|
158 |
+
# Handle JSON request
|
159 |
+
if request.headers.get("content-type") == "application/json":
|
160 |
+
data = await request.json()
|
161 |
+
imageUrl = data.get("imageUrl")
|
162 |
+
|
163 |
+
if image:
|
164 |
+
# Handle direct image upload
|
165 |
+
input_image = Image.open(io.BytesIO(await image.read()))
|
166 |
+
elif imageUrl:
|
167 |
+
# Handle image URL
|
168 |
+
response = requests.get(imageUrl)
|
169 |
+
if response.status_code != 200:
|
170 |
+
raise HTTPException(status_code=400, detail="Failed to fetch image from URL")
|
171 |
+
input_image = Image.open(io.BytesIO(response.content))
|
172 |
+
else:
|
173 |
+
raise HTTPException(status_code=400, detail="No image or image URL provided")
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
# Convert input_image to RGB mode
|
178 |
+
input_image = input_image.convert('RGB')
|
179 |
+
output = generating_prompt(input_image)
|
180 |
+
|
181 |
+
return StreamingResponse(output, media_type="text/json", headers={"Content-Disposition": "attachment; filename=discription.json"})
|
182 |
+
|
183 |
+
except Exception as e:
|
184 |
+
print(f"Error processing image: {e}")
|
185 |
+
raise HTTPException(status_code=500, detail=str(e))
|
186 |
+
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
if __name__ == "__main__":
|
193 |
+
uvicorn.run(app, host="127.0.0.1", port=8000)
|