becas / main.py
TDN-M's picture
Update main.py
0d7c1bb verified
raw
history blame
15.5 kB
import gradio as gr
import requests
import json
import os
import hashlib
import time
from PIL import Image
from pathlib import Path
from io import BytesIO
from groq import Groq
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
api_key_token = os.getenv('api_key_token', '')
groq_api_key = os.getenv('groq_api_key', '')
if not api_key_token or not groq_api_key:
raise ValueError("Please configure your API keys in .env")
# Khởi tạo client Groq (sửa lỗi proxies)
original_init = Groq.__init__
def patched_init(self, *args, **kwargs):
kwargs.pop("proxies", None) # Loại bỏ tham số proxies nếu có
original_init(self, *args, **kwargs)
Groq.__init__ = patched_init
client = Groq(api_key=groq_api_key)
# Cấu hình cơ bản
url_pre = "https://ap-east-1.tensorart.cloud/v1"
SAVE_DIR = "generated_images"
Path(SAVE_DIR).mkdir(exist_ok=True)
# Danh sách sản phẩm (copy từ mã gốc)
PRODUCT_GROUPS = {
"Standard": {
"C1012 Glacier White": "817687427545199895",
"C1026 Polar": "819910519797326073",
"C3269 Ash Grey": "821839484099264081",
"C3168 Silver Wave": "821849044696643212",
"C1005 Milky White": "821948258441171133",
},
"Deluxe": {
"C2103 Onyx Carrara": "827090618489513527",
"C2104 Massa": "822075428127644644",
"C3105 Casla Cloudy": "828912225788997963",
"C3146 Casla Nova": "828013009961087650",
"C2240 Marquin": "828085015087780649",
"C2262 Concrete (Honed)": "822211862058871636",
"C3311 Calacatta Sky": "829984593223502930",
"C3346 Massimo": "827938741386607132",
},
"Luxury": {
"C4143 Mario": "829984593223502930",
"C4145 Marina": "828132560375742058",
"C4202 Calacatta Gold": "828167757632695310",
"C1205 Casla Everest": "828296778450463190",
"C4211 Calacatta Supreme": "828436321937882328",
"C4204 Calacatta Classic": "828422973179466146",
"C5240 Spring": "is coming",
"C1102 Super White": "828545723344775887",
"C4246 Casla Mystery": "828544778451950698",
"C4345 Oro": "828891068780182635",
"C4346 Luxe": "829436426547535131",
"C4342 Casla Eternal": "829190256201829181",
"C4221 Athena": "829644354504131520",
"C4222 Lagoon": "is coming",
"C5225 Amber": "is coming",
},
"Super Luxury": {
"C4255 Calacatta Extra": "829659013227537217",
},
}
PRODUCT_IMAGE_MAP = {
"C1012 Glacier White": "product_images/C1012.jpg",
"C1026 Polar": "product_images/C1026.jpg",
"C3269 Ash Grey": "product_images/C3269.jpg",
"C3168 Silver Wave": "product_images/C3168.jpg",
"C1005 Milky White": "product_images/C1005.jpg",
"C2103 Onyx Carrara": "product_images/C2103.jpg",
"C2104 Massa": "product_images/C2104.jpg",
"C3105 Casla Cloudy": "product_images/C3105.jpg",
"C3146 Casla Nova": "product_images/C3146.jpg",
"C2240 Marquin": "product_images/C2240.jpg",
"C2262 Concrete (Honed)": "product_images/C2262.jpg",
"C3311 Calacatta Sky": "product_images/C3311.jpg",
"C3346 Massimo": "product_images/C3346.jpg",
"C4143 Mario": "product_images/C4143.jpg",
"C4145 Marina": "product_images/C4145.jpg",
"C4202 Calacatta Gold": "product_images/C4202.jpg",
"C1205 Casla Everest": "product_images/C1205.jpg",
"C4211 Calacatta Supreme": "product_images/C4211.jpg",
"C4204 Calacatta Classic": "product_images/C4204.jpg",
"C1102 Super White": "product_images/C1102.jpg",
"C4246 Casla Mystery": "product_images/C4246.jpg",
"C4345 Oro": "product_images/C4345.jpg",
"C4346 Luxe": "product_images/C4346.jpg",
"C4342 Casla Eternal": "product_images/C4342.jpg",
"C4221 Athena": "product_images/C4221.jpg",
"C4255 Calacatta Extra": "product_images/C4255.jpg",
}
# Hàm xử lý từ mã gốc
def rewrite_prompt_with_groq(vietnamese_prompt, product_codes):
prompt = f"{vietnamese_prompt}, featuring {' and '.join(product_codes)} quartz marble"
return prompt
def upload_image_to_tensorart(image_path):
try:
url = f"{url_pre}/resource/image"
payload = json.dumps({"expireSec": "7200"})
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': f'Bearer {api_key_token}'
}
if not os.path.exists(image_path):
return None
response = requests.post(url, headers=headers, data=payload, timeout=30)
response.raise_for_status()
resource_response = response.json()
put_url = resource_response.get('putUrl')
headers_put = resource_response.get('headers', {'Content-Type': 'image/jpeg'})
if not put_url:
return None
with open(image_path, 'rb') as img_file:
upload_response = requests.put(put_url, data=img_file, headers=headers_put)
if upload_response.status_code not in [200, 203]:
raise Exception(f"PUT failed with status {upload_response.status_code}")
resource_id = resource_response.get('resourceId')
if not resource_id:
return None
time.sleep(10) # Đợi đồng bộ tài nguyên
return resource_id
except Exception as e:
print(f"Upload error: {str(e)}")
return None
def txt2img(prompt, width, height, product_codes):
model_id = "779398605850080514"
vae_id = "ae.sft"
txt2img_data = {
"request_id": hashlib.md5(str(int(time.time())).encode()).hexdigest(),
"stages": [
{"type": "INPUT_INITIALIZE", "inputInitialize": {"seed": -1, "count": 1}},
{
"type": "DIFFUSION",
"diffusion": {
"width": width,
"height": height,
"prompts": [{"text": prompt}],
"negativePrompts": [{"text": " "}],
"sdModel": model_id,
"sdVae": vae_id,
"sampler": "Euler a",
"steps": 30,
"cfgScale": 8,
"clipSkip": 1,
"etaNoiseSeedDelta": 31337,
}
}
]
}
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': f'Bearer {api_key_token}'
}
response = requests.post(f"{url_pre}/jobs", json=txt2img_data, headers=headers)
if response.status_code != 200:
return f"Error: {response.status_code} - {response.text}"
response_data = response.json()
job_id = response_data['job']['id']
start_time = time.time()
timeout = 300
while True:
time.sleep(10)
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
return f"Error: Job timed out after {timeout} seconds."
response = requests.get(f"{url_pre}/jobs/{job_id}", headers=headers)
if response.status_code != 200:
return f"Error: {response.status_code} - {response.text}"
get_job_response_data = response.json()
job_status = get_job_response_data['job']['status']
if job_status == 'SUCCESS':
image_url = get_job_response_data['job']['successInfo']['images'][0]['url']
response_image = requests.get(image_url)
img = Image.open(BytesIO(response_image.content))
save_path = Path(SAVE_DIR) / f"{hashlib.md5(prompt.encode()).hexdigest()}.png"
img.save(save_path)
return save_path
elif job_status == 'FAILED':
return "Error: Job failed."
def generate_mask(image_resource_id, position, selected_product_code):
try:
if not image_resource_id:
raise Exception("Không có image_resource_id hợp lệ - ảnh gốc chưa được upload")
print(f"Using image_resource_id: {image_resource_id}")
time.sleep(10) # Đợi đồng bộ tài nguyên
short_code = selected_product_code.split()[0]
texture_filepath = PRODUCT_IMAGE_MAP.get(selected_product_code)
print(f"Texture file: {texture_filepath}, exists: {os.path.exists(texture_filepath)}")
if not texture_filepath or not os.path.exists(texture_filepath):
raise Exception(f"Không tìm thấy ảnh sản phẩm cho mã {short_code}")
texture_resource_id = upload_image_to_tensorart(texture_filepath)
print(f"Texture resource_id: {texture_resource_id}")
if not texture_resource_id:
raise Exception(f"Không thể upload ảnh sản phẩm {short_code}")
time.sleep(10) # Đợi đồng bộ tài nguyên
if isinstance(position, (set, list)):
position = position[0] if position else "default"
print(f"Position: {position}, type: {type(position)}")
# Dùng params đúng như mẫu TensorArt
workflow_params = {
"1": {
"classType": "LayerMask: SegmentAnythingUltra V3",
"inputs": {
"black_point": 0.3,
"detail_dilate": 6,
"detail_erode": 65,
"detail_method": "GuidedFilter",
"device": "cuda",
"image": ["2", 0],
"max_megapixels": 2,
"process_detail": True,
"prompt": ["4", 0],
"sam_models": ["3", 0],
"threshold": 0.3,
"white_point": 0.99
},
"properties": {"Node name for S&R": "LayerMask: SegmentAnythingUltra V3"}
},
"10": {
"classType": "Image Seamless Texture",
"inputs": {
"blending": 0.37,
"images": ["17", 0],
"tiled": "true",
"tiles": 2
},
"properties": {"Node name for S&R": "Image Seamless Texture"}
},
"13": {
"classType": "Paste By Mask",
"inputs": {
"image_base": ["2", 0],
"image_to_paste": ["10", 0],
"mask": ["8", 0],
"resize_behavior": "resize"
},
"properties": {"Node name for S&R": "Paste By Mask"}
},
"17": {
"classType": "TensorArt_LoadImage",
"inputs": {
"_height": 768,
"_width": 512,
"image": texture_resource_id,
"upload": "image"
},
"properties": {"Node name for S&R": "TensorArt_LoadImage"}
},
"2": {
"classType": "TensorArt_LoadImage",
"inputs": {
"_height": 1024,
"_width": 768,
"image": image_resource_id,
"upload": "image"
},
"properties": {"Node name for S&R": "TensorArt_LoadImage"}
},
"3": {
"classType": "LayerMask: LoadSegmentAnythingModels",
"inputs": {
"grounding_dino_model": "GroundingDINO_SwinB (938MB)",
"sam_model": "sam_vit_h (2.56GB)"
},
"properties": {"Node name for S&R": "LayerMask: LoadSegmentAnythingModels"}
},
"4": {
"classType": "TensorArt_PromptText",
"inputs": {"Text": position.lower()},
"properties": {"Node name for S&R": "TensorArt_PromptText"}
},
"7": {
"classType": "PreviewImage",
"inputs": {"images": ["13", 0]},
"properties": {"Node name for S&R": "PreviewImage"}
},
"8": {
"classType": "MaskToImage",
"inputs": {"mask": ["1", 1]},
"properties": {"Node name for S&R": "MaskToImage"}
}
}
payload = {
"requestId": f"workflow_{int(time.time())}",
"params": workflow_params,
"runningNotifyUrl": ""
}
output_path = run_workflow(payload, "full_workflow")
return output_path
except Exception as e:
print(f"Mask generation error: {str(e)}")
return None
def run_workflow(payload, step_name):
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': f'Bearer {api_key_token}'
}
response = requests.post(f"{url_pre}/jobs/workflow", json=payload, headers=headers, timeout=300)
if response.status_code != 200:
return None
job_id = response.json()['job']['id']
max_attempts = 36
for _ in range(max_attempts):
response = requests.get(f"{url_pre}/jobs/{job_id}", headers=headers, timeout=30)
result = response.json()
if result['job']['status'] == 'SUCCESS':
image_url = result['job']['successInfo']['images'][0]['url']
image_response = requests.get(image_url)
image = Image.open(BytesIO(image_response.content))
if image.mode == 'RGBA':
image = image.convert('RGB')
output_path = Path(SAVE_DIR) / f"{step_name}_{int(time.time())}.jpg"
image.save(output_path)
return str(output_path)
time.sleep(5)
return None
# API endpoints cho frontend
def api_text2img(prompt, size, custom_size, product_codes):
try:
width, height = map(int, (custom_size if size == "Custom size" else size).split("x"))
if not product_codes:
return {"error": "At least one product code required"}
short_codes = [code.split()[0] for code in product_codes]
rewritten_prompt = rewrite_prompt_with_groq(prompt, short_codes)
result = txt2img(rewritten_prompt, width, height, short_codes)
if isinstance(result, str):
return {"error": result}
return {"image_url": f"/file={result}"}
except Exception as e:
return {"error": str(e)}
def api_img2img(file, position, size, custom_size, product_codes):
try:
width, height = map(int, (custom_size if size == "Custom size" else size).split("x"))
if not product_codes:
return {"error": "At least one product code required"}
image_path = Path(SAVE_DIR) / f"input_{int(time.time())}.jpg"
file.save(image_path)
image_resource_id = upload_image_to_tensorart(str(image_path))
if not image_resource_id:
return {"error": "Failed to upload image"}
output_path = generate_mask(image_resource_id, position, product_codes[0])
if not output_path:
return {"error": "Failed to generate image"}
return {"image_url": f"/file={output_path}"}
except Exception as e:
return {"error": str(e)}
# Giao diện Gradio (tùy chọn)
with gr.Blocks() as demo:
gr.Markdown("## Backend Gradio cho CaslaQuartz")
gr.Interface(fn=lambda x: "API only", inputs="text", outputs="text")
# Khởi chạy Gradio với API
demo.launch(
server_name="0.0.0.0",
server_port=7860
)