from fastapi import FastAPI, File, UploadFile, Form from fastapi.responses import StreamingResponse from pydantic import BaseModel from pydantic import Field from typing import Optional import logging import os import boto3 import json import shlex import subprocess import tempfile import time import base64 import gradio as gr import numpy as np import rembg import spaces import torch from PIL import Image from functools import partial import io from io import BytesIO from botocore.exceptions import NoCredentialsError, PartialCredentialsError import datetime app = FastAPI() subprocess.run(shlex.split('pip install wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl')) from tsr.system import TSR from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation if torch.cuda.is_available(): device = "cuda:0" else: device = "cpu" # torch.cuda.synchronize() model = TSR.from_pretrained( "stabilityai/TripoSR", config_name="config.yaml", weight_name="model.ckpt", ) model.renderer.set_chunk_size(131072) model.to(device) rembg_session = rembg.new_session() ACCESS = os.getenv("ACCESS") SECRET = os.getenv("SECRET") bedrock = boto3.client(service_name='bedrock', aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1') bedrock_runtime = boto3.client(service_name='bedrock-runtime', aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1') s3_client = boto3.client('s3',aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1') def upload_file_to_s3(file_path, bucket_name, object_name=None): s3_client.upload_file(file_path, bucket_name, object_name) return True def check_input_image(input_image): if input_image is None: raise gr.Error("No image uploaded!") def preprocess(input_image, do_remove_background, foreground_ratio): def fill_background(image): torch.cuda.synchronize() # Ensure previous CUDA operations are complete image = np.array(image).astype(np.float32) / 255.0 image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5 image = Image.fromarray((image * 255.0).astype(np.uint8)) return image if do_remove_background: torch.cuda.synchronize() image = input_image.convert("RGB") image = remove_background(image, rembg_session) image = resize_foreground(image, foreground_ratio) image = fill_background(image) torch.cuda.synchronize() else: image = input_image if image.mode == "RGBA": image = fill_background(image) torch.cuda.synchronize() # Wait for all CUDA operations to complete torch.cuda.empty_cache() return image def generate(image, mc_resolution, formats=["obj", "glb"]): torch.cuda.synchronize() scene_codes = model(image, device=device) torch.cuda.synchronize() mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0] torch.cuda.synchronize() mesh = to_gradio_3d_orientation(mesh) torch.cuda.synchronize() mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False) torch.cuda.synchronize() mesh.export(mesh_path_glb.name) torch.cuda.synchronize() mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False) torch.cuda.synchronize() mesh.apply_scale([-1, 1, 1]) mesh.export(mesh_path_obj.name) torch.cuda.synchronize() torch.cuda.empty_cache() return mesh_path_obj.name, mesh_path_glb.name @app.post("/process_image/") async def process_image( file: UploadFile = File(...), seed: int = Form(...), enhance_image: bool = Form(...), # Default enhance_image value do_remove_background: bool = Form(...), # Default do_remove_background value foreground_ratio: float = Form(...), # Ratio must be between 0.0 and 1.0 (exclusive) mc_resolution: int = Form(...), # Resolution must be between 256 and 4096 auth: str = Form(...), text_prompt: Optional[str] = Form(None) ): if auth == os.getenv("AUTHORIZE"): image_bytes = await file.read() image_pil = Image.open(BytesIO(image_bytes)) preprocessed = preprocess(image_pil, do_remove_background, foreground_ratio) mesh_name_obj, mesh_name_glb = generate(preprocessed, mc_resolution) timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S%f') object_name = f'object_{timestamp}.obj' object_name_2 = f'object_{timestamp}.glb' object_name_3 = f"object_{timestamp}.png" preprocessed_image_tempfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False) preprocessed.save(preprocessed_image_tempfile.name) upload_file_to_s3(preprocessed_image_tempfile.name, 'framebucket3d', object_name_3) if upload_file_to_s3(mesh_name_obj, 'framebucket3d',object_name) and upload_file_to_s3(mesh_name_glb, 'framebucket3d',object_name_2): # torch.cuda.synchronize() # Wait for all CUDA operations to complete # torch.cuda.empty_cache() return { "img_path": f"https://framebucket3d.s3.amazonaws.com/{object_name_3}", "obj_path": f"https://framebucket3d.s3.amazonaws.com/{object_name}", "glb_path": f"https://framebucket3d.s3.amazonaws.com/{object_name_2}" } else: return {"Internal Server Error": False} else: return {"Authentication":"Failed"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)