becas / main.py
TDN-M's picture
Update main.py
28b9e42 verified
raw
history blame
4.25 kB
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
from fastapi.responses import JSONResponse, FileResponse
from pydantic import BaseModel
import os
import time
from PIL import Image
from io import BytesIO
from pathlib import Path
import hashlib
import json
import requests
# Khởi tạo FastAPI app
app = FastAPI(title="CaslaQuartz Image Generation API")
# Đường dẫn lưu trữ ảnh
SAVE_DIR = "/generated_images"
Path(SAVE_DIR).mkdir(exist_ok=True)
Path(SAVE_DIR).mkdir(exist_ok=True)
# Định nghĩa các model Pydantic
class GenerateRequest(BaseModel):
prompt: str
size_choice: str
custom_size: str = None
product_codes: list[str]
class Img2ImgRequest(BaseModel):
position: str
size_choice: str
custom_size: str = None
product_codes: list[str]
# Endpoint: Text2Img
@app.post("/text2img", summary="Generate image from text prompt")
async def text2img(request: GenerateRequest):
try:
# Validate size
if request.size_choice == "Custom size":
if not request.custom_size:
raise HTTPException(status_code=400, detail="Custom size is required.")
width, height = map(int, request.custom_size.split("x"))
else:
width, height = map(int, request.size_choice.split("x"))
# Rewrite prompt with Groq
rewritten_prompt = rewrite_prompt_with_groq(request.prompt, request.product_codes)
# Generate image using txt2img function
result = txt2img(rewritten_prompt, width, height, request.product_codes)
if isinstance(result, str):
raise HTTPException(status_code=500, detail=result)
# Save and return the image
save_path = Path(SAVE_DIR) / f"{hashlib.md5(request.prompt.encode()).hexdigest()}.png"
result.save(save_path)
return FileResponse(save_path, media_type="image/png", filename="generated_image.png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Endpoint: Img2Img
@app.post("/img2img", summary="Generate image from input image")
async def img2img(
image: UploadFile = File(...),
position: str = Query(..., description="Position to apply texture"),
size_choice: str = Query(..., description="Image size"),
custom_size: str = Query(None, description="Custom size (e.g., 1280x720)"),
product_codes: list[str] = Query(..., description="Selected product codes"),
):
try:
# Validate size
if size_choice == "Custom size":
if not custom_size:
raise HTTPException(status_code=400, detail="Custom size is required.")
width, height = map(int, custom_size.split("x"))
else:
width, height = map(int, size_choice.split("x"))
# Save uploaded image
image_path = Path(SAVE_DIR) / f"input_{int(time.time())}.jpg"
with open(image_path, "wb") as f:
f.write(await image.read())
# Upload image to TensorArt
image_resource_id = upload_image_to_tensorart(str(image_path))
if not image_resource_id:
raise HTTPException(status_code=500, detail="Failed to upload input image.")
# Generate mask and apply texture
output_path = generate_mask(image_resource_id, position, product_codes[0])
if not output_path:
raise HTTPException(status_code=500, detail="Failed to generate image.")
# Return the generated image
return FileResponse(output_path, media_type="image/jpeg", filename="output_image.jpg")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Helper functions (unchanged from original code)
def rewrite_prompt_with_groq(vietnamese_prompt, product_codes):
prompt = f"{vietnamese_prompt}, featuring {' and '.join(product_codes)} quartz marble"
return prompt
def txt2img(prompt, width, height, product_codes):
# Logic for generating image from text (unchanged)
pass
def upload_image_to_tensorart(image_path):
# Logic for uploading image to TensorArt (unchanged)
pass
def generate_mask(image_resource_id, position, selected_product_code):
# Logic for generating mask and applying texture (unchanged)
pass