Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
from sqlmodel import Field, Session, SQLModel, create_engine, select | |
from typing import Optional, List, Tuple | |
import hashlib | |
from datetime import datetime | |
from reportlab.lib.pagesizes import A4 | |
from reportlab.lib.units import cm | |
from reportlab.platypus import SimpleDocTemplate, Paragraph, Image, PageBreak, Spacer | |
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle | |
from reportlab.pdfbase import pdfmetrics | |
from reportlab.pdfbase.ttfonts import TTFont | |
from reportlab.lib.enums import TA_CENTER, TA_JUSTIFY | |
from PIL import Image as PILImage | |
import os | |
# μΊλ¦ν° μΌκ΄μ±μ μν κ³ μ μμ± μΆκ° (νκ΅μ΄) | |
CHARACTER_DESCRIPTION = "young korean man with blue hoodie" | |
# λ°μ΄ν°λ² μ΄μ€ λͺ¨λΈ | |
class Story(SQLModel, table=True): | |
id: Optional[int] = Field(default=None, primary_key=True) | |
prompt: str | |
content: str | |
created_at: datetime = Field(default_factory=datetime.now) | |
class ImageCache(SQLModel, table=True): | |
id: Optional[int] = Field(default=None, primary_key=True) | |
prompt_hash: str = Field(index=True) | |
image_path: str | |
created_at: datetime = Field(default_factory=datetime.now) | |
# λ°μ΄ν°λ² μ΄μ€ μ΄κΈ°ν | |
engine = create_engine("sqlite:///storybook.db") | |
SQLModel.metadata.create_all(engine) | |
# λͺ¨λΈ μ΄κΈ°ν | |
print("λͺ¨λΈ λ‘λ© μ€...") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# LLM λͺ¨λΈ | |
llm_model_name = "Bllossom/llama-3.2-Korean-Bllossom-AICA-5B" | |
tokenizer = AutoTokenizer.from_pretrained(llm_model_name) | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
llm_model_name, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
device_map="auto" | |
) | |
# Stable Diffusion λͺ¨λΈ | |
sd_model_name = "Lykon/DreamShaper" | |
sd_pipe = StableDiffusionPipeline.from_pretrained( | |
sd_model_name, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
use_safetensors=False | |
) | |
sd_pipe = sd_pipe.to(device) | |
# λͺ¨λΈ λ‘λ ν μ€μΌμ€λ¬ λ³κ²½ | |
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
sd_pipe.scheduler.config, | |
use_karras_sigmas=True, # Karras schedule | |
algorithm_type="dpmsolver++" | |
) | |
# μ΄λ―Έμ§ μ μ₯ λλ ν 리 | |
os.makedirs("generated_images", exist_ok=True) | |
def generate_story(prompt: str) -> Tuple[str, List[str]]: | |
"""ν둬ννΈλ‘λΆν° μ€ν 리 μμ±""" | |
system_prompt = f"""λΉμ μ λ°μ΄λ μ€ν 리ν λ¬μ λλ€. | |
λ€μ μ£Όμ λ₯Ό λ°νμΌλ‘, 5κ°μ λ¬Έλ¨μΌλ‘ ꡬμ±λ ν₯λ―Έλ‘μ΄ μ΄μΌκΈ°λ₯Ό μμ±νμΈμ. | |
κ·μΉ: | |
- μ£ΌμΈκ³΅μ 'μ²λ ' λλ 'κ·Έ'λ‘λ§ μ§μΉνμΈμ (μ΄λ¦ μ¬μ© κΈμ§) | |
- μ£ΌμΈκ³΅μ μκ²½μ μ΄ 20λ μ²λ μ λλ€ | |
- κ° λ¬Έλ¨μ 2~4κ°μ λ¬Έμ₯μΌλ‘ κ΅¬μ± | |
- μκ°μ μΌλ‘ νν κ°λ₯ν ꡬ체μ μΈ μ₯λ©΄ λ¬μ¬ ν¬ν¨ | |
- μμ νκ΅μ΄λ§ μ¬μ© | |
- κ° λ¬Έλ¨λ§λ€ λͺ νν μ₯μμ νλ λ¬μ¬ | |
μ£Όμ : {prompt} | |
μ΄μΌκΈ°:""" | |
inputs = tokenizer(system_prompt, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = llm_model.generate( | |
**inputs, | |
max_new_tokens=1000, | |
temperature=0.7, | |
do_sample=True, | |
top_p=0.92, | |
repetition_penalty=1.1, | |
) | |
story = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
story = story.replace(system_prompt, "").strip() | |
# λ¬Έλ¨ λΆλ¦¬ | |
paragraphs = [] | |
raw_paragraphs = story.split("\n\n") | |
for p in raw_paragraphs: | |
p = p.strip() | |
if p and len(p) > 20: | |
paragraphs.append(p) | |
paragraphs = paragraphs[:5] | |
# DB μ μ₯ | |
with Session(engine) as session: | |
db_story = Story(prompt=prompt, content="\n\n".join(paragraphs)) | |
session.add(db_story) | |
session.commit() | |
return "\n\n".join(paragraphs), paragraphs | |
def analyze_text_for_english_scene(text: str, paragraph_num: int = 1) -> str: | |
"""ν μ€νΈλ₯Ό λΆμνμ¬ μμ΄ μ¬ μΆμΆ (κΈ°λ³Έ 10κ° ν€μλ)""" | |
# λλ²κΉ μ© μΆλ ₯ | |
print(f"[{paragraph_num}] ν μ€νΈ λΆμ μ€: {text[:60]}...") | |
# ν΅μ¬ ν€μλ 10κ°λ§ μ²λ¦¬ | |
# 1. μΉ΄ν + λ ΈνΈλΆ/μ»΄ν¨ν° | |
if "μΉ΄ν" in text and ("λ ΈνΈλΆ" in text or "μ»΄ν¨ν°" in text): | |
return "working on laptop in coffee shop" | |
# 2. μΉ΄ν (μΌλ°) | |
elif "μΉ΄ν" in text: | |
return "in a coffee shop" | |
# 3. νλ‘κ·Έλλ°/μ½λ© | |
elif "νλ‘κ·Έλλ°" in text or "μ½λ©" in text or "μ½λ" in text: | |
return "coding on laptop" | |
# 4. νμ/λ―Έν | |
elif "νμ" in text or "λ―Έν " in text: | |
return "in a meeting" | |
# 5. λ°ν/νλ μ ν μ΄μ | |
elif "λ°ν" in text or "νλ μ ν μ΄μ " in text: | |
return "giving presentation" | |
# 6. λλ£/ν | |
elif "λλ£" in text or "ν" in text: | |
return "with team members" | |
# 7. μ±κ³΅/μΆν | |
elif "μ±κ³΅" in text or "μΆν" in text: | |
return "celebrating success" | |
# 8. κ³ν | |
elif "κ³ν" in text: | |
return "planning" | |
# 9. μ¬λ¬΄μ€ | |
elif "μ¬λ¬΄μ€" in text: | |
return "in office" | |
# 10. ν¬μ/ν¬μμ | |
elif "ν¬μ" in text: | |
return "meeting investors" | |
# κΈ°λ³Έκ° (λ¬Έλ¨λ³) | |
defaults = { | |
1: "young entrepreneur working", | |
2: "developing project", | |
3: "collaborating with others", | |
4: "business presentation", | |
5: "successful achievement" | |
} | |
return defaults.get(paragraph_num, "at work") | |
def generate_image(text: str, paragraph_num: int = 1) -> str: | |
"""ν μ€νΈλ‘λΆν° μ΄λ―Έμ§ μμ±""" | |
# ν둬ννΈ ν΄μ μμ± | |
prompt_hash = hashlib.md5(text.encode()).hexdigest() | |
# μΊμ νμΈ | |
with Session(engine) as session: | |
cached = session.exec( | |
select(ImageCache).where(ImageCache.prompt_hash == prompt_hash) | |
).first() | |
if cached: | |
return cached.image_path | |
# μ¬ μΆμΆ | |
print(f"\n[{paragraph_num}/5] μ΄λ―Έμ§ μμ± μ€...") | |
scene = analyze_text_for_english_scene(text) | |
# μ΅μ’ ν둬ννΈ μμ± | |
final_prompt = f"{CHARACTER_DESCRIPTION} {scene}" | |
print("μ΅μ’ ν둬ννΈ: ", final_prompt) | |
print(f"ν둬ννΈ κΈΈμ΄: {len(final_prompt)} κΈμ") | |
# λ€κ±°ν°λΈ ν둬ννΈ | |
negative_prompt = "realistic, photo, multiple people, crowd" | |
# Seed κ³ μ | |
base_seed = 396135060 | |
# Seed λ―ΈμΈ λ³ν | |
seed = base_seed + (paragraph_num * 10) # 10, 20, 30, 40, 50 | |
# Seed λ€λ³ν | |
#text_hash = int(hashlib.md5(text.encode()).hexdigest()[:8], 16) | |
#seed = base_seed + (text_hash % 1000) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
#generator = torch.Generator(device=device).manual_seed( | |
# torch.randint(0, 100000, (1,)).item() | |
#) | |
# μ΄λ―Έμ§ μμ± | |
with torch.no_grad(): | |
image = sd_pipe( | |
prompt=final_prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=20, | |
guidance_scale=6.0, | |
height=512, | |
width=512, | |
generator=generator, | |
safety_checker=None, | |
requires_safety_checker=False | |
).images[0] | |
# μ΄λ―Έμ§ μ μ₯ | |
image_path = f"generated_images/{prompt_hash}.png" | |
image.save(image_path) | |
# μΊμ μ μ₯ | |
with Session(engine) as session: | |
cache_entry = ImageCache(prompt_hash=prompt_hash, image_path=image_path) | |
session.add(cache_entry) | |
session.commit() | |
return image_path | |
def create_pdf(story_text: str, image_paths: List[str], output_path: str = "storybook.pdf"): | |
"""μ€ν 리μ μ΄λ―Έμ§λ‘ PDF μμ±""" | |
doc = SimpleDocTemplate(output_path, pagesize=A4) | |
story = [] | |
font_path = "malgun.ttf" | |
pdfmetrics.registerFont(TTFont('λ§μκ³ λ', font_path)) | |
# μ€νμΌ μ€μ | |
styles = getSampleStyleSheet() | |
title_style = ParagraphStyle( | |
'CustomTitle', | |
parent=styles['Heading1'], | |
fontName="λ§μκ³ λ", | |
fontSize=24, | |
textColor='black', | |
alignment=TA_CENTER, | |
spaceAfter=30 | |
) | |
text_style = ParagraphStyle( | |
'CustomText', | |
parent=styles['Normal'], | |
fontName="λ§μκ³ λ", | |
fontSize=12, | |
leading=18, | |
alignment=TA_JUSTIFY, | |
spaceAfter=20 | |
) | |
story.append(Paragraph("AI μ€ν 리λΆ", title_style)) | |
story.append(Spacer(1, 1*cm)) | |
paragraphs = story_text.strip().split("\n\n") | |
for i, para in enumerate(paragraphs): | |
story.append(Paragraph(para.strip(), text_style)) | |
if i < len(image_paths) and os.path.exists(image_paths[i]): | |
img = Image(image_paths[i], width=15*cm, height=10*cm) | |
story.append(img) | |
story.append(Spacer(1, 1*cm)) | |
if i < len(paragraphs) - 1: | |
story.append(PageBreak()) | |
doc.build(story) | |
return output_path | |
# Gradio μΈν°νμ΄μ€ | |
def process_story(prompt: str): | |
"""μ€ν 리 μμ± μ²λ¦¬""" | |
story, paragraphs = generate_story(prompt) | |
return story, gr.update(visible=True), paragraphs | |
def generate_images_batch(paragraphs: List[str]): | |
"""λ°°μΉλ‘ μ΄λ―Έμ§ μμ± (μ§νλ₯ νμ)""" | |
from tqdm import tqdm | |
image_paths = [] | |
for i, para in tqdm(enumerate(paragraphs), total=len(paragraphs), desc="μ΄λ―Έμ§ μμ±"): | |
img_path = generate_image(para, paragraph_num=i+1) | |
image_paths.append(img_path) | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
return image_paths | |
def create_storybook(story_text: str, paragraphs: List[str]): | |
"""μ€ν λ¦¬λΆ PDF μμ±""" | |
# μ΄λ―Έμ§ μμ± | |
image_paths = generate_images_batch(paragraphs) | |
# PDF μμ± | |
pdf_path = create_pdf(story_text, image_paths) | |
# μ΄λ―Έμ§ κ°€λ¬λ¦¬μ© λ°μ΄ν° | |
images = [PILImage.open(path) for path in image_paths] | |
return images, pdf_path | |
# Gradio UI | |
with gr.Blocks(title="AI μ€ν λ¦¬λΆ μ μ λꡬ") as app: | |
gr.Markdown("# AI μ€ν λ¦¬λΆ μ μ λꡬ") | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox( | |
label="μ€ν 리 μ£Όμ μ λ ₯", | |
placeholder="μ: μ€ννΈμ μ°½μ μ±κ³΅ μ€ν 리", | |
lines=2 | |
) | |
generate_btn = gr.Button("μ€ν 리 μμ±", variant="primary") | |
story_output = gr.Textbox( | |
label="μμ±λ μ€ν 리", | |
lines=15, | |
interactive=True | |
) | |
create_book_btn = gr.Button( | |
"μ€ν λ¦¬λΆ μμ± (μ΄λ―Έμ§ + PDF)", | |
variant="secondary", | |
visible=False | |
) | |
with gr.Column(): | |
image_gallery = gr.Gallery( | |
label="μμ±λ μ΄λ―Έμ§", | |
show_label=True, | |
elem_id="gallery", | |
columns=2, | |
rows=3, | |
height="auto" | |
) | |
pdf_output = gr.File( | |
label="PDF λ€μ΄λ‘λ", | |
visible=True | |
) | |
# μν μ μ₯ | |
paragraphs_state = gr.State([]) | |
# μ΄λ²€νΈ νΈλ€λ¬ | |
generate_btn.click( | |
fn=process_story, | |
inputs=[prompt_input], | |
outputs=[story_output, create_book_btn, paragraphs_state] | |
) | |
create_book_btn.click( | |
fn=create_storybook, | |
inputs=[story_output, paragraphs_state], | |
outputs=[image_gallery, pdf_output] | |
) | |
if __name__ == "__main__": | |
app.launch(share=True) |