Spaces:
Runtime error
Runtime error
"""Main entrypoint for the app.""" | |
import base64 | |
from io import BytesIO | |
import os | |
from numpy import ndarray | |
import requests | |
from fastapi import FastAPI | |
from fastapi.staticfiles import StaticFiles | |
from urllib.parse import urlparse | |
import gradio as gr | |
from style_master.LLMRecommender import LLMRecommender, make_request_with_retry | |
from PIL import Image | |
static_dir = os.environ.get("STATIC_FILES_PATH") or "data/Assets" | |
def ndarray_to_base64(image): | |
# Convert ndarray image to PIL image | |
pil_image = Image.fromarray(image) | |
# Create a BytesIO object | |
buffered = BytesIO() | |
# Save PIL image to buffer | |
pil_image.save(buffered, format="PNG") | |
# Get image bytes | |
img_bytes = buffered.getvalue() | |
# Encode bytes to base64 and decode to string | |
img_base64 = base64.b64encode(img_bytes).decode() | |
return img_base64 | |
# Function to encode the image | |
def encode_image(image_path_or_url_or_ndarray): | |
if isinstance(image_path_or_url_or_ndarray, ndarray): | |
image_base64 = ndarray_to_base64(image_path_or_url_or_ndarray) | |
elif image_path_or_url_or_ndarray.startswith("http"): | |
response = requests.get(image_path_or_url_or_ndarray) | |
buffered = BytesIO(response.content) | |
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
else: | |
with open(image_path_or_url_or_ndarray, "rb") as image_file: | |
image_base64 = base64.b64encode(image_file.read()).decode("utf-8") | |
return image_base64 | |
mock_ootd_api_responses = { | |
"sessions": { | |
"profile": { | |
"id": "123", | |
"name": "Sam", | |
} | |
}, | |
"try-ons": {"try_on_images": []}, | |
} | |
def call_ootd_api(name, data): | |
print(f"Calling {name} with data: {data if name != 'sessions' else '***'}") | |
return mock_ootd_api_responses.get(name, {}) | |
# url = f"{ootd_server_url}/{name}" | |
# headers = {"Content-Type": "application/json"} | |
# response = make_request_with_retry( | |
# url, headers, data, suppress_data=name == "sessions" | |
# ) | |
# response_json = response.json() | |
# return response_json | |
llmr = LLMRecommender() | |
def call_try_on_api(garment_ids): | |
data = {"session_id": f"{profile['id']}", "garment_ids": garment_ids} | |
response = call_ootd_api("try-ons", data) | |
# if len(garment_ids) == 2: | |
# products = [llmr.get_garment(id) for id in garment_ids] | |
# if products[0]["category"] != products[1]["category"]: | |
# data2 = { | |
# "session_id": f"{profile['id']}", | |
# "model_filename": response["try_on_images"][0]["filename"], | |
# "garment_ids": [garment_ids[1]], | |
# } | |
# response2 = call_ootd_api("try-ons", data2) | |
# response["try_on_images"] += response2["try_on_images"] | |
# response["try_on_images"][-1]["extra_garment_id"] = garment_ids[0] | |
return response | |
def login(image_file_or_ndarray): | |
print("login:", type(image_file_or_ndarray)) | |
data = {"encodedData": encode_image(image_file_or_ndarray)} | |
response = call_ootd_api("sessions", data) | |
profile = response["profile"] | |
print(response) | |
profile["gender"] = llmr.login(profile["name"], data["encodedData"]) | |
try_on_images_folder = f"{static_dir}/try-ons/{profile['id']}" | |
os.makedirs(try_on_images_folder, exist_ok=True) | |
return profile, try_on_images_folder | |
share_gradio_app = os.environ.get("SHARE_GRADIO_APP") == "true" | |
ootd_server_url = os.environ.get("OOTD_SERVER_URL") | |
def predict(message, history): | |
print("predict:", message) | |
response = llmr.invoke(message) | |
print(response) | |
if response["intent"] in ["unknown", "checkout"]: | |
partial_message = f"{response['message']}" | |
elif response["intent"] == "recommendation": | |
partial_message = f"Here are {len(response['products'])} recommendation{'s' if len(response['products']) > 1 else ''} for you:" | |
elif response["intent"] == "try-on": | |
partial_message = f"For {len(response['products'])} garment{'s' if len(response['products']) > 1 else ''} you've selected:" | |
elif response["intent"] == "add-to-cart": | |
partial_message = f"Added {len(response['products'])} item{'s' if len(response['products']) > 1 else ''} into your cart:" | |
elif response["intent"] == "view-cart": | |
partial_message = f"There {'are' if len(response['products']) > 1 else 'is'} {len(response['products'])} item{'s' if len(response['products']) > 1 else ''} in your cart{':' if len(response['products']) > 0 else '.'}" | |
else: | |
partial_message = f"{response}" | |
if "products" in response and response["products"]: | |
partial_message += "\n\n" | |
for id in response["products"]: | |
product = llmr.get_garment(id) | |
url = product["image"] | |
parsed_url = urlparse(url) | |
title = f"#{product['id']} {product['name']}: ${product['price']}" | |
partial_message += f"1. [{title}]({url})\n" | |
partial_message += f"\n" | |
if response["intent"] == "try-on": | |
response = call_try_on_api(response["products"]) | |
print(response) | |
partial_message += f"\n\nwe have created {len(response['try_on_images'])} try-on image{'s:' if len(response['try_on_images']) > 1 else '.'}" | |
partial_message += "\n\n" | |
for try_on_image in response["try_on_images"]: | |
url = try_on_image["url"] | |
file_name = try_on_image["filename"] | |
file_path = f"{try_on_images_folder}/{file_name}" | |
response = make_request_with_retry(url) | |
with open(file_path, "wb") as f: | |
f.write(response.content) | |
title = f"{profile['name']} wearing " | |
extra_garment_id = try_on_image.get("extra_garment_id") | |
if extra_garment_id: | |
extra_product = llmr.get_garment(extra_garment_id) | |
title += f"{extra_product['name']} and " | |
id = try_on_image["garment_id"] | |
product = llmr.get_garment(id) | |
title += f"{product['name']}" | |
partial_message += f"1. [{title}]({url})\n" | |
parsed_url = urlparse(url) | |
partial_message += f"\n" | |
yield partial_message | |
app = FastAPI() | |
model = os.environ.get("OPENAI_MODEL_NAME") | |
href = "https://platform.openai.com/docs/models" | |
title = "Style Master" | |
questions_file_path = os.environ.get("QUESTIONS_FILE_PATH") | |
# Open the file for reading | |
with open(questions_file_path, "r") as file: | |
examples = file.readlines() | |
examples = [example.strip() for example in examples] | |
description = f"""\ | |
<div align="left"> | |
<p> Powered by: <a href="{href}">{model}</a></p> | |
</div> | |
""" | |
css = """ | |
.container { | |
height: 100vh; | |
} | |
#image_upload{min-height:400px} | |
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px} | |
#image_upload .touch-none{display: flex} | |
""" | |
def get_all_garments(): | |
print("get_all_garments") | |
return [ | |
(garment["image"], f"#{garment['id']} {garment['name']}: ${garment['price']}") | |
for garment in llmr.get_garments() | |
] | |
def upload_model_image(model_path=None, model_image=None): | |
print("upload_model_image:", model_path if model_path else type(model_image)) | |
global profile, try_on_images_folder | |
profile, try_on_images_folder = login(model_path or model_image) | |
print(profile) | |
if model_path is None: | |
gender = profile["gender"] | |
model_path = f"{static_dir}/models/{gender}.png" | |
print("refresh model:", model_path) | |
image = Image.open(model_path) | |
return image | |
def visible_component(): | |
return gr.update(visible=True), gr.update(visible=True) | |
def use_preset_model_image(model_image): | |
image = upload_model_image(model_image=model_image) | |
x, y = visible_component() | |
return image, x, y | |
with gr.Blocks(title=title, css=css) as demo: | |
with gr.Tab(f"Chat with {title}", visible=False) as main_tab: | |
with gr.Column(elem_classes=["container"]): | |
# Setting up the Gradio chat interface. | |
chat_ui = gr.ChatInterface( | |
predict, | |
# title=title, | |
description=description, | |
examples=examples, | |
cache_examples=False, | |
) | |
chat_ui.clear_btn.click(lambda: llmr.login(profile["name"])) | |
with gr.Tab("View Catalogue", visible=False) as catalogue_tab: | |
gallery = gr.Gallery( | |
get_all_garments, | |
label="Generated images", | |
show_label=False, | |
elem_id="gallery", | |
columns=[5], | |
rows=[4], | |
object_fit="contain", | |
height=800, | |
) | |
with gr.Tab("Change Model"): | |
with gr.Column(): | |
model_image = gr.Image(elem_id="image_upload", width=600) | |
upload_button = gr.UploadButton( | |
"Upload Image", file_types=["image"], file_count="single" | |
) | |
upload_button.upload(upload_model_image, upload_button, model_image).then( | |
fn=visible_component, inputs=None, outputs=[main_tab, catalogue_tab] | |
) | |
example = gr.Examples( | |
label="Preset Models", | |
inputs=model_image, | |
outputs=[model_image, main_tab, catalogue_tab], | |
fn=use_preset_model_image, | |
run_on_click=True, | |
examples_per_page=2, | |
examples=[ | |
os.path.join(static_dir, "models/female.png"), | |
os.path.join(static_dir, "models/male.png"), | |
], | |
) | |
demo.queue() | |
app.mount( | |
"/static", StaticFiles(directory=static_dir, follow_symlink=True), name="static" | |
) | |
app = gr.mount_gradio_app(app, demo, path="/") | |