EvgenyKu's picture
initial commit
f2d49da
raw
history blame
13.4 kB
import os
#import random
import datetime
import spaces
import torch
import gradio as gr
#from huggingface_hub import hf_hub_download
import traceback
from transformers import pipeline
from huggingface_hub import login
from diffusers import FluxPipeline
from deep_translator import GoogleTranslator
login(token = os.getenv('HF_TOKEN'))
print("="*50)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
print(f"Current device: {torch.cuda.current_device()}")
print(f"Device name: {torch.cuda.get_device_name(0)}")
print("="*50)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"{datetime.datetime.now()} Загрузка модели FLUX.1-dev")
pipe = FluxPipeline.from_pretrained(
# pretrained_model_name_or_path = local_path,
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16, # Используем bfloat16 для A100
# low_cpu_mem_usage=True, # Экономия памяти
# device_map="balanced",
# local_files_only=True
# variant="fp16",
use_safetensors=True
)
print(f"{datetime.datetime.now()} Загрузка модели FLUX.1-dev успешно завершена")
print(f"{datetime.datetime.now()} Загрузка LoRA")
pipe.load_lora_weights("Shakker-Labs/FLUX.1-dev-LoRA-add-details", weight_name="FLUX-dev-lora-add_details.safetensors")
print(f"{datetime.datetime.now()} Загрузка LoRA успешно завершена")
pipe.fuse_lora(lora_scale=1.0)
pipe.to(device)
pipe.enable_model_cpu_offload() # Выгрузка неиспользуемых компонентов
# print(f"{datetime.datetime.now()} Загрузка модели stabilityai/stable-diffusion-x4-upscaler")
# upscaler_pipeline = StableDiffusionUpscalePipeline.from_pretrained(
# "stabilityai/stable-diffusion-x4-upscaler",
# torch_dtype=torch.float16
# ).to(device)
# print(f"{datetime.datetime.now()} Загрузка модели stabilityai/stable-diffusion-x4-upscaler успешно завершена")
#
# upscaler_pipeline.enable_model_cpu_offload() # Выгрузка неиспользуемых компонентов
print(f"{datetime.datetime.now()} Загрузка модели briaai/RMBG-1.4")
bg_remover = pipeline("image-segmentation", "briaai/RMBG-1.4", trust_remote_code=True )
print(f"{datetime.datetime.now()} Загрузка модели briaai/RMBG-1.4 успешно завершена")
@spaces.GPU()
def generate_image(object_name, remove_bg=True):
try:
# Формирование промпта
object_name = translate_ru_en(object_name)
prompt = create_template_prompt(object_name)
# Для имитации генерации (можно заменить на реальный вызов ComfyUI API)
print(f"Генерация иконки для объекта: {object_name}")
print(f"Промпт: {prompt[:100]}...")
# print(f"Параметры: seed={seed}, steps={steps}, размер={width}x{height}")
print(f"Опции: remove_bg={remove_bg}")
steps = os.getenv('STEPS') if os.getenv('STEPS') is not None else 10
print(f"Шаги: {steps}")
image = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=int(steps),
generator=torch.Generator(device).manual_seed(42)
).images[0]
torch.cuda.empty_cache()
# if upscale :
# torch.cuda.empty_cache()
# upscaled_image = upscaler_pipeline(
# prompt="", # Обязательный параметр, но может быть пустым
# image=image,
# num_inference_steps=steps, # Оптимально для качества/скорости
# guidance_scale=1.0 # Минимальное значение для апскейла
# ).images[0]
# return upscaled_image
if remove_bg :
remove_bg_image = bg_remover(image)
torch.cuda.empty_cache()
return remove_bg_image
torch.cuda.empty_cache()
return image
except Exception as e:
print(f"Ошибка при генерации изображения: {e}")
traceback.print_exc()
return None
def create_template_prompt(object_name):
template = load_text("prompt.txt")
return template.format(object_name = object_name)
def translate_ru_en(text: str):
try:
# Проверка на кириллицу (если включено)
if not any('\u0400' <= char <= '\u04FF' for char in text):
return text
# Создаем переводчик
translator = GoogleTranslator(source="ru", target="en")
# Выполняем перевод
return translator.translate(text)
except Exception as e:
print(f"Ошибка перевода: {e}")
traceback.print_exc()
return text
def load_text(file_name):
with open(file_name, 'r', encoding='utf-8') as f:
return f.read()
custom_css = load_text("style.css")
# Создание интерфейса Gradio
with gr.Blocks(title="3D Icon Generator", css=custom_css, theme=gr.themes.Default()) as app:
gr.Markdown("# iconDDDzilla")
gr.Markdown("### Create 3d icons with transparent background in one click!")
with gr.Row():
with gr.Column():
# Входные параметры
object_input = gr.Textbox(label="Object name", placeholder="Type object name (for example: calendar, phone, camera)")
remove_bg_checkbox = gr.Checkbox(label="Remove background", value=True)
# with gr.Accordion("Расширенные настройки", open=False):
# custom_prompt = gr.Textbox(label="Пользовательский промпт", placeholder="Оставьте пустым для использования шаблона", lines=3)
#
# with gr.Row():
# seed = gr.Number(label="Seed", value=276789180904019, precision=0)
# steps = gr.Slider(minimum=1, maximum=5, value=5, step=1, label="Шаги")
#
# with gr.Row():
# width = gr.Slider(minimum=512, maximum=2048, value=1024, step=64, label="Ширина")
# height = gr.Slider(minimum=512, maximum=2048, value=1024, step=64, label="Высота")
#
# with gr.Row():
# #upscale_checkbox = gr.Checkbox(label="Применить апскейл", value=True)
# remove_bg_checkbox = gr.Checkbox(label="Удалить фон", value=False)
# Кнопка генерации
generate_btn = gr.Button("Run")
with gr.Column():
# Выходное изображение
output_image = gr.Image(label="Image")
# Примеры использования
# examples = gr.Examples(
# examples=[
# ["calendar", "", 276789180904019, 5, 1024, 1024, True],
# ["camera", "", 391847529184, 5, 1024, 1024, True],
# ["smartphone", "", 654321987654, 5, 1024, 1024, True],
# ["headphones", "", 123456789012, 5, 1024, 1024, True],
# ],
# inputs=[
# object_input,
# custom_prompt,
# seed,
# steps,
# width,
# height,
# #upscale_checkbox,
# remove_bg_checkbox
# ],
# outputs=[output_image],
# fn=generate_image,
# )
# Информация о моделях
# with gr.Accordion("Информация о используемых моделях", open=False):
# gr.Markdown("""
# ## Используемые модели
#
# - **Основная модель:** [flux1-dev-fp8.safetensors](https://huggingface.co/lllyasviel/flux1_dev/blob/main/flux1-dev-fp8.safetensors) от Stability AI
# - **Модель апскейла:** [4x_NMKD-Superscale-SP_178000_G.pth](https://huggingface.co/gemasai/4x_NMKD-Superscale-SP_178000_G) для улучшения качества изображения
# - **Модель удаления фона:** [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) для качественного удаления фона
#
# Все модели автоматически загружаются при первом запуске приложения.
# """)
# Привязка функции к нажатию кнопки
generate_btn.click(
fn=generate_image,
inputs=[
object_input,
# custom_prompt,
# seed,
# steps,
# width,
# height,
#upscale_checkbox,
remove_bg_checkbox
],
outputs=[output_image]
)
# Запуск приложения
if __name__ == "__main__":
app.launch()
### OLD UNUSED CODE!!! BE CAREFUL!!! ###
# Создаем необходимые директории
# os.makedirs("models", exist_ok=True)
# os.makedirs("models/checkpoints", exist_ok=True)
# os.makedirs("models/loras", exist_ok=True)
# os.makedirs("models/upscale_models", exist_ok=True)
# os.makedirs("models/rembg", exist_ok=True)
# os.makedirs("outputs", exist_ok=True)
# os.makedirs("temp_uploads", exist_ok=True)
# Загрузка моделей с Hugging Face
# def download_model(repo_id, filename, local_dir):
# """Загрузка модели с Hugging Face Hub"""
# local_path = os.path.join(local_dir, filename)
# if not os.path.exists(local_path):
# print(f"Загрузка {filename} из {repo_id}...")
# try:
# file_path = hf_hub_download(
# repo_id=repo_id,
# filename=filename,
# local_dir=local_dir,
# local_dir_use_symlinks=False
# )
# print(f"Модель успешно загружена: {file_path}")
# return file_path
# except Exception as e:
# print(f"Ошибка при загрузке модели {filename}: {e}")
# return None
# else:
# print(f"Модель {filename} уже присутствует: {local_path}")
# return local_path
# Вынесем загрузку моделей за пределы функции generate_image
# это критично для работы на ZeroGPU в Hugging Face Spaces
# def download_all_models():
# models = {
# "flux": download_model(
# "lllyasviel/flux1_dev",
# "flux1-dev-fp8.safetensors",
# "models/checkpoints"
# ),
# # "upscale": download_model(
# # "gemasai/4x_NMKD-Superscale-SP_178000_G",
# # "4x_NMKD-Superscale-SP_178000_G.pth",
# # "models/upscale_models"
# # ),
# # "rembg": "briaai/RMBG-1.4" # Используем модель RMBG-1.4 через transformers pipeline
# }
# return models
# def save_uploaded_file(file):
# """Сохранение загруженного файла"""
# if file is None:
# return None
#
# filename = os.path.join("temp_uploads", f"{random.randint(1000000, 9999999)}{os.path.splitext(file.name)[1]}")
# with open(filename, "wb") as f:
# f.write(file.read())
# return filename
# Функция для получения значения по индексу (используется в ComfyUI)
# def get_value_at_index(obj, index):
# """Получение значения из объекта по индексу из экспортированного ComfyUI кода"""
# if isinstance(obj, list):
# return obj[index]
# elif isinstance(obj, tuple):
# return obj[index]
# elif isinstance(obj, dict):
# return list(obj.values())[index]
# else:
# return obj[index]
# Загрузка моделей при запуске
#models = download_all_models()
# Инициализация pipeline для удаления фона
# try:
# rembg_pipeline = pipeline("image-segmentation", model=models["rembg"], trust_remote_code=True)
# print(f"Модель удаления фона успешно загружена: {models['rembg']}")
# except Exception as e:
# print(f"Ошибка при загрузке модели удаления фона: {e}")
# rembg_pipeline = None
# Укажите абсолютный путь к модели
# model_path = os.path.abspath("models/checkpoints/flux1-dev-fp8.safetensors")
# Проверьте существование ключевого файла
# if not os.path.exists(os.path.join(model_path, "model_index.json")):
# raise FileNotFoundError(f"Модель не найдена по пути: {model_path}")
# local_path = os.path.join("models/checkpoints", "")