nangunan commited on
Commit
d2fb2c8
Β·
1 Parent(s): cda63cc
Files changed (2) hide show
  1. app.py +407 -0
  2. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ import requests
4
+ import gradio as gr
5
+ import whisper
6
+ import subprocess
7
+ import uuid
8
+ import torch
9
+ import re
10
+ import matplotlib.pyplot as plt
11
+ import language_tool_python
12
+ import difflib
13
+ from transformers import (
14
+ AutoTokenizer,
15
+ AutoModelForSeq2SeqLM,
16
+ pipeline as hf_pipeline,
17
+ )
18
+
19
+ # ────────────────────────────────────────────────────────────────
20
+ # Optional evaluation libraries
21
+ try:
22
+ from rouge_score import rouge_scorer
23
+ except ImportError:
24
+ rouge_scorer = None
25
+ print("[Warning] rouge_score νŒ¨ν‚€μ§€κ°€ μ—†μŠ΅λ‹ˆλ‹€. pip install rouge-score")
26
+
27
+ try:
28
+ from bert_score import score as bert_score_func
29
+ except ImportError:
30
+ bert_score_func = None
31
+ print("[Warning] bert-score νŒ¨ν‚€μ§€κ°€ μ—†μŠ΅λ‹ˆλ‹€. pip install bert-score")
32
+
33
+ # ────────────────────────────────────────────────────────────────
34
+ # ν•œκΈ€ λ§žμΆ€λ²• 검사(py‑hanspell)
35
+ try:
36
+ from hanspell import spell_checker
37
+ except ImportError:
38
+ spell_checker = None
39
+
40
+ # ────────────────────────────────────────────────────────────────
41
+ # LanguageTool λ£° 기반 ꡐ정 (μ˜μ–΄ μ „μš©)
42
+ try:
43
+ lt_tool = language_tool_python.LanguageTool('en-US')
44
+ except Exception as e:
45
+ lt_tool = None
46
+ print(f"[Warning] LanguageTool μ΄ˆκΈ°ν™” μ‹€νŒ¨: {e}")
47
+
48
+ # ────────────────────────────────────────────────────────────────
49
+ # FFmpeg
50
+ yt_dlp_path = "C:/Windows/System32/yt-dlp.exe"
51
+ ffmpeg_path = "C:/ProgramData/chocolatey/bin"
52
+ def download_ffmpeg(dest_bin):
53
+ if os.path.isdir(dest_bin) and os.path.isfile(os.path.join(dest_bin, "ffmpeg.exe")):
54
+ return dest_bin
55
+ url = "https://www.gyan.dev/ffmpeg/builds/ffmpeg-release-essentials.zip"
56
+ zip_path = os.path.join(os.getcwd(), "ffmpeg.zip")
57
+ extract_root = os.path.dirname(dest_bin)
58
+ os.makedirs(extract_root, exist_ok=True)
59
+ resp = requests.get(url, stream=True); resp.raise_for_status()
60
+ with open(zip_path, "wb") as f:
61
+ for chunk in resp.iter_content(8192): f.write(chunk)
62
+ with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(extract_root)
63
+ os.remove(zip_path)
64
+ for root, _, files in os.walk(extract_root):
65
+ if "ffmpeg.exe" in files:
66
+ os.makedirs(dest_bin, exist_ok=True)
67
+ for fn in ("ffmpeg.exe","ffprobe.exe","ffplay.exe"):
68
+ src, dst = os.path.join(root,fn), os.path.join(dest_bin,fn)
69
+ if os.path.isfile(src): os.replace(src, dst)
70
+ return dest_bin
71
+ raise RuntimeError("FFmpeg μ„€μΉ˜ μ‹€νŒ¨")
72
+
73
+ download_ffmpeg(ffmpeg_path)
74
+ os.environ["PATH"] = ffmpeg_path + os.pathsep + os.environ.get("PATH","")
75
+
76
+ # ────────────────────────────────────────────────────────────────
77
+ # Whisper
78
+ asr_model = whisper.load_model("medium")
79
+
80
+ # ────────────────────────────────────────────────────────────────
81
+ # μš”μ•½ λͺ¨λΈ(λͺ¨λΈ/ν† ν¬λ‚˜μ΄μ € 직접 μ‚¬μš©, pipeline X)
82
+ SUMMARY_MODELS = {
83
+ "mT5_multilingual_XLSum": "csebuetnlp/mT5_multilingual_XLSum",
84
+ "Pegasus XSum": "google/pegasus-xsum",
85
+ "BART-large CNN": "facebook/bart-large-cnn",
86
+ "DistilBART CNN": "sshleifer/distilbart-cnn-12-6"
87
+ }
88
+ tokenizers, models = {}, {}
89
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
90
+
91
+ def load_summarizer(label: str):
92
+ if label in models:
93
+ return
94
+ repo = SUMMARY_MODELS[label]
95
+ tok = AutoTokenizer.from_pretrained(repo, use_fast=False)
96
+ model = AutoModelForSeq2SeqLM.from_pretrained(repo).to(device)
97
+ model.eval()
98
+ tokenizers[label] = tok
99
+ models[label] = model
100
+
101
+ if rouge_scorer:
102
+ scorer = rouge_scorer.RougeScorer(["rouge1","rouge2","rougeL"], use_stemmer=True)
103
+
104
+ # ────────────────────────────────────────────────────────────────
105
+ # 문법 ꡐ정
106
+ GRAMMAR_MODELS = {
107
+ "LanguageTool-en": None,
108
+ "py-hanspell": None,
109
+ "GEC-ν•œκ΅­μ–΄": "Soyoung97/gec_kr"
110
+ }
111
+ grammar_pipes = {}
112
+
113
+ def load_grammar_pipe(name: str):
114
+ repo = GRAMMAR_MODELS[name]
115
+ grammar_pipes[name] = hf_pipeline(
116
+ "text2text-generation",
117
+ model=repo,
118
+ tokenizer=AutoTokenizer.from_pretrained(repo),
119
+ device=0 if torch.cuda.is_available() else -1
120
+ )
121
+
122
+ def correct_spelling(text, max_chunk=500):
123
+ if not spell_checker: return text
124
+ parts, curr = re.split(r'([.?!]\s*)', text), ""
125
+ segs, out = [], []
126
+ for p in parts:
127
+ if len(curr)+len(p) <= max_chunk: curr += p
128
+ else: segs.append(curr); curr = p
129
+ if curr: segs.append(curr)
130
+ for s in segs:
131
+ try: out.append(spell_checker.check(s).checked)
132
+ except: out.append(s)
133
+ return " ".join(o.strip() for o in out)
134
+
135
+ def correct_text(text, method="GEC-ν•œκ΅­μ–΄"):
136
+ if method=="py-hanspell":
137
+ return correct_spelling(text)
138
+ if method=="LanguageTool-en" and lt_tool:
139
+ matches = lt_tool.check(text)
140
+ return language_tool_python.utils.correct(text, matches)
141
+ if method=="GEC-ν•œκ΅­μ–΄":
142
+ if method not in grammar_pipes:
143
+ load_grammar_pipe(method)
144
+ pipe = grammar_pipes[method]
145
+ sents = re.split(r'(?<=[.?!])\s+', text)
146
+ corrected=[]
147
+ for sent in sents:
148
+ gen = pipe(sent, max_length=256, min_length=1, do_sample=False)[0]["generated_text"]
149
+ corrected.append(gen.strip())
150
+ return " ".join(corrected)
151
+ return text
152
+
153
+ # ────────────────────────────────────────────────────────────────
154
+ # ꡐ정λ₯  + Diff
155
+ def calculate_correction_rate(original, corrected):
156
+ orig_tokens = original.split()
157
+ corr_tokens = corrected.split()
158
+ sm = difflib.SequenceMatcher(None, orig_tokens, corr_tokens)
159
+ diff_count = sum((i2 - i1) for tag, i1, i2, j1, j2 in sm.get_opcodes() if tag != 'equal')
160
+ total = max(len(orig_tokens), 1)
161
+ return round(100 * diff_count / total, 2)
162
+
163
+ def highlight_diff(original: str, corrected: str) -> str:
164
+ diff = difflib.ndiff(original.split(), corrected.split())
165
+ html_parts = []
166
+ for token in diff:
167
+ if token.startswith("+ "):
168
+ html_parts.append(f"<span style='color:red;'>{token[2:]}</span>")
169
+ elif token.startswith("- "):
170
+ continue
171
+ else:
172
+ html_parts.append(token[2:])
173
+ return " ".join(html_parts)
174
+
175
+ # ────────────────────────────────────────────────────────────────
176
+ # YouTube
177
+ def download_audio(url):
178
+ fname = f"yt_{uuid.uuid4().hex[:8]}.mp3"
179
+ cmd = [yt_dlp_path,"-f","bestaudio","--extract-audio","--audio-format","mp3","-o",fname,url]
180
+ res = subprocess.run(cmd, capture_output=True, text=True)
181
+ if res.returncode!=0: raise RuntimeError(res.stderr)
182
+ return fname
183
+
184
+ def get_transcript(url, state):
185
+ if state and state.get("url")==url:
186
+ return state["orig"], state
187
+ audio = download_audio(url)
188
+ res = asr_model.transcribe(audio)
189
+ orig = res.get("text","")
190
+ os.remove(audio)
191
+ return orig, {"url":url, "orig":orig}
192
+
193
+ # ────────────────────────────────────────────────────────────────
194
+ # μ•ˆμ „ν•œ 청크 μš”μ•½ (model.generate 직접 호좜)
195
+ def summarize_long_text(text: str, label: str, chunk_size: int = 512) -> str:
196
+ load_summarizer(label)
197
+ tok = tokenizers[label]
198
+ model= models[label]
199
+
200
+ enc = tok(text, return_tensors="pt", truncation=False)
201
+ ids = enc.input_ids[0]
202
+ summaries = []
203
+
204
+ max_ctx = getattr(model.config, "max_position_embeddings", 1024) - 4
205
+ chunk_size = min(chunk_size, max_ctx)
206
+
207
+ for i in range(0, len(ids), chunk_size):
208
+ chunk_ids = ids[i:i+chunk_size].unsqueeze(0).to(device)
209
+ out_ids = model.generate(
210
+ chunk_ids,
211
+ max_new_tokens=128,
212
+ num_beams=4,
213
+ do_sample=False
214
+ )
215
+ summ = tok.decode(out_ids[0], skip_special_tokens=True)
216
+ summaries.append(summ)
217
+
218
+ combined = " ".join(summaries)
219
+ enc2 = tok(combined, return_tensors="pt", truncation=True, max_length=max_ctx).to(device)
220
+ out_ids = model.generate(
221
+ **enc2,
222
+ max_new_tokens=128,
223
+ num_beams=4,
224
+ do_sample=False
225
+ )
226
+ final = tok.decode(out_ids[0], skip_special_tokens=True)
227
+ return final
228
+
229
+ # ────────────────────────────────────────────────────────────────
230
+ def summarize_single(url, label, grammar_method, transcript_state):
231
+ orig, new_state = get_transcript(url, transcript_state)
232
+ corr = correct_text(orig, method=grammar_method)
233
+ corr_rate = calculate_correction_rate(orig, corr)
234
+ corr_html = f"<div><b>ꡐ정λ₯ :</b> {corr_rate}%</div><hr/>{highlight_diff(orig, corr)}"
235
+
236
+ summary = summarize_long_text(corr, label) if len(corr) > 100 else "⚠️ μš”μ•½ λΆˆκ°€"
237
+
238
+ rouge_vals=[0,0,0]
239
+ if rouge_scorer and summary.strip():
240
+ sc = scorer.score(orig, summary)
241
+ rouge_vals=[sc["rouge1"].fmeasure, sc["rouge2"].fmeasure, sc["rougeL"].fmeasure]
242
+
243
+ bert_f1=0
244
+ if bert_score_func and summary.strip():
245
+ try:
246
+ _,_,F = bert_score_func([summary],[orig],lang="ko")
247
+ except Exception:
248
+ _,_,F = bert_score_func([summary],[orig],lang="en")
249
+ bert_f1=float(F.mean())
250
+
251
+ fig,ax=plt.subplots()
252
+ ax.bar(["R1","R2","RL","BERT-F1"], rouge_vals+[bert_f1])
253
+ ax.set_ylim(0,1); ax.set_ylabel("Score"); ax.set_title("Summary Fidelity")
254
+ plt.tight_layout()
255
+
256
+ return orig, corr_html, summary, fig, new_state
257
+
258
+ # ────────────────────────────────────────────────────────────────
259
+ def summarize_all(url, grammar_method, transcript_state):
260
+ orig, new_state = get_transcript(url, transcript_state)
261
+ corr = correct_text(orig, method=grammar_method)
262
+ corr_rate = calculate_correction_rate(orig, corr)
263
+ corr_html = f"<div><b>ꡐ정λ₯ :</b> {corr_rate}%</div><hr/>{highlight_diff(orig, corr)}"
264
+
265
+ figs, interps, rv_list, bf_list = [], [], [], []
266
+ summaries_plain = []
267
+ labels = list(SUMMARY_MODELS.keys())
268
+
269
+ for label in labels:
270
+ summ = summarize_long_text(corr, label)
271
+ summaries_plain.append(summ)
272
+
273
+ rv=[0,0,0]; bf=0
274
+ if rouge_scorer:
275
+ sc = scorer.score(orig, summ)
276
+ rv=[sc["rouge1"].fmeasure, sc["rouge2"].fmeasure, sc["rougeL"].fmeasure]
277
+ if bert_score_func:
278
+ try:
279
+ _,_,F = bert_score_func([summ],[orig],lang="ko")
280
+ except Exception:
281
+ _,_,F = bert_score_func([summ],[orig],lang="en")
282
+ bf=float(F.mean())
283
+ rv_list.append(rv); bf_list.append(bf)
284
+
285
+ fig,ax=plt.subplots()
286
+ ax.bar(["R1","R2","RL","BERT-F1"], rv+[bf])
287
+ ax.set_ylim(0,1); ax.set_title(label)
288
+ plt.tight_layout(); figs.append(fig)
289
+
290
+ note="정보 손싀 많음"
291
+ if bf>0.8: note="핡심 정보 잘 반영"
292
+ elif bf>0.5: note="μ£Όμš” λ‚΄μš© 포함"
293
+ interps.append(f"{label}: {note} (F1={bf:.2f})")
294
+
295
+ html = "<h3>λͺ¨λΈλ³„ μš”μ•½ & Fidelity Metrics</h3>"
296
+ html+= f"<p><b>ꡐ정λ₯ :</b> {corr_rate}%</p>"
297
+ html+= "<table border='1' style='border-collapse:collapse; width:100%; table-layout:fixed;'>"
298
+ html+= "<tr><th style='width:12%'>λͺ¨λΈ</th><th style='width:58%'>μš”μ•½λ¬Έ</th><th style='width:5%'>R1</th><th style='width:5%'>R2</th><th style='width:5%'>RL</th><th style='width:7%'>BERT-F1</th><th style='width:8%'>해석</th></tr>"
299
+
300
+ for i,label in enumerate(labels):
301
+ r1,r2,rl = rv_list[i]
302
+ bf = bf_list[i]
303
+ note = "정보 손싀 많음"
304
+ if bf>0.8: note="핡심 정보 잘 반영"
305
+ elif bf>0.5: note="μ£Όμš” λ‚΄μš© 포함"
306
+
307
+ summ_html = summaries_plain[i].replace("<", "&lt;")
308
+ html+= (
309
+ f"<tr>"
310
+ f"<td>{label}</td>"
311
+ f"<td style='white-space:pre-wrap; word-break:break-word'>{summ_html}</td>"
312
+ f"<td>{r1:.2f}</td><td>{r2:.2f}</td><td>{rl:.2f}</td>"
313
+ f"<td>{bf:.2f}</td><td>{note}</td>"
314
+ f"</tr>"
315
+ )
316
+ html+="</table>"
317
+
318
+ return [orig, corr_html] + figs + interps + [html, new_state]
319
+
320
+ # ────────────────────────────────────────────────────────────────
321
+ def save_summary(url, label):
322
+ orig, _ = get_transcript(url, None)
323
+ corr = correct_text(orig, "GEC-ν•œκ΅­μ–΄")
324
+ summary = summarize_long_text(corr, label)
325
+ path = os.path.join(os.getcwd(), f"summary_{label}.txt")
326
+ with open(path, "w", encoding="utf-8") as f:
327
+ f.write(summary)
328
+ return path
329
+
330
+ # ────────────────────────────────────────────────────────────────
331
+ # CSS (ꡐ정 μžλ§‰μ„ λ°•μŠ€μ²˜λŸΌ 보이게)
332
+ CUSTOM_CSS = """
333
+ #corr_box, #corr_box_all {
334
+ border: 1px solid #ccc;
335
+ padding: 10px;
336
+ border-radius: 6px;
337
+ background-color: #fafafa;
338
+ max-height: 300px;
339
+ overflow-y: auto;
340
+ white-space: pre-wrap;
341
+ }
342
+ """
343
+
344
+ # Gradio
345
+ with gr.Blocks(css=CUSTOM_CSS) as demo:
346
+ gr.Markdown("## 🎬 YouTube μš”μ•½ μ„œλΉ„μŠ€ (ꡐ정 + ꡐ정λ₯  + Diff κ°•μ‘°, μ•ˆμ „ μ²­ν¬μš”μ•½)")
347
+
348
+ with gr.Tabs():
349
+ with gr.TabItem("단일 λͺ¨λΈ μš”μ•½"):
350
+ url_input = gr.Textbox(label="YouTube URL")
351
+ model_sel = gr.Dropdown(list(SUMMARY_MODELS.keys()), label="μš”μ•½ λͺ¨λΈ")
352
+ grammar_sel = gr.Dropdown(list(GRAMMAR_MODELS.keys()), label="ꡐ정 λͺ¨λΈ", value="GEC-ν•œκ΅­μ–΄")
353
+ transcript_state = gr.State(None)
354
+ btn_single = gr.Button("μš”μ•½ μ‹€ν–‰")
355
+
356
+ orig_tb = gr.Textbox(label="원문 μžλ§‰", lines=10)
357
+ corr_tb = gr.HTML(label="ꡐ정 μžλ§‰ (변경점 κ°•μ‘°)", elem_id="corr_box")
358
+ sum_tb = gr.Textbox(label="μš”μ•½ κ²°κ³Ό", lines=8)
359
+ fidelity_plot = gr.Plot(label="Fidelity Metrics")
360
+ save_btn = gr.Button("μš”μ•½ μ €μž₯")
361
+ download_single = gr.File(label="λ‹€μš΄λ‘œλ“œ 파일")
362
+
363
+ btn_single.click(
364
+ fn=summarize_single,
365
+ inputs=[url_input, model_sel, grammar_sel, transcript_state],
366
+ outputs=[orig_tb, corr_tb, sum_tb, fidelity_plot, transcript_state]
367
+ )
368
+ save_btn.click(
369
+ fn=save_summary,
370
+ inputs=[url_input, model_sel],
371
+ outputs=[download_single]
372
+ )
373
+
374
+ with gr.TabItem("전체 λͺ¨λΈ 비ꡐ"):
375
+ url_all = gr.Textbox(label="YouTube URL")
376
+ grammar_sel_all = gr.Dropdown(list(GRAMMAR_MODELS.keys()), label="ꡐ정 λͺ¨λΈ", value="GEC-ν•œκ΅­μ–΄")
377
+ transcript_state_all = gr.State(None)
378
+ btn_all = gr.Button("λͺ¨λ‘ μ‹€ν–‰")
379
+
380
+ orig_all = gr.Textbox(label="원문 μžλ§‰", lines=10)
381
+ corr_all = gr.HTML(label="ꡐ정 μžλ§‰ (변경점 κ°•μ‘°)", elem_id="corr_box_all")
382
+
383
+ plot_components, interp_components = [], []
384
+ for label in SUMMARY_MODELS:
385
+ plot_components.append(gr.Plot(label=f"{label} Metrics"))
386
+ interp_components.append(gr.HTML(label=f"{label} 해석"))
387
+
388
+ agg_table = gr.HTML(label="λͺ¨λΈλ³„ μš”μ•½ & Fidelity Metrics")
389
+ save_all_sel = gr.Radio(list(SUMMARY_MODELS.keys()), label="μ €μž₯ λͺ¨λΈ μ§€μ •")
390
+ save_all_btn = gr.Button("선택 μš”μ•½ μ €μž₯")
391
+ download_all = gr.File(label="λ‹€μš΄λ‘œλ“œ 파일")
392
+
393
+ btn_all.click(
394
+ fn=summarize_all,
395
+ inputs=[url_all, grammar_sel_all, transcript_state_all],
396
+ outputs=[orig_all, corr_all] + plot_components + interp_components + [agg_table, transcript_state_all]
397
+ )
398
+ save_all_btn.click(
399
+ fn=save_summary,
400
+ inputs=[url_all, save_all_sel],
401
+ outputs=[download_all]
402
+ )
403
+
404
+ if __name__ == '__main__':
405
+ # μžλ™ 포트 ν• λ‹Ή
406
+ demo.launch(server_name="127.0.0.1")
407
+ # ν˜Ήμ€ μ™„μ „ μžλ™: demo.launch()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ sentencepiece
4
+ gradio
5
+ git+https://github.com/openai/whisper.git
6
+ matplotlib
7
+ requests
8
+ uuid
9
+ language-tool-python
10
+ rouge-score
11
+ bert-score
12
+ yt-dlp