DRgaddam commited on
Commit
1317962
·
verified ·
1 Parent(s): 83fd4be

added file

Browse files
Files changed (1) hide show
  1. bg_rm_with_fastapi.py +193 -0
bg_rm_with_fastapi.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
2
+ from fastapi.responses import JSONResponse, StreamingResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import uvicorn
5
+ import requests
6
+ from PIL import Image
7
+ import io
8
+ import warnings
9
+ from transparent_background import Remover
10
+ import ssl
11
+ import torch
12
+ import re
13
+ import json
14
+ import numpy as np
15
+ from torch.quantization import quantize_dynamic
16
+ from transformers import CLIPProcessor, CLIPModel
17
+ from langchain_ollama import OllamaLLM
18
+
19
+ # Disable SSL verification and warnings
20
+ ssl._create_default_https_context = ssl._create_unverified_context
21
+ warnings.filterwarnings('ignore', category=UserWarning)
22
+ warnings.filterwarnings('ignore', category=FutureWarning)
23
+
24
+ llm = OllamaLLM(model="llama2",base_url = "http://localhost:11434",system="you are an jewellery expert",temperature=0.0)
25
+ model=CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
26
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
27
+
28
+ quantized_model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) #8-bit quantized model
29
+
30
+ j_type=["necklace", "finger ring","single earring","earrings","necklace without chain","bangels","pendant without chain"]
31
+ p_gem=["diamond center stone", "ruby center stone", "emerald center stone", "sapphire center stone", "amethyst center stone", "pearl center stone", "topaz center stone", "opal center stone", "garnet center stone", "aquamarine center stone"]
32
+ s_gem=["surrounded by small diamond","surounded by nothing or no secondary stone"]
33
+ design=[ "modern design", "classic design", "minimalist design", "flower design","round shaped", "oval shaped", "square shaped", "cushion shaped", "pear shaped"]
34
+ size=["small size", "medium size", "large size"]
35
+ metal=["gold", "silver"]
36
+ # occasion=["wedding occasion", "casual occasion", "formal occasion", "party occasion", "gifting ", "travel"]
37
+ # t_audience=["women", "men", "teen", "fashionista", "casual"]
38
+ t_audience=["women", "men"]
39
+ visual_desc=["dazzling", "radiant", "glittering", "shimmering", "captivating", "bold", "playful", "charming"]
40
+
41
+ t=[j_type,p_gem,s_gem,design,size,metal,t_audience,visual_desc]
42
+
43
+ app = FastAPI()
44
+ def generating_prompt(image):
45
+ lst1=[]
46
+ image=image
47
+ #add the path of image to generate description
48
+ for items in t:
49
+ inputs = processor(text=items, images=image, return_tensors="pt", padding=True)
50
+ # with torch.cuda.amp.autocast():
51
+ outputs = quantized_model(**inputs)
52
+ # print(outputs)
53
+ logits_per_image = outputs.logits_per_image # this is the image-text similarity score
54
+ probs = logits_per_image.softmax(dim=1).detach().numpy()
55
+ probs=np.array(probs)
56
+ # print(probs)
57
+ indes=np.argmax(probs)
58
+ lst1.append(items[indes])
59
+ res = llm.invoke(f"generate the description(2 to 4 lines) and title(3 to 5 words) of a object from the given features :{str(lst1)}")
60
+ text = res
61
+ substring = "Title:"
62
+ desc="Description:"
63
+ match0 = re.search(substring, text)
64
+ match1 = re.search(desc,text)
65
+ if match0 and match1:
66
+ title=text[match0.start():match1.start()]
67
+ description = text[match1.start():]
68
+ X = title.split(":")
69
+ y = description.split(":")
70
+ di = {X[0]:X[1],y[0]:y[1]}
71
+ json_object = json.dumps(di)
72
+ return json_object
73
+ else:
74
+ return f"The substring '{substring}' is not found."
75
+ # Enable CORS
76
+ app.add_middleware(
77
+ CORSMiddleware,
78
+ allow_origins=["*"], # Change this to specific domains if needed
79
+ allow_credentials=True,
80
+ allow_methods=["*"],
81
+ allow_headers=["*"],
82
+ )
83
+
84
+ @app.get("/")
85
+ def home():
86
+ return {"status": "running", "message": "Background removal service is operational", "version": "1.0"}
87
+
88
+ @app.post("/remove-background")
89
+ async def remove_background(
90
+ request: Request,
91
+ image: UploadFile = File(None),
92
+ imageUrl: str = Form(None),
93
+ backgroundColor: str = Form(None)):
94
+ try:
95
+ input_image = None
96
+
97
+ # Handle JSON request
98
+ if request.headers.get("content-type") == "application/json":
99
+ data = await request.json()
100
+ imageUrl = data.get("imageUrl")
101
+ backgroundColor = data.get("backgroundColor")
102
+
103
+ if image:
104
+ # Handle direct image upload
105
+ input_image = Image.open(io.BytesIO(await image.read()))
106
+ elif imageUrl:
107
+ # Handle image URL
108
+ response = requests.get(imageUrl)
109
+ if response.status_code != 200:
110
+ raise HTTPException(status_code=400, detail="Failed to fetch image from URL")
111
+ input_image = Image.open(io.BytesIO(response.content))
112
+ else:
113
+ raise HTTPException(status_code=400, detail="No image or image URL provided")
114
+
115
+ # Initialize remover
116
+ remover = Remover()
117
+
118
+ # Convert input_image to RGB mode
119
+ input_image = input_image.convert('RGB')
120
+
121
+ # Remove background using new method
122
+ output_image = remover.process(input_image, type='rgba'
123
+
124
+ )
125
+
126
+ # If background color is specified, apply it
127
+ if backgroundColor:
128
+ # Convert hex to RGB
129
+ bg_color = tuple(int(backgroundColor.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))
130
+
131
+ # Create new image with background color
132
+ background = Image.new('RGBA', output_image.size, bg_color + (255,))
133
+ # Use alpha channel as mask
134
+ background.paste(output_image, (0, 0), output_image)
135
+ output_image = background
136
+
137
+ # Save to buffer
138
+ output_buffer = io.BytesIO()
139
+ output_image.save(output_buffer, format='PNG')
140
+ output_buffer.seek(0)
141
+
142
+ return StreamingResponse(output_buffer, media_type="image/png", headers={"Content-Disposition": "attachment; filename=removed_bg.png"})
143
+
144
+ except Exception as e:
145
+ print(f"Error processing image: {e}")
146
+ raise HTTPException(status_code=500, detail=str(e))
147
+
148
+
149
+
150
+ @app.post("/description_gen")
151
+ async def description_gen(
152
+ request: Request,
153
+ image: UploadFile = File(None),
154
+ imageUrl: str = Form(None) ):
155
+ try:
156
+ input_image = None
157
+
158
+ # Handle JSON request
159
+ if request.headers.get("content-type") == "application/json":
160
+ data = await request.json()
161
+ imageUrl = data.get("imageUrl")
162
+
163
+ if image:
164
+ # Handle direct image upload
165
+ input_image = Image.open(io.BytesIO(await image.read()))
166
+ elif imageUrl:
167
+ # Handle image URL
168
+ response = requests.get(imageUrl)
169
+ if response.status_code != 200:
170
+ raise HTTPException(status_code=400, detail="Failed to fetch image from URL")
171
+ input_image = Image.open(io.BytesIO(response.content))
172
+ else:
173
+ raise HTTPException(status_code=400, detail="No image or image URL provided")
174
+
175
+
176
+
177
+ # Convert input_image to RGB mode
178
+ input_image = input_image.convert('RGB')
179
+ output = generating_prompt(input_image)
180
+
181
+ return StreamingResponse(output, media_type="text/json", headers={"Content-Disposition": "attachment; filename=discription.json"})
182
+
183
+ except Exception as e:
184
+ print(f"Error processing image: {e}")
185
+ raise HTTPException(status_code=500, detail=str(e))
186
+
187
+
188
+
189
+
190
+
191
+
192
+ if __name__ == "__main__":
193
+ uvicorn.run(app, host="127.0.0.1", port=8000)