qqwjq1981 commited on
Commit
d9fe1f1
·
verified ·
1 Parent(s): b9ac337

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -82
app.py CHANGED
@@ -41,6 +41,7 @@ import soundfile as sf
41
  from paddleocr import PaddleOCR
42
  import cv2
43
  from rapidfuzz import fuzz
 
44
 
45
  logger = logging.getLogger(__name__)
46
 
@@ -513,77 +514,117 @@ def solve_optimal_alignment(original_segments, generated_durations, total_durati
513
 
514
  return original_segments
515
 
516
- def extract_subtitles_with_ocr(video_path):
517
- ocr = PaddleOCR(use_angle_cls=True, lang="ch") # Change `lang` as needed
518
- vidcap = cv2.VideoCapture(video_path)
519
- fps = vidcap.get(cv2.CAP_PROP_FPS)
520
-
521
- subtitles = []
522
- frame_id = 0
523
- success, image = vidcap.read()
 
 
 
 
 
 
524
 
525
  while success:
526
- if frame_id % int(fps) == 0: # OCR 1 frame per second (adjust if needed)
527
- result = ocr.ocr(image, cls=True)
528
- texts = [line[1][0] for line in result[0]] # Get text parts
529
- combined_text = " ".join(texts).strip()
530
- if combined_text:
531
- subtitles.append({
532
- "time": frame_id / fps,
533
- "text": combined_text
534
- })
535
-
536
- frame_id += 1
537
- success, image = vidcap.read()
538
-
539
- vidcap.release()
540
- return subtitles
541
-
542
- def align_subtitles_to_transcripts(ocr_subtitles, whisperx_segments):
543
- aligned_pairs = []
544
-
545
- for ocr_entry in ocr_subtitles:
546
- ocr_time = ocr_entry["time"]
547
- best_score = -1
548
- best_segment = None
549
-
550
- for seg in whisperx_segments:
551
- # Only consider segments close in time (within +/- 2s)
552
- if abs(seg["start"] - ocr_time) < 2.0 or abs(seg["end"] - ocr_time) < 2.0:
553
- score = fuzz.ratio(seg["text"], ocr_entry["text"])
554
- if score > best_score:
555
- best_score = score
556
- best_segment = seg
557
-
558
- if best_segment:
559
- aligned_pairs.append({
560
- "whisper_text": best_segment["text"],
561
- "ocr_text": ocr_entry["text"],
562
- "start": best_segment["start"],
563
- "end": best_segment["end"],
564
- "similarity": best_score
565
- })
566
-
567
- return aligned_pairs
568
-
569
- def correct_transcripts_with_ocr(aligned_pairs):
570
- corrected_segments = []
571
-
572
- for pair in aligned_pairs:
573
- if pair["similarity"] > 80:
574
- # Trust OCR more if they are close
575
- corrected_text = pair["ocr_text"]
576
  else:
577
- corrected_text = pair["whisper_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578
 
579
- corrected_segments.append({
580
- "start": pair["start"],
581
- "end": pair["end"],
582
- "text": corrected_text
583
- })
584
 
585
- return corrected_segments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
 
 
 
 
 
 
587
  # def get_frame_image_bytes(video, t):
588
  # frame = video.get_frame(t)
589
  # img = Image.fromarray(frame)
@@ -634,21 +675,6 @@ def correct_transcripts_with_ocr(aligned_pairs):
634
  # return entry
635
 
636
 
637
- # def post_edit_translated_segments(translated_json, video_path):
638
- # video = VideoFileClip(video_path)
639
-
640
- # def process(entry):
641
- # mid_time = (entry['start'] + entry['end']) / 2
642
- # image_bytes = get_frame_image_bytes(video, mid_time)
643
- # entry = post_edit_segment(entry, image_bytes)
644
- # return entry
645
-
646
- # with concurrent.futures.ThreadPoolExecutor() as executor:
647
- # edited = list(executor.map(process, translated_json))
648
-
649
- # video.close()
650
- # return edited
651
-
652
  def process_entry(entry, i, tts_model, video_width, video_height, process_mode, target_language, font_path, speaker_sample_paths=None):
653
  logger.debug(f"Processing entry {i}: {entry}")
654
  error_message = None
@@ -953,12 +979,12 @@ def upload_and_manage(file, target_language, process_mode):
953
  transcription_json, source_language = transcribe_video_with_speakers(file.name)
954
  logger.info(f"Transcription completed. Detected source language: {source_language}")
955
 
 
956
  # Step 2: Translate the transcription
957
  logger.info(f"Translating transcription from {source_language} to {target_language}...")
958
- translated_json_raw = translate_text(transcription_json, source_language, target_language)
959
  logger.info(f"Translation completed. Number of translated segments: {len(translated_json_raw)}")
960
 
961
- # translated_json = post_edit_translated_segments(translated_json, file.name)
962
  translated_json = apply_adaptive_speed(translated_json_raw, source_language, target_language)
963
 
964
  # Step 3: Add transcript to video based on timestamps
 
41
  from paddleocr import PaddleOCR
42
  import cv2
43
  from rapidfuzz import fuzz
44
+ from tqdm import tqdm
45
 
46
  logger = logging.getLogger(__name__)
47
 
 
514
 
515
  return original_segments
516
 
517
+ def ocr_frame_worker(args):
518
+ frame_idx, frame_time, frame = args
519
+ ocr = PaddleOCR(use_angle_cls=True, lang="ch") # Initialize OCR inside worker
520
+ result = ocr.ocr(frame, cls=True)
521
+ texts = [line[1][0] for line in result[0]] if result[0] else []
522
+ combined_text = " ".join(texts).strip()
523
+ return {"time": frame_time, "text": combined_text}
524
+
525
+ def extract_ocr_subtitles_parallel(video_path, interval_sec=0.5, num_workers=4):
526
+ cap = cv2.VideoCapture(video_path)
527
+ fps = cap.get(cv2.CAP_PROP_FPS)
528
+ frames = []
529
+ frame_idx = 0
530
+ success, frame = cap.read()
531
 
532
  while success:
533
+ if frame_idx % int(fps * interval_sec) == 0:
534
+ frame_time = frame_idx / fps
535
+ frames.append((frame_idx, frame_time, frame.copy()))
536
+ success, frame = cap.read()
537
+ frame_idx += 1
538
+ cap.release()
539
+
540
+ ocr_results = []
541
+ with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
542
+ futures = [executor.submit(ocr_frame_worker, frame) for frame in frames]
543
+
544
+ for f in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
545
+ try:
546
+ result = f.result()
547
+ if result["text"]:
548
+ ocr_results.append(result)
549
+ except Exception as e:
550
+ print(f"⚠️ OCR failed for a frame: {e}")
551
+ return ocr_results
552
+
553
+ def collapse_ocr_subtitles(ocr_json, text_similarity_threshold=90):
554
+ collapsed = []
555
+ current = None
556
+ for entry in ocr_json:
557
+ time = entry["time"]
558
+ text = entry["text"]
559
+
560
+ if not current:
561
+ current = {"start": time, "end": time, "text": text}
562
+ continue
563
+
564
+ sim = fuzz.ratio(current["text"], text)
565
+ if sim >= text_similarity_threshold:
566
+ current["end"] = time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567
  else:
568
+ collapsed.append(current)
569
+ current = {"start": time, "end": time, "text": text}
570
+ if current:
571
+ collapsed.append(current)
572
+ return collapsed
573
+
574
+ def post_edit_transcribed_segments(transcription_json, video_path,
575
+ interval_sec=0.5,
576
+ text_similarity_threshold=80,
577
+ time_tolerance=1.0,
578
+ num_workers=4):
579
+ """
580
+ Given WhisperX transcription (transcription_json) and video,
581
+ use OCR subtitles to post-correct and merge the transcriptions.
582
+ """
583
+
584
+ # Step 1: Extract OCR subtitles
585
+ ocr_json = extract_ocr_subtitles_parallel(video_path, interval_sec=interval_sec, num_workers=num_workers)
586
+
587
+ # Step 2: Collapse repetitive OCR
588
+ collapsed_ocr = collapse_ocr_subtitles(ocr_json, text_similarity_threshold=90)
589
+
590
+ # Step 3: Merge OCR with WhisperX
591
+ merged_segments = []
592
 
593
+ for entry in transcription_json:
594
+ start = entry.get("start", 0)
595
+ end = entry.get("end", 0)
596
+ base_text = entry.get("text", "")
 
597
 
598
+ best_match = None
599
+ best_score = -1
600
+
601
+ for ocr in collapsed_ocr:
602
+ # Check time overlap
603
+ time_overlap = not (ocr["end"] < start - time_tolerance or ocr["start"] > end + time_tolerance)
604
+ if not time_overlap:
605
+ continue
606
+
607
+ # Text similarity
608
+ sim = fuzz.ratio(ocr["text"], base_text)
609
+ if sim > best_score:
610
+ best_score = sim
611
+ best_match = ocr
612
+
613
+ # If good match found, replace the original text
614
+ updated_entry = entry.copy()
615
+ if best_match and best_score >= text_similarity_threshold:
616
+ updated_entry["text"] = best_match["text"]
617
+ updated_entry["ocr_matched"] = True
618
+ updated_entry["ocr_similarity"] = best_score
619
+ else:
620
+ updated_entry["ocr_matched"] = False
621
+ updated_entry["ocr_similarity"] = best_score if best_score >= 0 else None
622
 
623
+ merged_segments.append(updated_entry)
624
+
625
+ print(f"✅ Post-editing completed: {len(merged_segments)} segments")
626
+ return merged_segments
627
+
628
  # def get_frame_image_bytes(video, t):
629
  # frame = video.get_frame(t)
630
  # img = Image.fromarray(frame)
 
675
  # return entry
676
 
677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
  def process_entry(entry, i, tts_model, video_width, video_height, process_mode, target_language, font_path, speaker_sample_paths=None):
679
  logger.debug(f"Processing entry {i}: {entry}")
680
  error_message = None
 
979
  transcription_json, source_language = transcribe_video_with_speakers(file.name)
980
  logger.info(f"Transcription completed. Detected source language: {source_language}")
981
 
982
+ transcription_json_merged = post_edit_translated_segments(transcription_json, file.name)
983
  # Step 2: Translate the transcription
984
  logger.info(f"Translating transcription from {source_language} to {target_language}...")
985
+ translated_json_raw = translate_text(transcription_json_merged, source_language, target_language)
986
  logger.info(f"Translation completed. Number of translated segments: {len(translated_json_raw)}")
987
 
 
988
  translated_json = apply_adaptive_speed(translated_json_raw, source_language, target_language)
989
 
990
  # Step 3: Add transcript to video based on timestamps