phantomthe commited on
Commit
6027a38
Β·
1 Parent(s): b00b30b
Files changed (2) hide show
  1. app.py +386 -0
  2. requirements.txt +18 -0
app.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
5
+ from sqlmodel import Field, Session, SQLModel, create_engine, select
6
+ from typing import Optional, List, Tuple
7
+ import hashlib
8
+ from datetime import datetime
9
+ from reportlab.lib.pagesizes import A4
10
+ from reportlab.lib.units import cm
11
+ from reportlab.platypus import SimpleDocTemplate, Paragraph, Image, PageBreak, Spacer
12
+ from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
13
+ from reportlab.pdfbase import pdfmetrics
14
+ from reportlab.pdfbase.ttfonts import TTFont
15
+ from reportlab.lib.enums import TA_CENTER, TA_JUSTIFY
16
+ from PIL import Image as PILImage
17
+ import os
18
+
19
+ # 캐릭터 일관성을 μœ„ν•œ κ³ μ • 속성 μΆ”κ°€ (ν•œκ΅­μ–΄)
20
+ CHARACTER_DESCRIPTION = "young korean man with blue hoodie"
21
+
22
+ # λ°μ΄ν„°λ² μ΄μŠ€ λͺ¨λΈ
23
+ class Story(SQLModel, table=True):
24
+ id: Optional[int] = Field(default=None, primary_key=True)
25
+ prompt: str
26
+ content: str
27
+ created_at: datetime = Field(default_factory=datetime.now)
28
+
29
+ class ImageCache(SQLModel, table=True):
30
+ id: Optional[int] = Field(default=None, primary_key=True)
31
+ prompt_hash: str = Field(index=True)
32
+ image_path: str
33
+ created_at: datetime = Field(default_factory=datetime.now)
34
+
35
+ # λ°μ΄ν„°λ² μ΄μŠ€ μ΄ˆκΈ°ν™”
36
+ engine = create_engine("sqlite:///storybook.db")
37
+ SQLModel.metadata.create_all(engine)
38
+
39
+ # λͺ¨λΈ μ΄ˆκΈ°ν™”
40
+ print("λͺ¨λΈ λ‘œλ”© 쀑...")
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+
43
+ # LLM λͺ¨λΈ
44
+ llm_model_name = "Bllossom/llama-3.2-Korean-Bllossom-AICA-5B"
45
+ tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
46
+ llm_model = AutoModelForCausalLM.from_pretrained(
47
+ llm_model_name,
48
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
49
+ device_map="auto"
50
+ )
51
+
52
+ # Stable Diffusion λͺ¨λΈ
53
+ sd_model_name = "Lykon/DreamShaper"
54
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
55
+ sd_model_name,
56
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
57
+ use_safetensors=False
58
+ )
59
+ sd_pipe = sd_pipe.to(device)
60
+
61
+ # λͺ¨λΈ λ‘œλ“œ ν›„ μŠ€μΌ€μ€„λŸ¬ λ³€κ²½
62
+ sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
63
+ sd_pipe.scheduler.config,
64
+ use_karras_sigmas=True, # Karras schedule
65
+ algorithm_type="dpmsolver++"
66
+ )
67
+
68
+ # 이미지 μ €μž₯ 디렉토리
69
+ os.makedirs("generated_images", exist_ok=True)
70
+
71
+
72
+ def generate_story(prompt: str) -> Tuple[str, List[str]]:
73
+ """ν”„λ‘¬ν”„νŠΈλ‘œλΆ€ν„° μŠ€ν† λ¦¬ 생성"""
74
+ system_prompt = f"""당신은 λ›°μ–΄λ‚œ μŠ€ν† λ¦¬ν…”λŸ¬μž…λ‹ˆλ‹€.
75
+ λ‹€μŒ 주제λ₯Ό λ°”νƒ•μœΌλ‘œ, 5개의 λ¬Έλ‹¨μœΌλ‘œ κ΅¬μ„±λœ ν₯미둜운 이야기λ₯Ό μž‘μ„±ν•˜μ„Έμš”.
76
+
77
+ κ·œμΉ™:
78
+ - 주인곡은 'μ²­λ…„' λ˜λŠ” 'κ·Έ'둜만 μ§€μΉ­ν•˜μ„Έμš” (이름 μ‚¬μš© κΈˆμ§€)
79
+ - 주인곡은 μ•ˆκ²½μ„ μ“΄ 20λŒ€ μ²­λ…„μž…λ‹ˆλ‹€
80
+ - 각 문단은 2~4개의 λ¬Έμž₯으둜 ꡬ성
81
+ - μ‹œκ°μ μœΌλ‘œ ν‘œν˜„ κ°€λŠ₯ν•œ ꡬ체적인 μž₯λ©΄ λ¬˜μ‚¬ 포함
82
+ - 순수 ν•œκ΅­μ–΄λ§Œ μ‚¬μš©
83
+ - 각 λ¬Έλ‹¨λ§ˆλ‹€ λͺ…ν™•ν•œ μž₯μ†Œμ™€ 행동 λ¬˜μ‚¬
84
+
85
+ 주제: {prompt}
86
+
87
+ 이야기:"""
88
+
89
+ inputs = tokenizer(system_prompt, return_tensors="pt").to(device)
90
+
91
+ with torch.no_grad():
92
+ outputs = llm_model.generate(
93
+ **inputs,
94
+ max_new_tokens=1000,
95
+ temperature=0.7,
96
+ do_sample=True,
97
+ top_p=0.92,
98
+ repetition_penalty=1.1,
99
+ )
100
+
101
+ story = tokenizer.decode(outputs[0], skip_special_tokens=True)
102
+ story = story.replace(system_prompt, "").strip()
103
+
104
+ # 문단 뢄리
105
+ paragraphs = []
106
+ raw_paragraphs = story.split("\n\n")
107
+ for p in raw_paragraphs:
108
+ p = p.strip()
109
+ if p and len(p) > 20:
110
+ paragraphs.append(p)
111
+
112
+ paragraphs = paragraphs[:5]
113
+
114
+ # DB μ €μž₯
115
+ with Session(engine) as session:
116
+ db_story = Story(prompt=prompt, content="\n\n".join(paragraphs))
117
+ session.add(db_story)
118
+ session.commit()
119
+
120
+ return "\n\n".join(paragraphs), paragraphs
121
+
122
+
123
+ def analyze_text_for_english_scene(text: str, paragraph_num: int = 1) -> str:
124
+ """ν…μŠ€νŠΈλ₯Ό λΆ„μ„ν•˜μ—¬ μ˜μ–΄ 씬 μΆ”μΆœ (κΈ°λ³Έ 10개 ν‚€μ›Œλ“œ)"""
125
+
126
+ # λ””λ²„κΉ…μš© 좜λ ₯
127
+ print(f"[{paragraph_num}] ν…μŠ€νŠΈ 뢄석 쀑: {text[:60]}...")
128
+
129
+ # 핡심 ν‚€μ›Œλ“œ 10개만 처리
130
+ # 1. 카페 + λ…ΈνŠΈλΆ/컴퓨터
131
+ if "카페" in text and ("λ…ΈνŠΈλΆ" in text or "컴퓨터" in text):
132
+ return "working on laptop in coffee shop"
133
+
134
+ # 2. 카페 (일반)
135
+ elif "카페" in text:
136
+ return "in a coffee shop"
137
+
138
+ # 3. ν”„λ‘œκ·Έλž˜λ°/μ½”λ”©
139
+ elif "ν”„λ‘œκ·Έλž˜λ°" in text or "μ½”λ”©" in text or "μ½”λ“œ" in text:
140
+ return "coding on laptop"
141
+
142
+ # 4. 회의/λ―ΈνŒ…
143
+ elif "회의" in text or "λ―ΈνŒ…" in text:
144
+ return "in a meeting"
145
+
146
+ # 5. λ°œν‘œ/ν”„λ ˆμ  ν…Œμ΄μ…˜
147
+ elif "λ°œν‘œ" in text or "ν”„λ ˆμ  ν…Œμ΄μ…˜" in text:
148
+ return "giving presentation"
149
+
150
+ # 6. λ™λ£Œ/νŒ€
151
+ elif "��료" in text or "νŒ€" in text:
152
+ return "with team members"
153
+
154
+ # 7. 성곡/μΆ•ν•˜
155
+ elif "성곡" in text or "μΆ•ν•˜" in text:
156
+ return "celebrating success"
157
+
158
+ # 8. κ³„νš
159
+ elif "κ³„νš" in text:
160
+ return "planning"
161
+
162
+ # 9. 사무싀
163
+ elif "사무싀" in text:
164
+ return "in office"
165
+
166
+ # 10. 투자/투자자
167
+ elif "투자" in text:
168
+ return "meeting investors"
169
+
170
+ # κΈ°λ³Έκ°’ (문단별)
171
+ defaults = {
172
+ 1: "young entrepreneur working",
173
+ 2: "developing project",
174
+ 3: "collaborating with others",
175
+ 4: "business presentation",
176
+ 5: "successful achievement"
177
+ }
178
+
179
+ return defaults.get(paragraph_num, "at work")
180
+
181
+
182
+ def generate_image(text: str, paragraph_num: int = 1) -> str:
183
+ """ν…μŠ€νŠΈλ‘œλΆ€ν„° 이미지 생성"""
184
+ # ν”„λ‘¬ν”„νŠΈ ν•΄μ‹œ 생성
185
+ prompt_hash = hashlib.md5(text.encode()).hexdigest()
186
+
187
+ # μΊμ‹œ 확인
188
+ with Session(engine) as session:
189
+ cached = session.exec(
190
+ select(ImageCache).where(ImageCache.prompt_hash == prompt_hash)
191
+ ).first()
192
+
193
+ if cached:
194
+ return cached.image_path
195
+
196
+ # 씬 μΆ”μΆœ
197
+ print(f"\n[{paragraph_num}/5] 이미지 생성 쀑...")
198
+ scene = analyze_text_for_english_scene(text)
199
+
200
+ # μ΅œμ’… ν”„λ‘¬ν”„νŠΈ 생성
201
+ final_prompt = f"{CHARACTER_DESCRIPTION} {scene}"
202
+
203
+
204
+ print("μ΅œμ’… ν”„λ‘¬ν”„νŠΈ: ", final_prompt)
205
+ print(f"ν”„λ‘¬ν”„νŠΈ 길이: {len(final_prompt)} κΈ€μž")
206
+
207
+ # λ„€κ±°ν‹°λΈŒ ν”„λ‘¬ν”„νŠΈ
208
+ negative_prompt = "realistic, photo, multiple people, crowd"
209
+
210
+ # Seed κ³ μ •
211
+ base_seed = 396135060
212
+ # Seed λ―Έμ„Έ λ³€ν™”
213
+ seed = base_seed + (paragraph_num * 10) # 10, 20, 30, 40, 50
214
+ # Seed λ‹€λ³€ν™”
215
+ #text_hash = int(hashlib.md5(text.encode()).hexdigest()[:8], 16)
216
+ #seed = base_seed + (text_hash % 1000)
217
+ generator = torch.Generator(device=device).manual_seed(seed)
218
+ #generator = torch.Generator(device=device).manual_seed(
219
+ # torch.randint(0, 100000, (1,)).item()
220
+ #)
221
+
222
+ # 이미지 생성
223
+ with torch.no_grad():
224
+ image = sd_pipe(
225
+ prompt=final_prompt,
226
+ negative_prompt=negative_prompt,
227
+ num_inference_steps=20,
228
+ guidance_scale=6.0,
229
+ height=512,
230
+ width=512,
231
+ generator=generator,
232
+ safety_checker=None,
233
+ requires_safety_checker=False
234
+ ).images[0]
235
+
236
+ # 이미지 μ €μž₯
237
+ image_path = f"generated_images/{prompt_hash}.png"
238
+ image.save(image_path)
239
+
240
+ # μΊμ‹œ μ €μž₯
241
+ with Session(engine) as session:
242
+ cache_entry = ImageCache(prompt_hash=prompt_hash, image_path=image_path)
243
+ session.add(cache_entry)
244
+ session.commit()
245
+
246
+ return image_path
247
+
248
+ def create_pdf(story_text: str, image_paths: List[str], output_path: str = "storybook.pdf"):
249
+ """μŠ€ν† λ¦¬μ™€ μ΄λ―Έμ§€λ‘œ PDF 생성"""
250
+ doc = SimpleDocTemplate(output_path, pagesize=A4)
251
+ story = []
252
+
253
+ font_path = "malgun.ttf"
254
+ pdfmetrics.registerFont(TTFont('맑은고딕', font_path))
255
+
256
+ # μŠ€νƒ€μΌ μ„€μ •
257
+ styles = getSampleStyleSheet()
258
+ title_style = ParagraphStyle(
259
+ 'CustomTitle',
260
+ parent=styles['Heading1'],
261
+ fontName="맑은고딕",
262
+ fontSize=24,
263
+ textColor='black',
264
+ alignment=TA_CENTER,
265
+ spaceAfter=30
266
+ )
267
+
268
+ text_style = ParagraphStyle(
269
+ 'CustomText',
270
+ parent=styles['Normal'],
271
+ fontName="맑은고딕",
272
+ fontSize=12,
273
+ leading=18,
274
+ alignment=TA_JUSTIFY,
275
+ spaceAfter=20
276
+ )
277
+
278
+ story.append(Paragraph("AI μŠ€ν† λ¦¬λΆ", title_style))
279
+ story.append(Spacer(1, 1*cm))
280
+
281
+ paragraphs = story_text.strip().split("\n\n")
282
+ for i, para in enumerate(paragraphs):
283
+ story.append(Paragraph(para.strip(), text_style))
284
+
285
+ if i < len(image_paths) and os.path.exists(image_paths[i]):
286
+ img = Image(image_paths[i], width=15*cm, height=10*cm)
287
+ story.append(img)
288
+ story.append(Spacer(1, 1*cm))
289
+
290
+ if i < len(paragraphs) - 1:
291
+ story.append(PageBreak())
292
+
293
+ doc.build(story)
294
+ return output_path
295
+
296
+ # Gradio μΈν„°νŽ˜μ΄μŠ€
297
+ def process_story(prompt: str):
298
+ """μŠ€ν† λ¦¬ 생성 처리"""
299
+ story, paragraphs = generate_story(prompt)
300
+ return story, gr.update(visible=True), paragraphs
301
+
302
+ def generate_images_batch(paragraphs: List[str]):
303
+ """배치둜 이미지 생성 (μ§„ν–‰λ₯  ν‘œμ‹œ)"""
304
+ from tqdm import tqdm
305
+
306
+ image_paths = []
307
+ for i, para in tqdm(enumerate(paragraphs), total=len(paragraphs), desc="이미지 생성"):
308
+ img_path = generate_image(para, paragraph_num=i+1)
309
+ image_paths.append(img_path)
310
+
311
+ if device == "cuda":
312
+ torch.cuda.empty_cache()
313
+
314
+ return image_paths
315
+
316
+ def create_storybook(story_text: str, paragraphs: List[str]):
317
+ """μŠ€ν† λ¦¬λΆ PDF 생성"""
318
+ # 이미지 생성
319
+ image_paths = generate_images_batch(paragraphs)
320
+
321
+ # PDF 생성
322
+ pdf_path = create_pdf(story_text, image_paths)
323
+
324
+ # 이미지 가러리용 데이터
325
+ images = [PILImage.open(path) for path in image_paths]
326
+
327
+ return images, pdf_path
328
+
329
+ # Gradio UI
330
+ with gr.Blocks(title="AI μŠ€ν† λ¦¬λΆ μ €μž‘ 도ꡬ") as app:
331
+ gr.Markdown("# AI μŠ€ν† λ¦¬λΆ μ €μž‘ 도ꡬ")
332
+
333
+ with gr.Row():
334
+ with gr.Column():
335
+ prompt_input = gr.Textbox(
336
+ label="μŠ€ν† λ¦¬ 주제 μž…λ ₯",
337
+ placeholder="예: μŠ€νƒ€νŠΈμ—… μ°½μ—… 성곡 μŠ€ν† λ¦¬",
338
+ lines=2
339
+ )
340
+ generate_btn = gr.Button("μŠ€ν† λ¦¬ 생성", variant="primary")
341
+
342
+ story_output = gr.Textbox(
343
+ label="μƒμ„±λœ μŠ€ν† λ¦¬",
344
+ lines=15,
345
+ interactive=True
346
+ )
347
+
348
+ create_book_btn = gr.Button(
349
+ "μŠ€ν† λ¦¬λΆ 생성 (이미지 + PDF)",
350
+ variant="secondary",
351
+ visible=False
352
+ )
353
+
354
+ with gr.Column():
355
+ image_gallery = gr.Gallery(
356
+ label="μƒμ„±λœ 이미지",
357
+ show_label=True,
358
+ elem_id="gallery",
359
+ columns=2,
360
+ rows=3,
361
+ height="auto"
362
+ )
363
+
364
+ pdf_output = gr.File(
365
+ label="PDF λ‹€μš΄λ‘œλ“œ",
366
+ visible=True
367
+ )
368
+
369
+ # μƒνƒœ μ €μž₯
370
+ paragraphs_state = gr.State([])
371
+
372
+ # 이벀트 ν•Έλ“€λŸ¬
373
+ generate_btn.click(
374
+ fn=process_story,
375
+ inputs=[prompt_input],
376
+ outputs=[story_output, create_book_btn, paragraphs_state]
377
+ )
378
+
379
+ create_book_btn.click(
380
+ fn=create_storybook,
381
+ inputs=[story_output, paragraphs_state],
382
+ outputs=[image_gallery, pdf_output]
383
+ )
384
+
385
+ if __name__ == "__main__":
386
+ app.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu128
2
+ torch==2.7.1+cu128
3
+ torchvision==0.22.1+cu128
4
+ torchaudio==2.7.1+cu128
5
+ # Hugging Face 및 κ΄€λ ¨ 라이브러리
6
+ transformers==4.52.4
7
+ diffusers==0.33.1
8
+ accelerate==1.7.0
9
+ bitsandbytes==0.46.0
10
+
11
+ # 기타 ν•„μˆ˜ νŒ¨ν‚€μ§€
12
+ gradio==5.34.0
13
+ sqlmodel==0.0.24
14
+ reportlab==4.4.1
15
+ pillow==11.2.1
16
+
17
+ # character_lora_fine_tuning μ‚¬μš© μ‹œ ν•„μš”
18
+ peft==0.15.2