phantomthe's picture
init
6027a38
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from sqlmodel import Field, Session, SQLModel, create_engine, select
from typing import Optional, List, Tuple
import hashlib
from datetime import datetime
from reportlab.lib.pagesizes import A4
from reportlab.lib.units import cm
from reportlab.platypus import SimpleDocTemplate, Paragraph, Image, PageBreak, Spacer
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.pdfbase import pdfmetrics
from reportlab.pdfbase.ttfonts import TTFont
from reportlab.lib.enums import TA_CENTER, TA_JUSTIFY
from PIL import Image as PILImage
import os
# 캐릭터 일관성을 μœ„ν•œ κ³ μ • 속성 μΆ”κ°€ (ν•œκ΅­μ–΄)
CHARACTER_DESCRIPTION = "young korean man with blue hoodie"
# λ°μ΄ν„°λ² μ΄μŠ€ λͺ¨λΈ
class Story(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
prompt: str
content: str
created_at: datetime = Field(default_factory=datetime.now)
class ImageCache(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
prompt_hash: str = Field(index=True)
image_path: str
created_at: datetime = Field(default_factory=datetime.now)
# λ°μ΄ν„°λ² μ΄μŠ€ μ΄ˆκΈ°ν™”
engine = create_engine("sqlite:///storybook.db")
SQLModel.metadata.create_all(engine)
# λͺ¨λΈ μ΄ˆκΈ°ν™”
print("λͺ¨λΈ λ‘œλ”© 쀑...")
device = "cuda" if torch.cuda.is_available() else "cpu"
# LLM λͺ¨λΈ
llm_model_name = "Bllossom/llama-3.2-Korean-Bllossom-AICA-5B"
tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto"
)
# Stable Diffusion λͺ¨λΈ
sd_model_name = "Lykon/DreamShaper"
sd_pipe = StableDiffusionPipeline.from_pretrained(
sd_model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
use_safetensors=False
)
sd_pipe = sd_pipe.to(device)
# λͺ¨λΈ λ‘œλ“œ ν›„ μŠ€μΌ€μ€„λŸ¬ λ³€κ²½
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
sd_pipe.scheduler.config,
use_karras_sigmas=True, # Karras schedule
algorithm_type="dpmsolver++"
)
# 이미지 μ €μž₯ 디렉토리
os.makedirs("generated_images", exist_ok=True)
def generate_story(prompt: str) -> Tuple[str, List[str]]:
"""ν”„λ‘¬ν”„νŠΈλ‘œλΆ€ν„° μŠ€ν† λ¦¬ 생성"""
system_prompt = f"""당신은 λ›°μ–΄λ‚œ μŠ€ν† λ¦¬ν…”λŸ¬μž…λ‹ˆλ‹€.
λ‹€μŒ 주제λ₯Ό λ°”νƒ•μœΌλ‘œ, 5개의 λ¬Έλ‹¨μœΌλ‘œ κ΅¬μ„±λœ ν₯미둜운 이야기λ₯Ό μž‘μ„±ν•˜μ„Έμš”.
κ·œμΉ™:
- 주인곡은 'μ²­λ…„' λ˜λŠ” 'κ·Έ'둜만 μ§€μΉ­ν•˜μ„Έμš” (이름 μ‚¬μš© κΈˆμ§€)
- 주인곡은 μ•ˆκ²½μ„ μ“΄ 20λŒ€ μ²­λ…„μž…λ‹ˆλ‹€
- 각 문단은 2~4개의 λ¬Έμž₯으둜 ꡬ성
- μ‹œκ°μ μœΌλ‘œ ν‘œν˜„ κ°€λŠ₯ν•œ ꡬ체적인 μž₯λ©΄ λ¬˜μ‚¬ 포함
- 순수 ν•œκ΅­μ–΄λ§Œ μ‚¬μš©
- 각 λ¬Έλ‹¨λ§ˆλ‹€ λͺ…ν™•ν•œ μž₯μ†Œμ™€ 행동 λ¬˜μ‚¬
주제: {prompt}
이야기:"""
inputs = tokenizer(system_prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = llm_model.generate(
**inputs,
max_new_tokens=1000,
temperature=0.7,
do_sample=True,
top_p=0.92,
repetition_penalty=1.1,
)
story = tokenizer.decode(outputs[0], skip_special_tokens=True)
story = story.replace(system_prompt, "").strip()
# 문단 뢄리
paragraphs = []
raw_paragraphs = story.split("\n\n")
for p in raw_paragraphs:
p = p.strip()
if p and len(p) > 20:
paragraphs.append(p)
paragraphs = paragraphs[:5]
# DB μ €μž₯
with Session(engine) as session:
db_story = Story(prompt=prompt, content="\n\n".join(paragraphs))
session.add(db_story)
session.commit()
return "\n\n".join(paragraphs), paragraphs
def analyze_text_for_english_scene(text: str, paragraph_num: int = 1) -> str:
"""ν…μŠ€νŠΈλ₯Ό λΆ„μ„ν•˜μ—¬ μ˜μ–΄ 씬 μΆ”μΆœ (κΈ°λ³Έ 10개 ν‚€μ›Œλ“œ)"""
# λ””λ²„κΉ…μš© 좜λ ₯
print(f"[{paragraph_num}] ν…μŠ€νŠΈ 뢄석 쀑: {text[:60]}...")
# 핡심 ν‚€μ›Œλ“œ 10개만 처리
# 1. 카페 + λ…ΈνŠΈλΆ/컴퓨터
if "카페" in text and ("λ…ΈνŠΈλΆ" in text or "컴퓨터" in text):
return "working on laptop in coffee shop"
# 2. 카페 (일반)
elif "카페" in text:
return "in a coffee shop"
# 3. ν”„λ‘œκ·Έλž˜λ°/μ½”λ”©
elif "ν”„λ‘œκ·Έλž˜λ°" in text or "μ½”λ”©" in text or "μ½”λ“œ" in text:
return "coding on laptop"
# 4. 회의/λ―ΈνŒ…
elif "회의" in text or "λ―ΈνŒ…" in text:
return "in a meeting"
# 5. λ°œν‘œ/ν”„λ ˆμ  ν…Œμ΄μ…˜
elif "λ°œν‘œ" in text or "ν”„λ ˆμ  ν…Œμ΄μ…˜" in text:
return "giving presentation"
# 6. λ™λ£Œ/νŒ€
elif "λ™λ£Œ" in text or "νŒ€" in text:
return "with team members"
# 7. 성곡/μΆ•ν•˜
elif "성곡" in text or "μΆ•ν•˜" in text:
return "celebrating success"
# 8. κ³„νš
elif "κ³„νš" in text:
return "planning"
# 9. 사무싀
elif "사무싀" in text:
return "in office"
# 10. 투자/투자자
elif "투자" in text:
return "meeting investors"
# κΈ°λ³Έκ°’ (문단별)
defaults = {
1: "young entrepreneur working",
2: "developing project",
3: "collaborating with others",
4: "business presentation",
5: "successful achievement"
}
return defaults.get(paragraph_num, "at work")
def generate_image(text: str, paragraph_num: int = 1) -> str:
"""ν…μŠ€νŠΈλ‘œλΆ€ν„° 이미지 생성"""
# ν”„λ‘¬ν”„νŠΈ ν•΄μ‹œ 생성
prompt_hash = hashlib.md5(text.encode()).hexdigest()
# μΊμ‹œ 확인
with Session(engine) as session:
cached = session.exec(
select(ImageCache).where(ImageCache.prompt_hash == prompt_hash)
).first()
if cached:
return cached.image_path
# 씬 μΆ”μΆœ
print(f"\n[{paragraph_num}/5] 이미지 생성 쀑...")
scene = analyze_text_for_english_scene(text)
# μ΅œμ’… ν”„λ‘¬ν”„νŠΈ 생성
final_prompt = f"{CHARACTER_DESCRIPTION} {scene}"
print("μ΅œμ’… ν”„λ‘¬ν”„νŠΈ: ", final_prompt)
print(f"ν”„λ‘¬ν”„νŠΈ 길이: {len(final_prompt)} κΈ€μž")
# λ„€κ±°ν‹°λΈŒ ν”„λ‘¬ν”„νŠΈ
negative_prompt = "realistic, photo, multiple people, crowd"
# Seed κ³ μ •
base_seed = 396135060
# Seed λ―Έμ„Έ λ³€ν™”
seed = base_seed + (paragraph_num * 10) # 10, 20, 30, 40, 50
# Seed λ‹€λ³€ν™”
#text_hash = int(hashlib.md5(text.encode()).hexdigest()[:8], 16)
#seed = base_seed + (text_hash % 1000)
generator = torch.Generator(device=device).manual_seed(seed)
#generator = torch.Generator(device=device).manual_seed(
# torch.randint(0, 100000, (1,)).item()
#)
# 이미지 생성
with torch.no_grad():
image = sd_pipe(
prompt=final_prompt,
negative_prompt=negative_prompt,
num_inference_steps=20,
guidance_scale=6.0,
height=512,
width=512,
generator=generator,
safety_checker=None,
requires_safety_checker=False
).images[0]
# 이미지 μ €μž₯
image_path = f"generated_images/{prompt_hash}.png"
image.save(image_path)
# μΊμ‹œ μ €μž₯
with Session(engine) as session:
cache_entry = ImageCache(prompt_hash=prompt_hash, image_path=image_path)
session.add(cache_entry)
session.commit()
return image_path
def create_pdf(story_text: str, image_paths: List[str], output_path: str = "storybook.pdf"):
"""μŠ€ν† λ¦¬μ™€ μ΄λ―Έμ§€λ‘œ PDF 생성"""
doc = SimpleDocTemplate(output_path, pagesize=A4)
story = []
font_path = "malgun.ttf"
pdfmetrics.registerFont(TTFont('맑은고딕', font_path))
# μŠ€νƒ€μΌ μ„€μ •
styles = getSampleStyleSheet()
title_style = ParagraphStyle(
'CustomTitle',
parent=styles['Heading1'],
fontName="맑은고딕",
fontSize=24,
textColor='black',
alignment=TA_CENTER,
spaceAfter=30
)
text_style = ParagraphStyle(
'CustomText',
parent=styles['Normal'],
fontName="맑은고딕",
fontSize=12,
leading=18,
alignment=TA_JUSTIFY,
spaceAfter=20
)
story.append(Paragraph("AI μŠ€ν† λ¦¬λΆ", title_style))
story.append(Spacer(1, 1*cm))
paragraphs = story_text.strip().split("\n\n")
for i, para in enumerate(paragraphs):
story.append(Paragraph(para.strip(), text_style))
if i < len(image_paths) and os.path.exists(image_paths[i]):
img = Image(image_paths[i], width=15*cm, height=10*cm)
story.append(img)
story.append(Spacer(1, 1*cm))
if i < len(paragraphs) - 1:
story.append(PageBreak())
doc.build(story)
return output_path
# Gradio μΈν„°νŽ˜μ΄μŠ€
def process_story(prompt: str):
"""μŠ€ν† λ¦¬ 생성 처리"""
story, paragraphs = generate_story(prompt)
return story, gr.update(visible=True), paragraphs
def generate_images_batch(paragraphs: List[str]):
"""배치둜 이미지 생성 (μ§„ν–‰λ₯  ν‘œμ‹œ)"""
from tqdm import tqdm
image_paths = []
for i, para in tqdm(enumerate(paragraphs), total=len(paragraphs), desc="이미지 생성"):
img_path = generate_image(para, paragraph_num=i+1)
image_paths.append(img_path)
if device == "cuda":
torch.cuda.empty_cache()
return image_paths
def create_storybook(story_text: str, paragraphs: List[str]):
"""μŠ€ν† λ¦¬λΆ PDF 생성"""
# 이미지 생성
image_paths = generate_images_batch(paragraphs)
# PDF 생성
pdf_path = create_pdf(story_text, image_paths)
# 이미지 가러리용 데이터
images = [PILImage.open(path) for path in image_paths]
return images, pdf_path
# Gradio UI
with gr.Blocks(title="AI μŠ€ν† λ¦¬λΆ μ €μž‘ 도ꡬ") as app:
gr.Markdown("# AI μŠ€ν† λ¦¬λΆ μ €μž‘ 도ꡬ")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="μŠ€ν† λ¦¬ 주제 μž…λ ₯",
placeholder="예: μŠ€νƒ€νŠΈμ—… μ°½μ—… 성곡 μŠ€ν† λ¦¬",
lines=2
)
generate_btn = gr.Button("μŠ€ν† λ¦¬ 생성", variant="primary")
story_output = gr.Textbox(
label="μƒμ„±λœ μŠ€ν† λ¦¬",
lines=15,
interactive=True
)
create_book_btn = gr.Button(
"μŠ€ν† λ¦¬λΆ 생성 (이미지 + PDF)",
variant="secondary",
visible=False
)
with gr.Column():
image_gallery = gr.Gallery(
label="μƒμ„±λœ 이미지",
show_label=True,
elem_id="gallery",
columns=2,
rows=3,
height="auto"
)
pdf_output = gr.File(
label="PDF λ‹€μš΄λ‘œλ“œ",
visible=True
)
# μƒνƒœ μ €μž₯
paragraphs_state = gr.State([])
# 이벀트 ν•Έλ“€λŸ¬
generate_btn.click(
fn=process_story,
inputs=[prompt_input],
outputs=[story_output, create_book_btn, paragraphs_state]
)
create_book_btn.click(
fn=create_storybook,
inputs=[story_output, paragraphs_state],
outputs=[image_gallery, pdf_output]
)
if __name__ == "__main__":
app.launch(share=True)