Spaces:
Running
Running
add
Browse files- README.md +14 -19
- app.py +69 -0
- bg_rm_with_fastapi.py +193 -0
README.md
CHANGED
@@ -1,19 +1,14 @@
|
|
1 |
-
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk:
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
short_description:
|
12 |
-
---
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
|
17 |
-
|
18 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
19 |
-
forums](https://discuss.streamlit.io).
|
|
|
1 |
+
---
|
2 |
+
title: Backgraound Remove
|
3 |
+
emoji: π
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.44.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
short_description: remove_background
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import types
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
# Suppress specific timm warning
|
6 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="timm.models.layers")
|
7 |
+
|
8 |
+
# Monkey-patch torch.classes path issue
|
9 |
+
if not hasattr(sys.modules.get("torch"), "__path__"):
|
10 |
+
torch_classes = types.SimpleNamespace(_path=[])
|
11 |
+
sys.modules["torch.classes"] = torch_classes
|
12 |
+
|
13 |
+
import streamlit as st
|
14 |
+
import cv2
|
15 |
+
import numpy as np
|
16 |
+
from io import BytesIO
|
17 |
+
from PIL import Image
|
18 |
+
from transparent_background import Remover
|
19 |
+
|
20 |
+
def process_image(image):
|
21 |
+
try:
|
22 |
+
remover = Remover()
|
23 |
+
img = Image.open(image).convert("RGB")
|
24 |
+
|
25 |
+
with st.spinner("Processing image..."):
|
26 |
+
out = remover.process(img, type="rgba")
|
27 |
+
|
28 |
+
return out # `out` is already a PIL Image
|
29 |
+
|
30 |
+
except Exception as e:
|
31 |
+
st.error(f"An error occurred: {str(e)}")
|
32 |
+
return None
|
33 |
+
|
34 |
+
def main():
|
35 |
+
st.title("Image Upload and Processing App")
|
36 |
+
|
37 |
+
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png", "tif", "tiff"])
|
38 |
+
|
39 |
+
if uploaded_file is not None:
|
40 |
+
try:
|
41 |
+
image = Image.open(uploaded_file)
|
42 |
+
|
43 |
+
# Convert unsupported image modes
|
44 |
+
if image.mode == 'I;16':
|
45 |
+
image = image.point(lambda i: i * (1.0 / 256)).convert('RGB')
|
46 |
+
|
47 |
+
st.image(image, caption="Uploaded Image", use_container_width=True)
|
48 |
+
|
49 |
+
processed_pil = process_image(uploaded_file)
|
50 |
+
|
51 |
+
if processed_pil:
|
52 |
+
st.image(processed_pil, caption="Processed Image", use_container_width=True)
|
53 |
+
|
54 |
+
buf = BytesIO()
|
55 |
+
processed_pil.save(buf, format="PNG")
|
56 |
+
byte_im = buf.getvalue()
|
57 |
+
|
58 |
+
st.download_button(
|
59 |
+
label="Download Processed Image",
|
60 |
+
data=byte_im,
|
61 |
+
file_name="processed_image.png",
|
62 |
+
mime="image/png"
|
63 |
+
)
|
64 |
+
|
65 |
+
except Exception as e:
|
66 |
+
st.error(f"Error loading image: {e}")
|
67 |
+
|
68 |
+
if __name__ == "__main__":
|
69 |
+
main()
|
bg_rm_with_fastapi.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
|
2 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
4 |
+
import uvicorn
|
5 |
+
import requests
|
6 |
+
from PIL import Image
|
7 |
+
import io
|
8 |
+
import warnings
|
9 |
+
from transparent_background import Remover
|
10 |
+
import ssl
|
11 |
+
import torch
|
12 |
+
import re
|
13 |
+
import json
|
14 |
+
import numpy as np
|
15 |
+
from torch.quantization import quantize_dynamic
|
16 |
+
from transformers import CLIPProcessor, CLIPModel
|
17 |
+
from langchain_ollama import OllamaLLM
|
18 |
+
|
19 |
+
# Disable SSL verification and warnings
|
20 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
21 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
22 |
+
warnings.filterwarnings('ignore', category=FutureWarning)
|
23 |
+
|
24 |
+
llm = OllamaLLM(model="llama2",base_url = "http://localhost:11434",system="you are an jewellery expert",temperature=0.0)
|
25 |
+
model=CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
26 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
27 |
+
|
28 |
+
quantized_model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) #8-bit quantized model
|
29 |
+
|
30 |
+
j_type=["necklace", "finger ring","single earring","earrings","necklace without chain","bangels","pendant without chain"]
|
31 |
+
p_gem=["diamond center stone", "ruby center stone", "emerald center stone", "sapphire center stone", "amethyst center stone", "pearl center stone", "topaz center stone", "opal center stone", "garnet center stone", "aquamarine center stone"]
|
32 |
+
s_gem=["surrounded by small diamond","surounded by nothing or no secondary stone"]
|
33 |
+
design=[ "modern design", "classic design", "minimalist design", "flower design","round shaped", "oval shaped", "square shaped", "cushion shaped", "pear shaped"]
|
34 |
+
size=["small size", "medium size", "large size"]
|
35 |
+
metal=["gold", "silver"]
|
36 |
+
# occasion=["wedding occasion", "casual occasion", "formal occasion", "party occasion", "gifting ", "travel"]
|
37 |
+
# t_audience=["women", "men", "teen", "fashionista", "casual"]
|
38 |
+
t_audience=["women", "men"]
|
39 |
+
visual_desc=["dazzling", "radiant", "glittering", "shimmering", "captivating", "bold", "playful", "charming"]
|
40 |
+
|
41 |
+
t=[j_type,p_gem,s_gem,design,size,metal,t_audience,visual_desc]
|
42 |
+
|
43 |
+
app = FastAPI()
|
44 |
+
def generating_prompt(image):
|
45 |
+
lst1=[]
|
46 |
+
image=image
|
47 |
+
#add the path of image to generate description
|
48 |
+
for items in t:
|
49 |
+
inputs = processor(text=items, images=image, return_tensors="pt", padding=True)
|
50 |
+
# with torch.cuda.amp.autocast():
|
51 |
+
outputs = quantized_model(**inputs)
|
52 |
+
# print(outputs)
|
53 |
+
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
54 |
+
probs = logits_per_image.softmax(dim=1).detach().numpy()
|
55 |
+
probs=np.array(probs)
|
56 |
+
# print(probs)
|
57 |
+
indes=np.argmax(probs)
|
58 |
+
lst1.append(items[indes])
|
59 |
+
res = llm.invoke(f"generate the description(2 to 4 lines) and title(3 to 5 words) of a object from the given features :{str(lst1)}")
|
60 |
+
text = res
|
61 |
+
substring = "Title:"
|
62 |
+
desc="Description:"
|
63 |
+
match0 = re.search(substring, text)
|
64 |
+
match1 = re.search(desc,text)
|
65 |
+
if match0 and match1:
|
66 |
+
title=text[match0.start():match1.start()]
|
67 |
+
description = text[match1.start():]
|
68 |
+
X = title.split(":")
|
69 |
+
y = description.split(":")
|
70 |
+
di = {X[0]:X[1],y[0]:y[1]}
|
71 |
+
json_object = json.dumps(di)
|
72 |
+
return json_object
|
73 |
+
else:
|
74 |
+
return f"The substring '{substring}' is not found."
|
75 |
+
# Enable CORS
|
76 |
+
app.add_middleware(
|
77 |
+
CORSMiddleware,
|
78 |
+
allow_origins=["*"], # Change this to specific domains if needed
|
79 |
+
allow_credentials=True,
|
80 |
+
allow_methods=["*"],
|
81 |
+
allow_headers=["*"],
|
82 |
+
)
|
83 |
+
|
84 |
+
@app.get("/")
|
85 |
+
def home():
|
86 |
+
return {"status": "running", "message": "Background removal service is operational", "version": "1.0"}
|
87 |
+
|
88 |
+
@app.post("/remove-background")
|
89 |
+
async def remove_background(
|
90 |
+
request: Request,
|
91 |
+
image: UploadFile = File(None),
|
92 |
+
imageUrl: str = Form(None),
|
93 |
+
backgroundColor: str = Form(None)):
|
94 |
+
try:
|
95 |
+
input_image = None
|
96 |
+
|
97 |
+
# Handle JSON request
|
98 |
+
if request.headers.get("content-type") == "application/json":
|
99 |
+
data = await request.json()
|
100 |
+
imageUrl = data.get("imageUrl")
|
101 |
+
backgroundColor = data.get("backgroundColor")
|
102 |
+
|
103 |
+
if image:
|
104 |
+
# Handle direct image upload
|
105 |
+
input_image = Image.open(io.BytesIO(await image.read()))
|
106 |
+
elif imageUrl:
|
107 |
+
# Handle image URL
|
108 |
+
response = requests.get(imageUrl)
|
109 |
+
if response.status_code != 200:
|
110 |
+
raise HTTPException(status_code=400, detail="Failed to fetch image from URL")
|
111 |
+
input_image = Image.open(io.BytesIO(response.content))
|
112 |
+
else:
|
113 |
+
raise HTTPException(status_code=400, detail="No image or image URL provided")
|
114 |
+
|
115 |
+
# Initialize remover
|
116 |
+
remover = Remover()
|
117 |
+
|
118 |
+
# Convert input_image to RGB mode
|
119 |
+
input_image = input_image.convert('RGB')
|
120 |
+
|
121 |
+
# Remove background using new method
|
122 |
+
output_image = remover.process(input_image, type='rgba'
|
123 |
+
|
124 |
+
)
|
125 |
+
|
126 |
+
# If background color is specified, apply it
|
127 |
+
if backgroundColor:
|
128 |
+
# Convert hex to RGB
|
129 |
+
bg_color = tuple(int(backgroundColor.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))
|
130 |
+
|
131 |
+
# Create new image with background color
|
132 |
+
background = Image.new('RGBA', output_image.size, bg_color + (255,))
|
133 |
+
# Use alpha channel as mask
|
134 |
+
background.paste(output_image, (0, 0), output_image)
|
135 |
+
output_image = background
|
136 |
+
|
137 |
+
# Save to buffer
|
138 |
+
output_buffer = io.BytesIO()
|
139 |
+
output_image.save(output_buffer, format='PNG')
|
140 |
+
output_buffer.seek(0)
|
141 |
+
|
142 |
+
return StreamingResponse(output_buffer, media_type="image/png", headers={"Content-Disposition": "attachment; filename=removed_bg.png"})
|
143 |
+
|
144 |
+
except Exception as e:
|
145 |
+
print(f"Error processing image: {e}")
|
146 |
+
raise HTTPException(status_code=500, detail=str(e))
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
@app.post("/description_gen")
|
151 |
+
async def description_gen(
|
152 |
+
request: Request,
|
153 |
+
image: UploadFile = File(None),
|
154 |
+
imageUrl: str = Form(None) ):
|
155 |
+
try:
|
156 |
+
input_image = None
|
157 |
+
|
158 |
+
# Handle JSON request
|
159 |
+
if request.headers.get("content-type") == "application/json":
|
160 |
+
data = await request.json()
|
161 |
+
imageUrl = data.get("imageUrl")
|
162 |
+
|
163 |
+
if image:
|
164 |
+
# Handle direct image upload
|
165 |
+
input_image = Image.open(io.BytesIO(await image.read()))
|
166 |
+
elif imageUrl:
|
167 |
+
# Handle image URL
|
168 |
+
response = requests.get(imageUrl)
|
169 |
+
if response.status_code != 200:
|
170 |
+
raise HTTPException(status_code=400, detail="Failed to fetch image from URL")
|
171 |
+
input_image = Image.open(io.BytesIO(response.content))
|
172 |
+
else:
|
173 |
+
raise HTTPException(status_code=400, detail="No image or image URL provided")
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
# Convert input_image to RGB mode
|
178 |
+
input_image = input_image.convert('RGB')
|
179 |
+
output = generating_prompt(input_image)
|
180 |
+
|
181 |
+
return StreamingResponse(output, media_type="text/json", headers={"Content-Disposition": "attachment; filename=discription.json"})
|
182 |
+
|
183 |
+
except Exception as e:
|
184 |
+
print(f"Error processing image: {e}")
|
185 |
+
raise HTTPException(status_code=500, detail=str(e))
|
186 |
+
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
if __name__ == "__main__":
|
193 |
+
uvicorn.run(app, host="127.0.0.1", port=8000)
|