|
import os |
|
import uuid |
|
import GPUtil |
|
import gradio as gr |
|
import psutil |
|
import spaces |
|
from videosys import CogVideoXConfig, CogVideoXPABConfig, VideoSysEngine |
|
from transformers import pipeline |
|
|
|
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), ".tmp_outputs") |
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
|
|
|
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") |
|
|
|
def translate_to_english(text): |
|
if any('\uAC00' <= char <= '\uD7A3' for char in text): |
|
return translator(text, max_length=512)[0]['translation_text'] |
|
return text |
|
|
|
def load_model(model_name, enable_video_sys=False, pab_threshold=[100, 850], pab_range=2): |
|
pab_config = CogVideoXPABConfig(spatial_threshold=pab_threshold, spatial_range=pab_range) |
|
config = CogVideoXConfig(model_name, enable_pab=enable_video_sys, pab_config=pab_config) |
|
engine = VideoSysEngine(config) |
|
return engine |
|
|
|
def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0): |
|
translated_prompt = translate_to_english(prompt) |
|
video = engine.generate(translated_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).video[0] |
|
|
|
unique_filename = f"{uuid.uuid4().hex}.mp4" |
|
output_path = os.path.join("./.tmp_outputs", unique_filename) |
|
|
|
engine.save_video(video, output_path) |
|
return output_path |
|
|
|
@spaces.GPU() |
|
def generate_vanilla(model_name, prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)): |
|
engine = load_model(model_name) |
|
video_path = generate(engine, prompt, num_inference_steps, guidance_scale) |
|
return video_path |
|
|
|
@spaces.GPU() |
|
def generate_vs( |
|
model_name, |
|
prompt, |
|
num_inference_steps, |
|
guidance_scale, |
|
threshold_start, |
|
threshold_end, |
|
gap, |
|
progress=gr.Progress(track_tqdm=True), |
|
): |
|
threshold = [int(threshold_end), int(threshold_start)] |
|
gap = int(gap) |
|
engine = load_model(model_name, enable_video_sys=True, pab_threshold=threshold, pab_range=gap) |
|
video_path = generate(engine, prompt, num_inference_steps, guidance_scale) |
|
return video_path |
|
|
|
def get_server_status(): |
|
cpu_percent = psutil.cpu_percent() |
|
memory = psutil.virtual_memory() |
|
disk = psutil.disk_usage("/") |
|
try: |
|
gpus = GPUtil.getGPUs() |
|
if gpus: |
|
gpu = gpus[0] |
|
gpu_memory = f"{gpu.memoryUsed}/{gpu.memoryTotal}MB ({gpu.memoryUtil*100:.1f}%)" |
|
else: |
|
gpu_memory = "GPU๋ฅผ ์ฐพ์ ์ ์์" |
|
except: |
|
gpu_memory = "GPU ์ ๋ณด๋ฅผ ์ฌ์ฉํ ์ ์์" |
|
|
|
return { |
|
"cpu": f"{cpu_percent}%", |
|
"memory": f"{memory.percent}%", |
|
"disk": f"{disk.percent}%", |
|
"gpu_memory": gpu_memory, |
|
} |
|
|
|
def update_server_status(): |
|
status = get_server_status() |
|
return (status["cpu"], status["memory"], status["disk"], status["gpu_memory"]) |
|
|
|
css = """ |
|
footer { |
|
visibility: hidden; |
|
} |
|
""" |
|
|
|
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox(label="ํ๋กฌํํธ (200๋จ์ด ์ด๋ด)", value="๋ฐ๋ค ์์ ์ผ๋ชฐ.", lines=3) |
|
|
|
with gr.Column(): |
|
gr.Markdown("**์์ฑ ๋งค๊ฐ๋ณ์**<br>") |
|
with gr.Row(): |
|
model_name = gr.Radio( |
|
["THUDM/CogVideoX-2b", "THUDM/CogVideoX-5b"], label="๋ชจ๋ธ ์ ํ", value="THUDM/CogVideoX-2b" |
|
) |
|
with gr.Row(): |
|
num_inference_steps = gr.Number(label="์ถ๋ก ๋จ๊ณ", value=50) |
|
guidance_scale = gr.Number(label="๊ฐ์ด๋์ค ์ค์ผ์ผ", value=6.0) |
|
with gr.Row(): |
|
pab_range = gr.Number( |
|
label="PAB ๋ธ๋ก๋์บ์คํธ ๋ฒ์", value=2, precision=0, info="๋ธ๋ก๋์บ์คํธ ํ์์คํ
๋ฒ์." |
|
) |
|
pab_threshold_start = gr.Number(label="PAB ์์ ํ์์คํ
", value=850, info="1000 ๋จ๊ณ์์ ์์.") |
|
pab_threshold_end = gr.Number(label="PAB ์ข
๋ฃ ํ์์คํ
", value=100, info="0 ๋จ๊ณ์์ ์ข
๋ฃ.") |
|
with gr.Row(): |
|
generate_button_vs = gr.Button("โก๏ธ VideoSys๋ก ๋น๋์ค ์์ฑ (๋ ๋น ๋ฆ)") |
|
generate_button = gr.Button("๐ฌ ๋น๋์ค ์์ฑ (์๋ณธ)") |
|
with gr.Column(elem_classes="server-status"): |
|
gr.Markdown("#### ์๋ฒ ์ํ") |
|
|
|
with gr.Row(): |
|
cpu_status = gr.Textbox(label="CPU", scale=1) |
|
memory_status = gr.Textbox(label="๋ฉ๋ชจ๋ฆฌ", scale=1) |
|
|
|
with gr.Row(): |
|
disk_status = gr.Textbox(label="๋์คํฌ", scale=1) |
|
gpu_status = gr.Textbox(label="GPU ๋ฉ๋ชจ๋ฆฌ", scale=1) |
|
|
|
with gr.Row(): |
|
refresh_button = gr.Button("์๋ก๊ณ ์นจ") |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
video_output_vs = gr.Video(label="VideoSys๋ฅผ ์ฌ์ฉํ CogVideoX", width=720, height=480) |
|
with gr.Row(): |
|
video_output = gr.Video(label="CogVideoX", width=720, height=480) |
|
|
|
generate_button.click( |
|
generate_vanilla, |
|
inputs=[model_name, prompt, num_inference_steps, guidance_scale], |
|
outputs=[video_output], |
|
concurrency_id="gen", |
|
concurrency_limit=1, |
|
) |
|
|
|
generate_button_vs.click( |
|
generate_vs, |
|
inputs=[ |
|
model_name, |
|
prompt, |
|
num_inference_steps, |
|
guidance_scale, |
|
pab_threshold_start, |
|
pab_threshold_end, |
|
pab_range, |
|
], |
|
outputs=[video_output_vs], |
|
concurrency_id="gen", |
|
concurrency_limit=1, |
|
) |
|
|
|
refresh_button.click(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status]) |
|
demo.load(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status], every=1) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=10, default_concurrency_limit=1) |
|
demo.launch() |