style-master / main.py
inflaton's picture
add back images of garments
64d6682
"""Main entrypoint for the app."""
import base64
from io import BytesIO
import os
from numpy import ndarray
import requests
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from urllib.parse import urlparse
import gradio as gr
from style_master.LLMRecommender import LLMRecommender, make_request_with_retry
from PIL import Image
static_dir = os.environ.get("STATIC_FILES_PATH") or "data/Assets"
def ndarray_to_base64(image):
# Convert ndarray image to PIL image
pil_image = Image.fromarray(image)
# Create a BytesIO object
buffered = BytesIO()
# Save PIL image to buffer
pil_image.save(buffered, format="PNG")
# Get image bytes
img_bytes = buffered.getvalue()
# Encode bytes to base64 and decode to string
img_base64 = base64.b64encode(img_bytes).decode()
return img_base64
# Function to encode the image
def encode_image(image_path_or_url_or_ndarray):
if isinstance(image_path_or_url_or_ndarray, ndarray):
image_base64 = ndarray_to_base64(image_path_or_url_or_ndarray)
elif image_path_or_url_or_ndarray.startswith("http"):
response = requests.get(image_path_or_url_or_ndarray)
buffered = BytesIO(response.content)
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
else:
with open(image_path_or_url_or_ndarray, "rb") as image_file:
image_base64 = base64.b64encode(image_file.read()).decode("utf-8")
return image_base64
mock_ootd_api_responses = {
"sessions": {
"profile": {
"id": "123",
"name": "Sam",
}
},
"try-ons": {"try_on_images": []},
}
def call_ootd_api(name, data):
print(f"Calling {name} with data: {data if name != 'sessions' else '***'}")
return mock_ootd_api_responses.get(name, {})
# url = f"{ootd_server_url}/{name}"
# headers = {"Content-Type": "application/json"}
# response = make_request_with_retry(
# url, headers, data, suppress_data=name == "sessions"
# )
# response_json = response.json()
# return response_json
llmr = LLMRecommender()
def call_try_on_api(garment_ids):
data = {"session_id": f"{profile['id']}", "garment_ids": garment_ids}
response = call_ootd_api("try-ons", data)
# if len(garment_ids) == 2:
# products = [llmr.get_garment(id) for id in garment_ids]
# if products[0]["category"] != products[1]["category"]:
# data2 = {
# "session_id": f"{profile['id']}",
# "model_filename": response["try_on_images"][0]["filename"],
# "garment_ids": [garment_ids[1]],
# }
# response2 = call_ootd_api("try-ons", data2)
# response["try_on_images"] += response2["try_on_images"]
# response["try_on_images"][-1]["extra_garment_id"] = garment_ids[0]
return response
def login(image_file_or_ndarray):
print("login:", type(image_file_or_ndarray))
data = {"encodedData": encode_image(image_file_or_ndarray)}
response = call_ootd_api("sessions", data)
profile = response["profile"]
print(response)
profile["gender"] = llmr.login(profile["name"], data["encodedData"])
try_on_images_folder = f"{static_dir}/try-ons/{profile['id']}"
os.makedirs(try_on_images_folder, exist_ok=True)
return profile, try_on_images_folder
share_gradio_app = os.environ.get("SHARE_GRADIO_APP") == "true"
ootd_server_url = os.environ.get("OOTD_SERVER_URL")
def predict(message, history):
print("predict:", message)
response = llmr.invoke(message)
print(response)
if response["intent"] in ["unknown", "checkout"]:
partial_message = f"{response['message']}"
elif response["intent"] == "recommendation":
partial_message = f"Here are {len(response['products'])} recommendation{'s' if len(response['products']) > 1 else ''} for you:"
elif response["intent"] == "try-on":
partial_message = f"For {len(response['products'])} garment{'s' if len(response['products']) > 1 else ''} you've selected:"
elif response["intent"] == "add-to-cart":
partial_message = f"Added {len(response['products'])} item{'s' if len(response['products']) > 1 else ''} into your cart:"
elif response["intent"] == "view-cart":
partial_message = f"There {'are' if len(response['products']) > 1 else 'is'} {len(response['products'])} item{'s' if len(response['products']) > 1 else ''} in your cart{':' if len(response['products']) > 0 else '.'}"
else:
partial_message = f"{response}"
if "products" in response and response["products"]:
partial_message += "\n\n"
for id in response["products"]:
product = llmr.get_garment(id)
url = product["image"]
parsed_url = urlparse(url)
title = f"#{product['id']} {product['name']}: ${product['price']}"
partial_message += f"1. [{title}]({url})\n"
partial_message += f"![{title}]({parsed_url.path})\n"
if response["intent"] == "try-on":
response = call_try_on_api(response["products"])
print(response)
partial_message += f"\n\nwe have created {len(response['try_on_images'])} try-on image{'s:' if len(response['try_on_images']) > 1 else '.'}"
partial_message += "\n\n"
for try_on_image in response["try_on_images"]:
url = try_on_image["url"]
file_name = try_on_image["filename"]
file_path = f"{try_on_images_folder}/{file_name}"
response = make_request_with_retry(url)
with open(file_path, "wb") as f:
f.write(response.content)
title = f"{profile['name']} wearing "
extra_garment_id = try_on_image.get("extra_garment_id")
if extra_garment_id:
extra_product = llmr.get_garment(extra_garment_id)
title += f"{extra_product['name']} and "
id = try_on_image["garment_id"]
product = llmr.get_garment(id)
title += f"{product['name']}"
partial_message += f"1. [{title}]({url})\n"
parsed_url = urlparse(url)
partial_message += f"![{title}]({parsed_url.path})\n"
yield partial_message
app = FastAPI()
model = os.environ.get("OPENAI_MODEL_NAME")
href = "https://platform.openai.com/docs/models"
title = "Style Master"
questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
# Open the file for reading
with open(questions_file_path, "r") as file:
examples = file.readlines()
examples = [example.strip() for example in examples]
description = f"""\
<div align="left">
<p> Powered by: <a href="{href}">{model}</a></p>
</div>
"""
css = """
.container {
height: 100vh;
}
#image_upload{min-height:400px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
#image_upload .touch-none{display: flex}
"""
def get_all_garments():
print("get_all_garments")
return [
(garment["image"], f"#{garment['id']} {garment['name']}: ${garment['price']}")
for garment in llmr.get_garments()
]
def upload_model_image(model_path=None, model_image=None):
print("upload_model_image:", model_path if model_path else type(model_image))
global profile, try_on_images_folder
profile, try_on_images_folder = login(model_path or model_image)
print(profile)
if model_path is None:
gender = profile["gender"]
model_path = f"{static_dir}/models/{gender}.png"
print("refresh model:", model_path)
image = Image.open(model_path)
return image
def visible_component():
return gr.update(visible=True), gr.update(visible=True)
def use_preset_model_image(model_image):
image = upload_model_image(model_image=model_image)
x, y = visible_component()
return image, x, y
with gr.Blocks(title=title, css=css) as demo:
with gr.Tab(f"Chat with {title}", visible=False) as main_tab:
with gr.Column(elem_classes=["container"]):
# Setting up the Gradio chat interface.
chat_ui = gr.ChatInterface(
predict,
# title=title,
description=description,
examples=examples,
cache_examples=False,
)
chat_ui.clear_btn.click(lambda: llmr.login(profile["name"]))
with gr.Tab("View Catalogue", visible=False) as catalogue_tab:
gallery = gr.Gallery(
get_all_garments,
label="Generated images",
show_label=False,
elem_id="gallery",
columns=[5],
rows=[4],
object_fit="contain",
height=800,
)
with gr.Tab("Change Model"):
with gr.Column():
model_image = gr.Image(elem_id="image_upload", width=600)
upload_button = gr.UploadButton(
"Upload Image", file_types=["image"], file_count="single"
)
upload_button.upload(upload_model_image, upload_button, model_image).then(
fn=visible_component, inputs=None, outputs=[main_tab, catalogue_tab]
)
example = gr.Examples(
label="Preset Models",
inputs=model_image,
outputs=[model_image, main_tab, catalogue_tab],
fn=use_preset_model_image,
run_on_click=True,
examples_per_page=2,
examples=[
os.path.join(static_dir, "models/female.png"),
os.path.join(static_dir, "models/male.png"),
],
)
demo.queue()
app.mount(
"/static", StaticFiles(directory=static_dir, follow_symlink=True), name="static"
)
app = gr.mount_gradio_app(app, demo, path="/")