sbapan41 commited on
Commit
b8d1b97
·
verified ·
1 Parent(s): 75101d3

Upload 12 files

Browse files
quantum_dubbing/audio_segments.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydub import AudioSegment
2
+ from tqdm import tqdm
3
+ from .utils import run_command
4
+ from .logging_setup import logger
5
+ import numpy as np
6
+
7
+
8
+ class Mixer:
9
+ def __init__(self):
10
+ self.parts = []
11
+
12
+ def __len__(self):
13
+ parts = self._sync()
14
+ seg = parts[0][1]
15
+ frame_count = max(offset + seg.frame_count() for offset, seg in parts)
16
+ return int(1000.0 * frame_count / seg.frame_rate)
17
+
18
+ def overlay(self, sound, position=0):
19
+ self.parts.append((position, sound))
20
+ return self
21
+
22
+ def _sync(self):
23
+ positions, segs = zip(*self.parts)
24
+
25
+ frame_rate = segs[0].frame_rate
26
+ array_type = segs[0].array_type # noqa
27
+
28
+ offsets = [int(frame_rate * pos / 1000.0) for pos in positions]
29
+ segs = AudioSegment.empty()._sync(*segs)
30
+ return list(zip(offsets, segs))
31
+
32
+ def append(self, sound):
33
+ self.overlay(sound, position=len(self))
34
+
35
+ def to_audio_segment(self):
36
+ parts = self._sync()
37
+ seg = parts[0][1]
38
+ channels = seg.channels
39
+
40
+ frame_count = max(offset + seg.frame_count() for offset, seg in parts)
41
+ sample_count = int(frame_count * seg.channels)
42
+
43
+ output = np.zeros(sample_count, dtype="int32")
44
+ for offset, seg in parts:
45
+ sample_offset = offset * channels
46
+ samples = np.frombuffer(seg.get_array_of_samples(), dtype="int32")
47
+ samples = np.int16(samples/np.max(np.abs(samples)) * 32767)
48
+ start = sample_offset
49
+ end = start + len(samples)
50
+ output[start:end] += samples
51
+
52
+ return seg._spawn(
53
+ output, overrides={"sample_width": 4}).normalize(headroom=0.0)
54
+
55
+
56
+ def create_translated_audio(
57
+ result_diarize, audio_files, final_file, concat=False, avoid_overlap=False,
58
+ ):
59
+ total_duration = result_diarize["segments"][-1]["end"] # in seconds
60
+
61
+ if concat:
62
+ """
63
+ file .\audio\1.ogg
64
+ file .\audio\2.ogg
65
+ file .\audio\3.ogg
66
+ file .\audio\4.ogg
67
+ ...
68
+ """
69
+
70
+ # Write the file paths to list.txt
71
+ with open("list.txt", "w") as file:
72
+ for i, audio_file in enumerate(audio_files):
73
+ if i == len(audio_files) - 1: # Check if it's the last item
74
+ file.write(f"file {audio_file}")
75
+ else:
76
+ file.write(f"file {audio_file}\n")
77
+
78
+ # command = f"ffmpeg -f concat -safe 0 -i list.txt {final_file}"
79
+ command = (
80
+ f"ffmpeg -f concat -safe 0 -i list.txt -c:a pcm_s16le {final_file}"
81
+ )
82
+ run_command(command)
83
+
84
+ else:
85
+ # silent audio with total_duration
86
+ base_audio = AudioSegment.silent(
87
+ duration=int(total_duration * 1000), frame_rate=41000
88
+ )
89
+ combined_audio = Mixer()
90
+ combined_audio.overlay(base_audio)
91
+
92
+ logger.debug(
93
+ f"Audio duration: {total_duration // 60} "
94
+ f"minutes and {int(total_duration % 60)} seconds"
95
+ )
96
+
97
+ last_end_time = 0
98
+ previous_speaker = ""
99
+ for line, audio_file in tqdm(
100
+ zip(result_diarize["segments"], audio_files)
101
+ ):
102
+ start = float(line["start"])
103
+
104
+ # Overlay each audio at the corresponding time
105
+ try:
106
+ audio = AudioSegment.from_file(audio_file)
107
+ # audio_a = audio.speedup(playback_speed=1.5)
108
+
109
+ if avoid_overlap:
110
+ speaker = line["speaker"]
111
+ if (last_end_time - 0.500) > start:
112
+ overlap_time = last_end_time - start
113
+ if previous_speaker and previous_speaker != speaker:
114
+ start = (last_end_time - 0.500)
115
+ else:
116
+ start = (last_end_time - 0.200)
117
+ if overlap_time > 2.5:
118
+ start = start - 0.3
119
+ logger.info(
120
+ f"Avoid overlap for {str(audio_file)} "
121
+ f"with {str(start)}"
122
+ )
123
+
124
+ previous_speaker = speaker
125
+
126
+ duration_tts_seconds = len(audio) / 1000.0 # to sec
127
+ last_end_time = (start + duration_tts_seconds)
128
+
129
+ start_time = start * 1000 # to ms
130
+ combined_audio = combined_audio.overlay(
131
+ audio, position=start_time
132
+ )
133
+ except Exception as error:
134
+ logger.debug(str(error))
135
+ logger.error(f"Error audio file {audio_file}")
136
+
137
+ # combined audio as a file
138
+ combined_audio_data = combined_audio.to_audio_segment()
139
+ combined_audio_data.export(
140
+ final_file, format="wav"
141
+ ) # best than ogg, change if the audio is anomalous
quantum_dubbing/language_configuration.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .logging_setup import logger
2
+
3
+ LANGUAGES_UNIDIRECTIONAL = {
4
+ "Aymara (ay)": "ay",
5
+ "Bambara (bm)": "bm",
6
+ "Cebuano (ceb)": "ceb",
7
+ "Chichewa (ny)": "ny",
8
+ "Divehi (dv)": "dv",
9
+ "Dogri (doi)": "doi",
10
+ "Ewe (ee)": "ee",
11
+ "Guarani (gn)": "gn",
12
+ "Iloko (ilo)": "ilo",
13
+ "Kinyarwanda (rw)": "rw",
14
+ "Krio (kri)": "kri",
15
+ "Kurdish (ku)": "ku",
16
+ "Kirghiz (ky)": "ky",
17
+ "Ganda (lg)": "lg",
18
+ "Maithili (mai)": "mai",
19
+ "Oriya (or)": "or",
20
+ "Oromo (om)": "om",
21
+ "Quechua (qu)": "qu",
22
+ "Samoan (sm)": "sm",
23
+ "Tigrinya (ti)": "ti",
24
+ "Tsonga (ts)": "ts",
25
+ "Akan (ak)": "ak",
26
+ "Uighur (ug)": "ug"
27
+ }
28
+
29
+ UNIDIRECTIONAL_L_LIST = LANGUAGES_UNIDIRECTIONAL.keys()
30
+
31
+ LANGUAGES = {
32
+ "Automatic detection": "Automatic detection",
33
+ "Arabic (ar)": "ar",
34
+ "Chinese - Simplified (zh-CN)": "zh",
35
+ "Czech (cs)": "cs",
36
+ "Danish (da)": "da",
37
+ "Dutch (nl)": "nl",
38
+ "English (en)": "en",
39
+ "Finnish (fi)": "fi",
40
+ "French (fr)": "fr",
41
+ "German (de)": "de",
42
+ "Greek (el)": "el",
43
+ "Hebrew (he)": "he",
44
+ "Hungarian (hu)": "hu",
45
+ "Italian (it)": "it",
46
+ "Japanese (ja)": "ja",
47
+ "Korean (ko)": "ko",
48
+ "Persian (fa)": "fa", # no aux gTTS
49
+ "Polish (pl)": "pl",
50
+ "Portuguese (pt)": "pt",
51
+ "Russian (ru)": "ru",
52
+ "Spanish (es)": "es",
53
+ "Turkish (tr)": "tr",
54
+ "Ukrainian (uk)": "uk",
55
+ "Urdu (ur)": "ur",
56
+ "Vietnamese (vi)": "vi",
57
+ "Hindi (hi)": "hi",
58
+ "Indonesian (id)": "id",
59
+ "Bengali (bn)": "bn",
60
+ "Telugu (te)": "te",
61
+ "Marathi (mr)": "mr",
62
+ "Tamil (ta)": "ta",
63
+ "Javanese (jw|jv)": "jw",
64
+ "Catalan (ca)": "ca",
65
+ "Nepali (ne)": "ne",
66
+ "Thai (th)": "th",
67
+ "Swedish (sv)": "sv",
68
+ "Amharic (am)": "am",
69
+ "Welsh (cy)": "cy", # no aux gTTS
70
+ "Estonian (et)": "et",
71
+ "Croatian (hr)": "hr",
72
+ "Icelandic (is)": "is",
73
+ "Georgian (ka)": "ka", # no aux gTTS
74
+ "Khmer (km)": "km",
75
+ "Slovak (sk)": "sk",
76
+ "Albanian (sq)": "sq",
77
+ "Serbian (sr)": "sr",
78
+ "Azerbaijani (az)": "az", # no aux gTTS
79
+ "Bulgarian (bg)": "bg",
80
+ "Galician (gl)": "gl", # no aux gTTS
81
+ "Gujarati (gu)": "gu",
82
+ "Kazakh (kk)": "kk", # no aux gTTS
83
+ "Kannada (kn)": "kn",
84
+ "Lithuanian (lt)": "lt", # no aux gTTS
85
+ "Latvian (lv)": "lv",
86
+ "Macedonian (mk)": "mk", # no aux gTTS # error get align model
87
+ "Malayalam (ml)": "ml",
88
+ "Malay (ms)": "ms", # error get align model
89
+ "Romanian (ro)": "ro",
90
+ "Sinhala (si)": "si",
91
+ "Sundanese (su)": "su",
92
+ "Swahili (sw)": "sw", # error aling
93
+ "Afrikaans (af)": "af",
94
+ "Bosnian (bs)": "bs",
95
+ "Latin (la)": "la",
96
+ "Myanmar Burmese (my)": "my",
97
+ "Norwegian (no|nb)": "no",
98
+ "Chinese - Traditional (zh-TW)": "zh-TW",
99
+ "Assamese (as)": "as",
100
+ "Basque (eu)": "eu",
101
+ "Hausa (ha)": "ha",
102
+ "Haitian Creole (ht)": "ht",
103
+ "Armenian (hy)": "hy",
104
+ "Lao (lo)": "lo",
105
+ "Malagasy (mg)": "mg",
106
+ "Mongolian (mn)": "mn",
107
+ "Maltese (mt)": "mt",
108
+ "Punjabi (pa)": "pa",
109
+ "Pashto (ps)": "ps",
110
+ "Slovenian (sl)": "sl",
111
+ "Shona (sn)": "sn",
112
+ "Somali (so)": "so",
113
+ "Tajik (tg)": "tg",
114
+ "Turkmen (tk)": "tk",
115
+ "Tatar (tt)": "tt",
116
+ "Uzbek (uz)": "uz",
117
+ "Yoruba (yo)": "yo",
118
+ **LANGUAGES_UNIDIRECTIONAL
119
+ }
120
+
121
+ BASE_L_LIST = LANGUAGES.keys()
122
+ LANGUAGES_LIST = [list(BASE_L_LIST)[0]] + sorted(list(BASE_L_LIST)[1:])
123
+ INVERTED_LANGUAGES = {value: key for key, value in LANGUAGES.items()}
124
+
125
+ EXTRA_ALIGN = {
126
+ "id": "indonesian-nlp/wav2vec2-large-xlsr-indonesian",
127
+ "bn": "arijitx/wav2vec2-large-xlsr-bengali",
128
+ "mr": "sumedh/wav2vec2-large-xlsr-marathi",
129
+ "ta": "Amrrs/wav2vec2-large-xlsr-53-tamil",
130
+ "jw": "cahya/wav2vec2-large-xlsr-javanese",
131
+ "ne": "shniranjan/wav2vec2-large-xlsr-300m-nepali",
132
+ "th": "sakares/wav2vec2-large-xlsr-thai-demo",
133
+ "sv": "KBLab/wav2vec2-large-voxrex-swedish",
134
+ "am": "agkphysics/wav2vec2-large-xlsr-53-amharic",
135
+ "cy": "Srulikbdd/Wav2Vec2-large-xlsr-welsh",
136
+ "et": "anton-l/wav2vec2-large-xlsr-53-estonian",
137
+ "hr": "classla/wav2vec2-xls-r-parlaspeech-hr",
138
+ "is": "carlosdanielhernandezmena/wav2vec2-large-xlsr-53-icelandic-ep10-1000h",
139
+ "ka": "MehdiHosseiniMoghadam/wav2vec2-large-xlsr-53-Georgian",
140
+ "km": "vitouphy/wav2vec2-xls-r-300m-khmer",
141
+ "sk": "infinitejoy/wav2vec2-large-xls-r-300m-slovak",
142
+ "sq": "Alimzhan/wav2vec2-large-xls-r-300m-albanian-colab",
143
+ "sr": "dnikolic/wav2vec2-xlsr-530-serbian-colab",
144
+ "az": "nijatzeynalov/wav2vec2-large-mms-1b-azerbaijani-common_voice15.0",
145
+ "bg": "infinitejoy/wav2vec2-large-xls-r-300m-bulgarian",
146
+ "gl": "ifrz/wav2vec2-large-xlsr-galician",
147
+ "gu": "Harveenchadha/vakyansh-wav2vec2-gujarati-gnm-100",
148
+ "kk": "aismlv/wav2vec2-large-xlsr-kazakh",
149
+ "kn": "Harveenchadha/vakyansh-wav2vec2-kannada-knm-560",
150
+ "lt": "DeividasM/wav2vec2-large-xlsr-53-lithuanian",
151
+ "lv": "anton-l/wav2vec2-large-xlsr-53-latvian",
152
+ "mk": "", # Konstantin-Bogdanoski/wav2vec2-macedonian-base
153
+ "ml": "gvs/wav2vec2-large-xlsr-malayalam",
154
+ "ms": "", # Duy/wav2vec2_malay
155
+ "ro": "anton-l/wav2vec2-large-xlsr-53-romanian",
156
+ "si": "IAmNotAnanth/wav2vec2-large-xls-r-300m-sinhala",
157
+ "su": "cahya/wav2vec2-large-xlsr-sundanese",
158
+ "sw": "", # Lians/fine-tune-wav2vec2-large-swahili
159
+ "af": "", # ylacombe/wav2vec2-common_voice-af-demo
160
+ "bs": "",
161
+ "la": "",
162
+ "my": "",
163
+ "no": "NbAiLab/wav2vec2-xlsr-300m-norwegian",
164
+ "zh-TW": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
165
+ "as": "",
166
+ "eu": "", # cahya/wav2vec2-large-xlsr-basque # verify
167
+ "ha": "infinitejoy/wav2vec2-large-xls-r-300m-hausa",
168
+ "ht": "",
169
+ "hy": "infinitejoy/wav2vec2-large-xls-r-300m-armenian", # no (.)
170
+ "lo": "",
171
+ "mg": "",
172
+ "mn": "tugstugi/wav2vec2-large-xlsr-53-mongolian",
173
+ "mt": "carlosdanielhernandezmena/wav2vec2-large-xlsr-53-maltese-64h",
174
+ "pa": "kingabzpro/wav2vec2-large-xlsr-53-punjabi",
175
+ "ps": "aamirhs/wav2vec2-large-xls-r-300m-pashto-colab",
176
+ "sl": "anton-l/wav2vec2-large-xlsr-53-slovenian",
177
+ "sn": "",
178
+ "so": "",
179
+ "tg": "",
180
+ "tk": "", # Ragav/wav2vec2-tk
181
+ "tt": "anton-l/wav2vec2-large-xlsr-53-tatar",
182
+ "uz": "", # Mekhriddin/wav2vec2-large-xls-r-300m-uzbek-colab
183
+ "yo": "ogbi/wav2vec2-large-mms-1b-yoruba-test",
184
+ }
185
+
186
+
187
+ def fix_code_language(translate_to, syntax="google"):
188
+ if syntax == "google":
189
+ # google-translator, gTTS
190
+ replace_lang_code = {"zh": "zh-CN", "he": "iw", "zh-cn": "zh-CN"}
191
+ elif syntax == "coqui":
192
+ # coqui-xtts
193
+ replace_lang_code = {"zh": "zh-cn", "zh-CN": "zh-cn", "zh-TW": "zh-cn"}
194
+
195
+ new_code_lang = replace_lang_code.get(translate_to, translate_to)
196
+ logger.debug(f"Fix code {translate_to} -> {new_code_lang}")
197
+ return new_code_lang
198
+
199
+
200
+ BARK_VOICES_LIST = {
201
+ "de_speaker_0-Male BARK": "v2/de_speaker_0",
202
+ "de_speaker_1-Male BARK": "v2/de_speaker_1",
203
+ "de_speaker_2-Male BARK": "v2/de_speaker_2",
204
+ "de_speaker_3-Female BARK": "v2/de_speaker_3",
205
+ "de_speaker_4-Male BARK": "v2/de_speaker_4",
206
+ "de_speaker_5-Male BARK": "v2/de_speaker_5",
207
+ "de_speaker_6-Male BARK": "v2/de_speaker_6",
208
+ "de_speaker_7-Male BARK": "v2/de_speaker_7",
209
+ "de_speaker_8-Female BARK": "v2/de_speaker_8",
210
+ "de_speaker_9-Male BARK": "v2/de_speaker_9",
211
+ "en_speaker_0-Male BARK": "v2/en_speaker_0",
212
+ "en_speaker_1-Male BARK": "v2/en_speaker_1",
213
+ "en_speaker_2-Male BARK": "v2/en_speaker_2",
214
+ "en_speaker_3-Male BARK": "v2/en_speaker_3",
215
+ "en_speaker_4-Male BARK": "v2/en_speaker_4",
216
+ "en_speaker_5-Male BARK": "v2/en_speaker_5",
217
+ "en_speaker_6-Male BARK": "v2/en_speaker_6",
218
+ "en_speaker_7-Male BARK": "v2/en_speaker_7",
219
+ "en_speaker_8-Male BARK": "v2/en_speaker_8",
220
+ "en_speaker_9-Female BARK": "v2/en_speaker_9",
221
+ "es_speaker_0-Male BARK": "v2/es_speaker_0",
222
+ "es_speaker_1-Male BARK": "v2/es_speaker_1",
223
+ "es_speaker_2-Male BARK": "v2/es_speaker_2",
224
+ "es_speaker_3-Male BARK": "v2/es_speaker_3",
225
+ "es_speaker_4-Male BARK": "v2/es_speaker_4",
226
+ "es_speaker_5-Male BARK": "v2/es_speaker_5",
227
+ "es_speaker_6-Male BARK": "v2/es_speaker_6",
228
+ "es_speaker_7-Male BARK": "v2/es_speaker_7",
229
+ "es_speaker_8-Female BARK": "v2/es_speaker_8",
230
+ "es_speaker_9-Female BARK": "v2/es_speaker_9",
231
+ "fr_speaker_0-Male BARK": "v2/fr_speaker_0",
232
+ "fr_speaker_1-Female BARK": "v2/fr_speaker_1",
233
+ "fr_speaker_2-Female BARK": "v2/fr_speaker_2",
234
+ "fr_speaker_3-Male BARK": "v2/fr_speaker_3",
235
+ "fr_speaker_4-Male BARK": "v2/fr_speaker_4",
236
+ "fr_speaker_5-Female BARK": "v2/fr_speaker_5",
237
+ "fr_speaker_6-Male BARK": "v2/fr_speaker_6",
238
+ "fr_speaker_7-Male BARK": "v2/fr_speaker_7",
239
+ "fr_speaker_8-Male BARK": "v2/fr_speaker_8",
240
+ "fr_speaker_9-Male BARK": "v2/fr_speaker_9",
241
+ "hi_speaker_0-Female BARK": "v2/hi_speaker_0",
242
+ "hi_speaker_1-Female BARK": "v2/hi_speaker_1",
243
+ "hi_speaker_2-Male BARK": "v2/hi_speaker_2",
244
+ "hi_speaker_3-Female BARK": "v2/hi_speaker_3",
245
+ "hi_speaker_4-Female BARK": "v2/hi_speaker_4",
246
+ "hi_speaker_5-Male BARK": "v2/hi_speaker_5",
247
+ "hi_speaker_6-Male BARK": "v2/hi_speaker_6",
248
+ "hi_speaker_7-Male BARK": "v2/hi_speaker_7",
249
+ "hi_speaker_8-Male BARK": "v2/hi_speaker_8",
250
+ "hi_speaker_9-Female BARK": "v2/hi_speaker_9",
251
+ "it_speaker_0-Male BARK": "v2/it_speaker_0",
252
+ "it_speaker_1-Male BARK": "v2/it_speaker_1",
253
+ "it_speaker_2-Female BARK": "v2/it_speaker_2",
254
+ "it_speaker_3-Male BARK": "v2/it_speaker_3",
255
+ "it_speaker_4-Male BARK": "v2/it_speaker_4",
256
+ "it_speaker_5-Male BARK": "v2/it_speaker_5",
257
+ "it_speaker_6-Male BARK": "v2/it_speaker_6",
258
+ "it_speaker_7-Female BARK": "v2/it_speaker_7",
259
+ "it_speaker_8-Male BARK": "v2/it_speaker_8",
260
+ "it_speaker_9-Female BARK": "v2/it_speaker_9",
261
+ "ja_speaker_0-Female BARK": "v2/ja_speaker_0",
262
+ "ja_speaker_1-Female BARK": "v2/ja_speaker_1",
263
+ "ja_speaker_2-Male BARK": "v2/ja_speaker_2",
264
+ "ja_speaker_3-Female BARK": "v2/ja_speaker_3",
265
+ "ja_speaker_4-Female BARK": "v2/ja_speaker_4",
266
+ "ja_speaker_5-Female BARK": "v2/ja_speaker_5",
267
+ "ja_speaker_6-Male BARK": "v2/ja_speaker_6",
268
+ "ja_speaker_7-Female BARK": "v2/ja_speaker_7",
269
+ "ja_speaker_8-Female BARK": "v2/ja_speaker_8",
270
+ "ja_speaker_9-Female BARK": "v2/ja_speaker_9",
271
+ "ko_speaker_0-Female BARK": "v2/ko_speaker_0",
272
+ "ko_speaker_1-Male BARK": "v2/ko_speaker_1",
273
+ "ko_speaker_2-Male BARK": "v2/ko_speaker_2",
274
+ "ko_speaker_3-Male BARK": "v2/ko_speaker_3",
275
+ "ko_speaker_4-Male BARK": "v2/ko_speaker_4",
276
+ "ko_speaker_5-Male BARK": "v2/ko_speaker_5",
277
+ "ko_speaker_6-Male BARK": "v2/ko_speaker_6",
278
+ "ko_speaker_7-Male BARK": "v2/ko_speaker_7",
279
+ "ko_speaker_8-Male BARK": "v2/ko_speaker_8",
280
+ "ko_speaker_9-Male BARK": "v2/ko_speaker_9",
281
+ "pl_speaker_0-Male BARK": "v2/pl_speaker_0",
282
+ "pl_speaker_1-Male BARK": "v2/pl_speaker_1",
283
+ "pl_speaker_2-Male BARK": "v2/pl_speaker_2",
284
+ "pl_speaker_3-Male BARK": "v2/pl_speaker_3",
285
+ "pl_speaker_4-Female BARK": "v2/pl_speaker_4",
286
+ "pl_speaker_5-Male BARK": "v2/pl_speaker_5",
287
+ "pl_speaker_6-Female BARK": "v2/pl_speaker_6",
288
+ "pl_speaker_7-Male BARK": "v2/pl_speaker_7",
289
+ "pl_speaker_8-Male BARK": "v2/pl_speaker_8",
290
+ "pl_speaker_9-Female BARK": "v2/pl_speaker_9",
291
+ "pt_speaker_0-Male BARK": "v2/pt_speaker_0",
292
+ "pt_speaker_1-Male BARK": "v2/pt_speaker_1",
293
+ "pt_speaker_2-Male BARK": "v2/pt_speaker_2",
294
+ "pt_speaker_3-Male BARK": "v2/pt_speaker_3",
295
+ "pt_speaker_4-Male BARK": "v2/pt_speaker_4",
296
+ "pt_speaker_5-Male BARK": "v2/pt_speaker_5",
297
+ "pt_speaker_6-Male BARK": "v2/pt_speaker_6",
298
+ "pt_speaker_7-Male BARK": "v2/pt_speaker_7",
299
+ "pt_speaker_8-Male BARK": "v2/pt_speaker_8",
300
+ "pt_speaker_9-Male BARK": "v2/pt_speaker_9",
301
+ "ru_speaker_0-Male BARK": "v2/ru_speaker_0",
302
+ "ru_speaker_1-Male BARK": "v2/ru_speaker_1",
303
+ "ru_speaker_2-Male BARK": "v2/ru_speaker_2",
304
+ "ru_speaker_3-Male BARK": "v2/ru_speaker_3",
305
+ "ru_speaker_4-Male BARK": "v2/ru_speaker_4",
306
+ "ru_speaker_5-Female BARK": "v2/ru_speaker_5",
307
+ "ru_speaker_6-Female BARK": "v2/ru_speaker_6",
308
+ "ru_speaker_7-Male BARK": "v2/ru_speaker_7",
309
+ "ru_speaker_8-Male BARK": "v2/ru_speaker_8",
310
+ "ru_speaker_9-Female BARK": "v2/ru_speaker_9",
311
+ "tr_speaker_0-Male BARK": "v2/tr_speaker_0",
312
+ "tr_speaker_1-Male BARK": "v2/tr_speaker_1",
313
+ "tr_speaker_2-Male BARK": "v2/tr_speaker_2",
314
+ "tr_speaker_3-Male BARK": "v2/tr_speaker_3",
315
+ "tr_speaker_4-Female BARK": "v2/tr_speaker_4",
316
+ "tr_speaker_5-Female BARK": "v2/tr_speaker_5",
317
+ "tr_speaker_6-Male BARK": "v2/tr_speaker_6",
318
+ "tr_speaker_7-Male BARK": "v2/tr_speaker_7",
319
+ "tr_speaker_8-Male BARK": "v2/tr_speaker_8",
320
+ "tr_speaker_9-Male BARK": "v2/tr_speaker_9",
321
+ "zh_speaker_0-Male BARK": "v2/zh_speaker_0",
322
+ "zh_speaker_1-Male BARK": "v2/zh_speaker_1",
323
+ "zh_speaker_2-Male BARK": "v2/zh_speaker_2",
324
+ "zh_speaker_3-Male BARK": "v2/zh_speaker_3",
325
+ "zh_speaker_4-Female BARK": "v2/zh_speaker_4",
326
+ "zh_speaker_5-Male BARK": "v2/zh_speaker_5",
327
+ "zh_speaker_6-Female BARK": "v2/zh_speaker_6",
328
+ "zh_speaker_7-Female BARK": "v2/zh_speaker_7",
329
+ "zh_speaker_8-Male BARK": "v2/zh_speaker_8",
330
+ "zh_speaker_9-Female BARK": "v2/zh_speaker_9",
331
+ }
332
+
333
+ VITS_VOICES_LIST = {
334
+ "ar-facebook-mms VITS": "facebook/mms-tts-ara",
335
+ # 'zh-facebook-mms VITS': 'facebook/mms-tts-cmn',
336
+ "zh_Hakka-facebook-mms VITS": "facebook/mms-tts-hak",
337
+ "zh_MinNan-facebook-mms VITS": "facebook/mms-tts-nan",
338
+ # 'cs-facebook-mms VITS': 'facebook/mms-tts-ces',
339
+ # 'da-facebook-mms VITS': 'facebook/mms-tts-dan',
340
+ "nl-facebook-mms VITS": "facebook/mms-tts-nld",
341
+ "en-facebook-mms VITS": "facebook/mms-tts-eng",
342
+ "fi-facebook-mms VITS": "facebook/mms-tts-fin",
343
+ "fr-facebook-mms VITS": "facebook/mms-tts-fra",
344
+ "de-facebook-mms VITS": "facebook/mms-tts-deu",
345
+ "el-facebook-mms VITS": "facebook/mms-tts-ell",
346
+ "el_Ancient-facebook-mms VITS": "facebook/mms-tts-grc",
347
+ "he-facebook-mms VITS": "facebook/mms-tts-heb",
348
+ "hu-facebook-mms VITS": "facebook/mms-tts-hun",
349
+ # 'it-facebook-mms VITS': 'facebook/mms-tts-ita',
350
+ # 'ja-facebook-mms VITS': 'facebook/mms-tts-jpn',
351
+ "ko-facebook-mms VITS": "facebook/mms-tts-kor",
352
+ "fa-facebook-mms VITS": "facebook/mms-tts-fas",
353
+ "pl-facebook-mms VITS": "facebook/mms-tts-pol",
354
+ "pt-facebook-mms VITS": "facebook/mms-tts-por",
355
+ "ru-facebook-mms VITS": "facebook/mms-tts-rus",
356
+ "es-facebook-mms VITS": "facebook/mms-tts-spa",
357
+ "tr-facebook-mms VITS": "facebook/mms-tts-tur",
358
+ "uk-facebook-mms VITS": "facebook/mms-tts-ukr",
359
+ "ur_arabic-facebook-mms VITS": "facebook/mms-tts-urd-script_arabic",
360
+ "ur_devanagari-facebook-mms VITS": "facebook/mms-tts-urd-script_devanagari",
361
+ "ur_latin-facebook-mms VITS": "facebook/mms-tts-urd-script_latin",
362
+ "vi-facebook-mms VITS": "facebook/mms-tts-vie",
363
+ "hi-facebook-mms VITS": "facebook/mms-tts-hin",
364
+ "hi_Fiji-facebook-mms VITS": "facebook/mms-tts-hif",
365
+ "id-facebook-mms VITS": "facebook/mms-tts-ind",
366
+ "bn-facebook-mms VITS": "facebook/mms-tts-ben",
367
+ "te-facebook-mms VITS": "facebook/mms-tts-tel",
368
+ "mr-facebook-mms VITS": "facebook/mms-tts-mar",
369
+ "ta-facebook-mms VITS": "facebook/mms-tts-tam",
370
+ "jw-facebook-mms VITS": "facebook/mms-tts-jav",
371
+ "jw_Suriname-facebook-mms VITS": "facebook/mms-tts-jvn",
372
+ "ca-facebook-mms VITS": "facebook/mms-tts-cat",
373
+ "ne-facebook-mms VITS": "facebook/mms-tts-nep",
374
+ "th-facebook-mms VITS": "facebook/mms-tts-tha",
375
+ "th_Northern-facebook-mms VITS": "facebook/mms-tts-nod",
376
+ "sv-facebook-mms VITS": "facebook/mms-tts-swe",
377
+ "am-facebook-mms VITS": "facebook/mms-tts-amh",
378
+ "cy-facebook-mms VITS": "facebook/mms-tts-cym",
379
+ # "et-facebook-mms VITS": "facebook/mms-tts-est",
380
+ # "ht-facebook-mms VITS": "facebook/mms-tts-hrv",
381
+ "is-facebook-mms VITS": "facebook/mms-tts-isl",
382
+ "km-facebook-mms VITS": "facebook/mms-tts-khm",
383
+ "km_Northern-facebook-mms VITS": "facebook/mms-tts-kxm",
384
+ # "sk-facebook-mms VITS": "facebook/mms-tts-slk",
385
+ "sq_Northern-facebook-mms VITS": "facebook/mms-tts-sqi",
386
+ "az_South-facebook-mms VITS": "facebook/mms-tts-azb",
387
+ "az_North_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-azj-script_cyrillic",
388
+ "az_North_script_latin-facebook-mms VITS": "facebook/mms-tts-azj-script_latin",
389
+ "bg-facebook-mms VITS": "facebook/mms-tts-bul",
390
+ # "gl-facebook-mms VITS": "facebook/mms-tts-glg",
391
+ "gu-facebook-mms VITS": "facebook/mms-tts-guj",
392
+ "kk-facebook-mms VITS": "facebook/mms-tts-kaz",
393
+ "kn-facebook-mms VITS": "facebook/mms-tts-kan",
394
+ # "lt-facebook-mms VITS": "facebook/mms-tts-lit",
395
+ "lv-facebook-mms VITS": "facebook/mms-tts-lav",
396
+ # "mk-facebook-mms VITS": "facebook/mms-tts-mkd",
397
+ "ml-facebook-mms VITS": "facebook/mms-tts-mal",
398
+ "ms-facebook-mms VITS": "facebook/mms-tts-zlm",
399
+ "ms_Central-facebook-mms VITS": "facebook/mms-tts-pse",
400
+ "ms_Manado-facebook-mms VITS": "facebook/mms-tts-xmm",
401
+ "ro-facebook-mms VITS": "facebook/mms-tts-ron",
402
+ # "si-facebook-mms VITS": "facebook/mms-tts-sin",
403
+ "sw-facebook-mms VITS": "facebook/mms-tts-swh",
404
+ # "af-facebook-mms VITS": "facebook/mms-tts-afr",
405
+ # "bs-facebook-mms VITS": "facebook/mms-tts-bos",
406
+ "la-facebook-mms VITS": "facebook/mms-tts-lat",
407
+ "my-facebook-mms VITS": "facebook/mms-tts-mya",
408
+ # "no_Bokmål-facebook-mms VITS": "thomasht86/mms-tts-nob", # verify
409
+ "as-facebook-mms VITS": "facebook/mms-tts-asm",
410
+ "as_Nagamese-facebook-mms VITS": "facebook/mms-tts-nag",
411
+ "eu-facebook-mms VITS": "facebook/mms-tts-eus",
412
+ "ha-facebook-mms VITS": "facebook/mms-tts-hau",
413
+ "ht-facebook-mms VITS": "facebook/mms-tts-hat",
414
+ "hy_Western-facebook-mms VITS": "facebook/mms-tts-hyw",
415
+ "lo-facebook-mms VITS": "facebook/mms-tts-lao",
416
+ "mg-facebook-mms VITS": "facebook/mms-tts-mlg",
417
+ "mn-facebook-mms VITS": "facebook/mms-tts-mon",
418
+ # "mt-facebook-mms VITS": "facebook/mms-tts-mlt",
419
+ "pa_Eastern-facebook-mms VITS": "facebook/mms-tts-pan",
420
+ # "pa_Western-facebook-mms VITS": "facebook/mms-tts-pnb",
421
+ # "ps-facebook-mms VITS": "facebook/mms-tts-pus",
422
+ # "sl-facebook-mms VITS": "facebook/mms-tts-slv",
423
+ "sn-facebook-mms VITS": "facebook/mms-tts-sna",
424
+ "so-facebook-mms VITS": "facebook/mms-tts-son",
425
+ "tg-facebook-mms VITS": "facebook/mms-tts-tgk",
426
+ "tk_script_arabic-facebook-mms VITS": "facebook/mms-tts-tuk-script_arabic",
427
+ "tk_script_latin-facebook-mms VITS": "facebook/mms-tts-tuk-script_latin",
428
+ "tt-facebook-mms VITS": "facebook/mms-tts-tat",
429
+ "tt_Crimean-facebook-mms VITS": "facebook/mms-tts-crh",
430
+ "uz_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-uzb-script_cyrillic",
431
+ "yo-facebook-mms VITS": "facebook/mms-tts-yor",
432
+ "ay-facebook-mms VITS": "facebook/mms-tts-ayr",
433
+ "bm-facebook-mms VITS": "facebook/mms-tts-bam",
434
+ "ceb-facebook-mms VITS": "facebook/mms-tts-ceb",
435
+ "ny-facebook-mms VITS": "facebook/mms-tts-nya",
436
+ "dv-facebook-mms VITS": "facebook/mms-tts-div",
437
+ "doi-facebook-mms VITS": "facebook/mms-tts-dgo",
438
+ "ee-facebook-mms VITS": "facebook/mms-tts-ewe",
439
+ "gn-facebook-mms VITS": "facebook/mms-tts-grn",
440
+ "ilo-facebook-mms VITS": "facebook/mms-tts-ilo",
441
+ "rw-facebook-mms VITS": "facebook/mms-tts-kin",
442
+ "kri-facebook-mms VITS": "facebook/mms-tts-kri",
443
+ "ku_script_arabic-facebook-mms VITS": "facebook/mms-tts-kmr-script_arabic",
444
+ "ku_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-kmr-script_cyrillic",
445
+ "ku_script_latin-facebook-mms VITS": "facebook/mms-tts-kmr-script_latin",
446
+ "ckb-facebook-mms VITS": "razhan/mms-tts-ckb", # Verify w
447
+ "ky-facebook-mms VITS": "facebook/mms-tts-kir",
448
+ "lg-facebook-mms VITS": "facebook/mms-tts-lug",
449
+ "mai-facebook-mms VITS": "facebook/mms-tts-mai",
450
+ "or-facebook-mms VITS": "facebook/mms-tts-ory",
451
+ "om-facebook-mms VITS": "facebook/mms-tts-orm",
452
+ "qu_Huallaga-facebook-mms VITS": "facebook/mms-tts-qub",
453
+ "qu_Lambayeque-facebook-mms VITS": "facebook/mms-tts-quf",
454
+ "qu_South_Bolivian-facebook-mms VITS": "facebook/mms-tts-quh",
455
+ "qu_North_Bolivian-facebook-mms VITS": "facebook/mms-tts-qul",
456
+ "qu_Tena_Lowland-facebook-mms VITS": "facebook/mms-tts-quw",
457
+ "qu_Ayacucho-facebook-mms VITS": "facebook/mms-tts-quy",
458
+ "qu_Cusco-facebook-mms VITS": "facebook/mms-tts-quz",
459
+ "qu_Cajamarca-facebook-mms VITS": "facebook/mms-tts-qvc",
460
+ "qu_Eastern_Apurímac-facebook-mms VITS": "facebook/mms-tts-qve",
461
+ "qu_Huamalíes_Dos_de_Mayo_Huánuco-facebook-mms VITS": "facebook/mms-tts-qvh",
462
+ "qu_Margos_Yarowilca_Lauricocha-facebook-mms VITS": "facebook/mms-tts-qvm",
463
+ "qu_North_Junín-facebook-mms VITS": "facebook/mms-tts-qvn",
464
+ "qu_Napo-facebook-mms VITS": "facebook/mms-tts-qvo",
465
+ "qu_San_Martín-facebook-mms VITS": "facebook/mms-tts-qvs",
466
+ "qu_Huaylla_Wanca-facebook-mms VITS": "facebook/mms-tts-qvw",
467
+ "qu_Northern_Pastaza-facebook-mms VITS": "facebook/mms-tts-qvz",
468
+ "qu_Huaylas_Ancash-facebook-mms VITS": "facebook/mms-tts-qwh",
469
+ "qu_Panao-facebook-mms VITS": "facebook/mms-tts-qxh",
470
+ "qu_Salasaca_Highland-facebook-mms VITS": "facebook/mms-tts-qxl",
471
+ "qu_Northern_Conchucos_Ancash-facebook-mms VITS": "facebook/mms-tts-qxn",
472
+ "qu_Southern_Conchucos-facebook-mms VITS": "facebook/mms-tts-qxo",
473
+ "qu_Cañar_Highland-facebook-mms VITS": "facebook/mms-tts-qxr",
474
+ "sm-facebook-mms VITS": "facebook/mms-tts-smo",
475
+ "ti-facebook-mms VITS": "facebook/mms-tts-tir",
476
+ "ts-facebook-mms VITS": "facebook/mms-tts-tso",
477
+ "ak-facebook-mms VITS": "facebook/mms-tts-aka",
478
+ "ug_script_arabic-facebook-mms VITS": "facebook/mms-tts-uig-script_arabic",
479
+ "ug_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-uig-script_cyrillic",
480
+ }
481
+
482
+ OPENAI_TTS_CODES = [
483
+ "af", "ar", "hy", "az", "be", "bs", "bg", "ca", "zh", "hr", "cs", "da",
484
+ "nl", "en", "et", "fi", "fr", "gl", "de", "el", "he", "hi", "hu", "is",
485
+ "id", "it", "ja", "kn", "kk", "ko", "lv", "lt", "mk", "ms", "mr", "mi",
486
+ "ne", "no", "fa", "pl", "pt", "ro", "ru", "sr", "sk", "sl", "es", "sw",
487
+ "sv", "tl", "ta", "th", "tr", "uk", "ur", "vi", "cy", "zh-TW"
488
+ ]
489
+
490
+ OPENAI_TTS_MODELS = [
491
+ ">alloy OpenAI-TTS",
492
+ ">echo OpenAI-TTS",
493
+ ">fable OpenAI-TTS",
494
+ ">onyx OpenAI-TTS",
495
+ ">nova OpenAI-TTS",
496
+ ">shimmer OpenAI-TTS",
497
+ ">alloy HD OpenAI-TTS",
498
+ ">echo HD OpenAI-TTS",
499
+ ">fable HD OpenAI-TTS",
500
+ ">onyx HD OpenAI-TTS",
501
+ ">nova HD OpenAI-TTS",
502
+ ">shimmer HD OpenAI-TTS"
503
+ ]
504
+
505
+ LANGUAGE_CODE_IN_THREE_LETTERS = {
506
+ "Automatic detection": "aut",
507
+ "ar": "ara",
508
+ "zh": "chi",
509
+ "cs": "cze",
510
+ "da": "dan",
511
+ "nl": "dut",
512
+ "en": "eng",
513
+ "fi": "fin",
514
+ "fr": "fre",
515
+ "de": "ger",
516
+ "el": "gre",
517
+ "he": "heb",
518
+ "hu": "hun",
519
+ "it": "ita",
520
+ "ja": "jpn",
521
+ "ko": "kor",
522
+ "fa": "per",
523
+ "pl": "pol",
524
+ "pt": "por",
525
+ "ru": "rus",
526
+ "es": "spa",
527
+ "tr": "tur",
528
+ "uk": "ukr",
529
+ "ur": "urd",
530
+ "vi": "vie",
531
+ "hi": "hin",
532
+ "id": "ind",
533
+ "bn": "ben",
534
+ "te": "tel",
535
+ "mr": "mar",
536
+ "ta": "tam",
537
+ "jw": "jav",
538
+ "ca": "cat",
539
+ "ne": "nep",
540
+ "th": "tha",
541
+ "sv": "swe",
542
+ "am": "amh",
543
+ "cy": "cym",
544
+ "et": "est",
545
+ "hr": "hrv",
546
+ "is": "isl",
547
+ "km": "khm",
548
+ "sk": "slk",
549
+ "sq": "sqi",
550
+ "sr": "srp",
551
+ }
quantum_dubbing/languages_gui.py ADDED
The diff for this file is too large to render. See raw diff
 
quantum_dubbing/logging_setup.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ import warnings
4
+ import os
5
+
6
+
7
+ def configure_logging_libs(debug=False):
8
+ warnings.filterwarnings(
9
+ action="ignore", category=UserWarning, module="pyannote"
10
+ )
11
+ modules = [
12
+ "numba", "httpx", "markdown_it", "speechbrain", "fairseq", "pyannote",
13
+ "faiss",
14
+ "pytorch_lightning.utilities.migration.utils",
15
+ "pytorch_lightning.utilities.migration",
16
+ "pytorch_lightning",
17
+ "lightning",
18
+ "lightning.pytorch.utilities.migration.utils",
19
+ ]
20
+ try:
21
+ for module in modules:
22
+ logging.getLogger(module).setLevel(logging.WARNING)
23
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3" if not debug else "1"
24
+
25
+ # fix verbose pyannote audio
26
+ def fix_verbose_pyannote(*args, what=""):
27
+ pass
28
+ import pyannote.audio.core.model # noqa
29
+ pyannote.audio.core.model.check_version = fix_verbose_pyannote
30
+ except Exception as error:
31
+ logger.error(str(error))
32
+
33
+
34
+ def setup_logger(name_log):
35
+ logger = logging.getLogger(name_log)
36
+ logger.setLevel(logging.INFO)
37
+
38
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
39
+ _default_handler.flush = sys.stderr.flush
40
+ logger.addHandler(_default_handler)
41
+
42
+ logger.propagate = False
43
+
44
+ handlers = logger.handlers
45
+
46
+ for handler in handlers:
47
+ formatter = logging.Formatter("[%(levelname)s] >> %(message)s")
48
+ handler.setFormatter(formatter)
49
+
50
+ # logger.handlers
51
+
52
+ return logger
53
+
54
+
55
+ logger = setup_logger("quantum_dubbing")
56
+ logger.setLevel(logging.INFO)
57
+
58
+
59
+ def set_logging_level(verbosity_level):
60
+ logging_level_mapping = {
61
+ "debug": logging.DEBUG,
62
+ "info": logging.INFO,
63
+ "warning": logging.WARNING,
64
+ "error": logging.ERROR,
65
+ "critical": logging.CRITICAL,
66
+ }
67
+
68
+ logger.setLevel(logging_level_mapping.get(verbosity_level, logging.INFO))
quantum_dubbing/mdx_net.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import hashlib
3
+ import os
4
+ import queue
5
+ import threading
6
+ import json
7
+ import shlex
8
+ import sys
9
+ import subprocess
10
+ import librosa
11
+ import numpy as np
12
+ import soundfile as sf
13
+ import torch
14
+ from tqdm import tqdm
15
+
16
+ try:
17
+ from .utils import (
18
+ remove_directory_contents,
19
+ create_directories,
20
+ )
21
+ except: # noqa
22
+ from utils import (
23
+ remove_directory_contents,
24
+ create_directories,
25
+ )
26
+ from .logging_setup import logger
27
+
28
+ try:
29
+ import onnxruntime as ort
30
+ except Exception as error:
31
+ logger.error(str(error))
32
+ # import warnings
33
+ # warnings.filterwarnings("ignore")
34
+
35
+ stem_naming = {
36
+ "Vocals": "Instrumental",
37
+ "Other": "Instruments",
38
+ "Instrumental": "Vocals",
39
+ "Drums": "Drumless",
40
+ "Bass": "Bassless",
41
+ }
42
+
43
+
44
+ class MDXModel:
45
+ def __init__(
46
+ self,
47
+ device,
48
+ dim_f,
49
+ dim_t,
50
+ n_fft,
51
+ hop=1024,
52
+ stem_name=None,
53
+ compensation=1.000,
54
+ ):
55
+ self.dim_f = dim_f
56
+ self.dim_t = dim_t
57
+ self.dim_c = 4
58
+ self.n_fft = n_fft
59
+ self.hop = hop
60
+ self.stem_name = stem_name
61
+ self.compensation = compensation
62
+
63
+ self.n_bins = self.n_fft // 2 + 1
64
+ self.chunk_size = hop * (self.dim_t - 1)
65
+ self.window = torch.hann_window(
66
+ window_length=self.n_fft, periodic=True
67
+ ).to(device)
68
+
69
+ out_c = self.dim_c
70
+
71
+ self.freq_pad = torch.zeros(
72
+ [1, out_c, self.n_bins - self.dim_f, self.dim_t]
73
+ ).to(device)
74
+
75
+ def stft(self, x):
76
+ x = x.reshape([-1, self.chunk_size])
77
+ x = torch.stft(
78
+ x,
79
+ n_fft=self.n_fft,
80
+ hop_length=self.hop,
81
+ window=self.window,
82
+ center=True,
83
+ return_complex=True,
84
+ )
85
+ x = torch.view_as_real(x)
86
+ x = x.permute([0, 3, 1, 2])
87
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
88
+ [-1, 4, self.n_bins, self.dim_t]
89
+ )
90
+ return x[:, :, : self.dim_f]
91
+
92
+ def istft(self, x, freq_pad=None):
93
+ freq_pad = (
94
+ self.freq_pad.repeat([x.shape[0], 1, 1, 1])
95
+ if freq_pad is None
96
+ else freq_pad
97
+ )
98
+ x = torch.cat([x, freq_pad], -2)
99
+ # c = 4*2 if self.target_name=='*' else 2
100
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
101
+ [-1, 2, self.n_bins, self.dim_t]
102
+ )
103
+ x = x.permute([0, 2, 3, 1])
104
+ x = x.contiguous()
105
+ x = torch.view_as_complex(x)
106
+ x = torch.istft(
107
+ x,
108
+ n_fft=self.n_fft,
109
+ hop_length=self.hop,
110
+ window=self.window,
111
+ center=True,
112
+ )
113
+ return x.reshape([-1, 2, self.chunk_size])
114
+
115
+
116
+ class MDX:
117
+ DEFAULT_SR = 44100
118
+ # Unit: seconds
119
+ DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
120
+ DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
121
+
122
+ def __init__(
123
+ self, model_path: str, params: MDXModel, processor=0
124
+ ):
125
+ # Set the device and the provider (CPU or CUDA)
126
+ self.device = (
127
+ torch.device(f"cuda:{processor}")
128
+ if processor >= 0
129
+ else torch.device("cpu")
130
+ )
131
+ self.provider = (
132
+ ["CUDAExecutionProvider"]
133
+ if processor >= 0
134
+ else ["CPUExecutionProvider"]
135
+ )
136
+
137
+ self.model = params
138
+
139
+ # Load the ONNX model using ONNX Runtime
140
+ self.ort = ort.InferenceSession(model_path, providers=self.provider)
141
+ # Preload the model for faster performance
142
+ self.ort.run(
143
+ None,
144
+ {"input": torch.rand(1, 4, params.dim_f, params.dim_t).numpy()},
145
+ )
146
+ self.process = lambda spec: self.ort.run(
147
+ None, {"input": spec.cpu().numpy()}
148
+ )[0]
149
+
150
+ self.prog = None
151
+
152
+ @staticmethod
153
+ def get_hash(model_path):
154
+ try:
155
+ with open(model_path, "rb") as f:
156
+ f.seek(-10000 * 1024, 2)
157
+ model_hash = hashlib.md5(f.read()).hexdigest()
158
+ except: # noqa
159
+ model_hash = hashlib.md5(open(model_path, "rb").read()).hexdigest()
160
+
161
+ return model_hash
162
+
163
+ @staticmethod
164
+ def segment(
165
+ wave,
166
+ combine=True,
167
+ chunk_size=DEFAULT_CHUNK_SIZE,
168
+ margin_size=DEFAULT_MARGIN_SIZE,
169
+ ):
170
+ """
171
+ Segment or join segmented wave array
172
+
173
+ Args:
174
+ wave: (np.array) Wave array to be segmented or joined
175
+ combine: (bool) If True, combines segmented wave array.
176
+ If False, segments wave array.
177
+ chunk_size: (int) Size of each segment (in samples)
178
+ margin_size: (int) Size of margin between segments (in samples)
179
+
180
+ Returns:
181
+ numpy array: Segmented or joined wave array
182
+ """
183
+
184
+ if combine:
185
+ # Initializing as None instead of [] for later numpy array concatenation
186
+ processed_wave = None
187
+ for segment_count, segment in enumerate(wave):
188
+ start = 0 if segment_count == 0 else margin_size
189
+ end = None if segment_count == len(wave) - 1 else -margin_size
190
+ if margin_size == 0:
191
+ end = None
192
+ if processed_wave is None: # Create array for first segment
193
+ processed_wave = segment[:, start:end]
194
+ else: # Concatenate to existing array for subsequent segments
195
+ processed_wave = np.concatenate(
196
+ (processed_wave, segment[:, start:end]), axis=-1
197
+ )
198
+
199
+ else:
200
+ processed_wave = []
201
+ sample_count = wave.shape[-1]
202
+
203
+ if chunk_size <= 0 or chunk_size > sample_count:
204
+ chunk_size = sample_count
205
+
206
+ if margin_size > chunk_size:
207
+ margin_size = chunk_size
208
+
209
+ for segment_count, skip in enumerate(
210
+ range(0, sample_count, chunk_size)
211
+ ):
212
+ margin = 0 if segment_count == 0 else margin_size
213
+ end = min(skip + chunk_size + margin_size, sample_count)
214
+ start = skip - margin
215
+
216
+ cut = wave[:, start:end].copy()
217
+ processed_wave.append(cut)
218
+
219
+ if end == sample_count:
220
+ break
221
+
222
+ return processed_wave
223
+
224
+ def pad_wave(self, wave):
225
+ """
226
+ Pad the wave array to match the required chunk size
227
+
228
+ Args:
229
+ wave: (np.array) Wave array to be padded
230
+
231
+ Returns:
232
+ tuple: (padded_wave, pad, trim)
233
+ - padded_wave: Padded wave array
234
+ - pad: Number of samples that were padded
235
+ - trim: Number of samples that were trimmed
236
+ """
237
+ n_sample = wave.shape[1]
238
+ trim = self.model.n_fft // 2
239
+ gen_size = self.model.chunk_size - 2 * trim
240
+ pad = gen_size - n_sample % gen_size
241
+
242
+ # Padded wave
243
+ wave_p = np.concatenate(
244
+ (
245
+ np.zeros((2, trim)),
246
+ wave,
247
+ np.zeros((2, pad)),
248
+ np.zeros((2, trim)),
249
+ ),
250
+ 1,
251
+ )
252
+
253
+ mix_waves = []
254
+ for i in range(0, n_sample + pad, gen_size):
255
+ waves = np.array(wave_p[:, i:i + self.model.chunk_size])
256
+ mix_waves.append(waves)
257
+
258
+ mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(
259
+ self.device
260
+ )
261
+
262
+ return mix_waves, pad, trim
263
+
264
+ def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int):
265
+ """
266
+ Process each wave segment in a multi-threaded environment
267
+
268
+ Args:
269
+ mix_waves: (torch.Tensor) Wave segments to be processed
270
+ trim: (int) Number of samples trimmed during padding
271
+ pad: (int) Number of samples padded during padding
272
+ q: (queue.Queue) Queue to hold the processed wave segments
273
+ _id: (int) Identifier of the processed wave segment
274
+
275
+ Returns:
276
+ numpy array: Processed wave segment
277
+ """
278
+ mix_waves = mix_waves.split(1)
279
+ with torch.no_grad():
280
+ pw = []
281
+ for mix_wave in mix_waves:
282
+ self.prog.update()
283
+ spec = self.model.stft(mix_wave)
284
+ processed_spec = torch.tensor(self.process(spec))
285
+ processed_wav = self.model.istft(
286
+ processed_spec.to(self.device)
287
+ )
288
+ processed_wav = (
289
+ processed_wav[:, :, trim:-trim]
290
+ .transpose(0, 1)
291
+ .reshape(2, -1)
292
+ .cpu()
293
+ .numpy()
294
+ )
295
+ pw.append(processed_wav)
296
+ processed_signal = np.concatenate(pw, axis=-1)[:, :-pad]
297
+ q.put({_id: processed_signal})
298
+ return processed_signal
299
+
300
+ def process_wave(self, wave: np.array, mt_threads=1):
301
+ """
302
+ Process the wave array in a multi-threaded environment
303
+
304
+ Args:
305
+ wave: (np.array) Wave array to be processed
306
+ mt_threads: (int) Number of threads to be used for processing
307
+
308
+ Returns:
309
+ numpy array: Processed wave array
310
+ """
311
+ self.prog = tqdm(total=0)
312
+ chunk = wave.shape[-1] // mt_threads
313
+ waves = self.segment(wave, False, chunk)
314
+
315
+ # Create a queue to hold the processed wave segments
316
+ q = queue.Queue()
317
+ threads = []
318
+ for c, batch in enumerate(waves):
319
+ mix_waves, pad, trim = self.pad_wave(batch)
320
+ self.prog.total = len(mix_waves) * mt_threads
321
+ thread = threading.Thread(
322
+ target=self._process_wave, args=(mix_waves, trim, pad, q, c)
323
+ )
324
+ thread.start()
325
+ threads.append(thread)
326
+ for thread in threads:
327
+ thread.join()
328
+ self.prog.close()
329
+
330
+ processed_batches = []
331
+ while not q.empty():
332
+ processed_batches.append(q.get())
333
+ processed_batches = [
334
+ list(wave.values())[0]
335
+ for wave in sorted(
336
+ processed_batches, key=lambda d: list(d.keys())[0]
337
+ )
338
+ ]
339
+ assert len(processed_batches) == len(
340
+ waves
341
+ ), "Incomplete processed batches, please reduce batch size!"
342
+ return self.segment(processed_batches, True, chunk)
343
+
344
+
345
+ def run_mdx(
346
+ model_params,
347
+ output_dir,
348
+ model_path,
349
+ filename,
350
+ exclude_main=False,
351
+ exclude_inversion=False,
352
+ suffix=None,
353
+ invert_suffix=None,
354
+ denoise=False,
355
+ keep_orig=True,
356
+ m_threads=2,
357
+ device_base="cuda",
358
+ ):
359
+ if device_base == "cuda":
360
+ device = torch.device("cuda:0")
361
+ processor_num = 0
362
+ device_properties = torch.cuda.get_device_properties(device)
363
+ vram_gb = device_properties.total_memory / 1024**3
364
+ m_threads = 1 if vram_gb < 8 else 2
365
+ else:
366
+ device = torch.device("cpu")
367
+ processor_num = -1
368
+ m_threads = 1
369
+
370
+ if os.environ.get("ZERO_GPU") == "TRUE":
371
+ duration = librosa.get_duration(filename=filename)
372
+
373
+ if duration < 60:
374
+ pass
375
+ elif duration >= 60 and duration <= 900:
376
+ m_threads = 4
377
+ elif duration > 900:
378
+ m_threads = 16
379
+
380
+ logger.info(f"MDX-NET Threads: {m_threads}, duration {duration}")
381
+
382
+ model_hash = MDX.get_hash(model_path)
383
+ mp = model_params.get(model_hash)
384
+ model = MDXModel(
385
+ device,
386
+ dim_f=mp["mdx_dim_f_set"],
387
+ dim_t=2 ** mp["mdx_dim_t_set"],
388
+ n_fft=mp["mdx_n_fft_scale_set"],
389
+ stem_name=mp["primary_stem"],
390
+ compensation=mp["compensate"],
391
+ )
392
+
393
+ mdx_sess = MDX(model_path, model, processor=processor_num)
394
+ wave, sr = librosa.load(filename, mono=False, sr=44100)
395
+ # normalizing input wave gives better output
396
+ peak = max(np.max(wave), abs(np.min(wave)))
397
+ wave /= peak
398
+ if denoise:
399
+ wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (
400
+ mdx_sess.process_wave(wave, m_threads)
401
+ )
402
+ wave_processed *= 0.5
403
+ else:
404
+ wave_processed = mdx_sess.process_wave(wave, m_threads)
405
+ # return to previous peak
406
+ wave_processed *= peak
407
+ stem_name = model.stem_name if suffix is None else suffix
408
+
409
+ main_filepath = None
410
+ if not exclude_main:
411
+ main_filepath = os.path.join(
412
+ output_dir,
413
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
414
+ )
415
+ sf.write(main_filepath, wave_processed.T, sr)
416
+
417
+ invert_filepath = None
418
+ if not exclude_inversion:
419
+ diff_stem_name = (
420
+ stem_naming.get(stem_name)
421
+ if invert_suffix is None
422
+ else invert_suffix
423
+ )
424
+ stem_name = (
425
+ f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
426
+ )
427
+ invert_filepath = os.path.join(
428
+ output_dir,
429
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
430
+ )
431
+ sf.write(
432
+ invert_filepath,
433
+ (-wave_processed.T * model.compensation) + wave.T,
434
+ sr,
435
+ )
436
+
437
+ if not keep_orig:
438
+ os.remove(filename)
439
+
440
+ del mdx_sess, wave_processed, wave
441
+ gc.collect()
442
+ torch.cuda.empty_cache()
443
+ return main_filepath, invert_filepath
444
+
445
+
446
+ MDX_DOWNLOAD_LINK = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/"
447
+ UVR_MODELS = [
448
+ "UVR-MDX-NET-Voc_FT.onnx",
449
+ "UVR_MDXNET_KARA_2.onnx",
450
+ "Reverb_HQ_By_FoxJoy.onnx",
451
+ "UVR-MDX-NET-Inst_HQ_4.onnx",
452
+ ]
453
+ BASE_DIR = "." # os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
454
+ mdxnet_models_dir = os.path.join(BASE_DIR, "mdx_models")
455
+ output_dir = os.path.join(BASE_DIR, "clean_song_output")
456
+
457
+
458
+ def convert_to_stereo_and_wav(audio_path):
459
+ wave, sr = librosa.load(audio_path, mono=False, sr=44100)
460
+
461
+ # check if mono
462
+ if type(wave[0]) != np.ndarray or audio_path[-4:].lower() != ".wav": # noqa
463
+ stereo_path = f"{os.path.splitext(audio_path)[0]}_stereo.wav"
464
+ stereo_path = os.path.join(output_dir, stereo_path)
465
+
466
+ command = shlex.split(
467
+ f'ffmpeg -y -loglevel error -i "{audio_path}" -ac 2 -f wav "{stereo_path}"'
468
+ )
469
+ sub_params = {
470
+ "stdout": subprocess.PIPE,
471
+ "stderr": subprocess.PIPE,
472
+ "creationflags": subprocess.CREATE_NO_WINDOW
473
+ if sys.platform == "win32"
474
+ else 0,
475
+ }
476
+ process_wav = subprocess.Popen(command, **sub_params)
477
+ output, errors = process_wav.communicate()
478
+ if process_wav.returncode != 0 or not os.path.exists(stereo_path):
479
+ raise Exception("Error processing audio to stereo wav")
480
+
481
+ return stereo_path
482
+ else:
483
+ return audio_path
484
+
485
+
486
+ def process_uvr_task(
487
+ orig_song_path: str = "aud_test.mp3",
488
+ main_vocals: bool = False,
489
+ dereverb: bool = True,
490
+ song_id: str = "mdx", # folder output name
491
+ only_voiceless: bool = False,
492
+ remove_files_output_dir: bool = False,
493
+ ):
494
+ if os.environ.get("QUANTUM_DEVICE") == "cpu":
495
+ device_base = "cpu"
496
+ else:
497
+ device_base = "cuda" if torch.cuda.is_available() else "cpu"
498
+
499
+ if remove_files_output_dir:
500
+ remove_directory_contents(output_dir)
501
+
502
+ with open(os.path.join(mdxnet_models_dir, "data.json")) as infile:
503
+ mdx_model_params = json.load(infile)
504
+
505
+ song_output_dir = os.path.join(output_dir, song_id)
506
+ create_directories(song_output_dir)
507
+ orig_song_path = convert_to_stereo_and_wav(orig_song_path)
508
+
509
+ logger.debug(f"onnxruntime device >> {ort.get_device()}")
510
+
511
+ if only_voiceless:
512
+ logger.info("Voiceless Track Separation...")
513
+ return run_mdx(
514
+ mdx_model_params,
515
+ song_output_dir,
516
+ os.path.join(mdxnet_models_dir, "UVR-MDX-NET-Inst_HQ_4.onnx"),
517
+ orig_song_path,
518
+ suffix="Voiceless",
519
+ denoise=False,
520
+ keep_orig=True,
521
+ exclude_inversion=True,
522
+ device_base=device_base,
523
+ )
524
+
525
+ logger.info("Vocal Track Isolation and Voiceless Track Separation...")
526
+ vocals_path, instrumentals_path = run_mdx(
527
+ mdx_model_params,
528
+ song_output_dir,
529
+ os.path.join(mdxnet_models_dir, "UVR-MDX-NET-Voc_FT.onnx"),
530
+ orig_song_path,
531
+ denoise=True,
532
+ keep_orig=True,
533
+ device_base=device_base,
534
+ )
535
+
536
+ if main_vocals:
537
+ logger.info("Main Voice Separation from Supporting Vocals...")
538
+ backup_vocals_path, main_vocals_path = run_mdx(
539
+ mdx_model_params,
540
+ song_output_dir,
541
+ os.path.join(mdxnet_models_dir, "UVR_MDXNET_KARA_2.onnx"),
542
+ vocals_path,
543
+ suffix="Backup",
544
+ invert_suffix="Main",
545
+ denoise=True,
546
+ device_base=device_base,
547
+ )
548
+ else:
549
+ backup_vocals_path, main_vocals_path = None, vocals_path
550
+
551
+ if dereverb:
552
+ logger.info("Vocal Clarity Enhancement through De-Reverberation...")
553
+ _, vocals_dereverb_path = run_mdx(
554
+ mdx_model_params,
555
+ song_output_dir,
556
+ os.path.join(mdxnet_models_dir, "Reverb_HQ_By_FoxJoy.onnx"),
557
+ main_vocals_path,
558
+ invert_suffix="DeReverb",
559
+ exclude_main=True,
560
+ denoise=True,
561
+ device_base=device_base,
562
+ )
563
+ else:
564
+ vocals_dereverb_path = main_vocals_path
565
+
566
+ return (
567
+ vocals_path,
568
+ instrumentals_path,
569
+ backup_vocals_path,
570
+ main_vocals_path,
571
+ vocals_dereverb_path,
572
+ )
573
+
574
+
575
+ if __name__ == "__main__":
576
+ from utils import download_manager
577
+
578
+ for id_model in UVR_MODELS:
579
+ download_manager(
580
+ os.path.join(MDX_DOWNLOAD_LINK, id_model), mdxnet_models_dir
581
+ )
582
+ (
583
+ vocals_path_,
584
+ instrumentals_path_,
585
+ backup_vocals_path_,
586
+ main_vocals_path_,
587
+ vocals_dereverb_path_,
588
+ ) = process_uvr_task(
589
+ orig_song_path="aud.mp3",
590
+ main_vocals=True,
591
+ dereverb=True,
592
+ song_id="mdx",
593
+ remove_files_output_dir=True,
594
+ )
quantum_dubbing/postprocessor.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import remove_files, run_command
2
+ from .text_multiformat_processor import get_subtitle
3
+ from .logging_setup import logger
4
+ import unicodedata
5
+ import shutil
6
+ import copy
7
+ import os
8
+ import re
9
+
10
+ OUTPUT_TYPE_OPTIONS = [
11
+ "video (mp4)",
12
+ "video (mkv)",
13
+ "audio (mp3)",
14
+ "audio (ogg)",
15
+ "audio (wav)",
16
+ "subtitle",
17
+ "subtitle [by speaker]",
18
+ "video [subtitled] (mp4)",
19
+ "video [subtitled] (mkv)",
20
+ "audio [original vocal sound]",
21
+ "audio [original background sound]",
22
+ "audio [original vocal and background sound]",
23
+ "audio [original vocal-dereverb sound]",
24
+ "audio [original vocal-dereverb and background sound]",
25
+ "raw media",
26
+ ]
27
+
28
+ DOCS_OUTPUT_TYPE_OPTIONS = [
29
+ "videobook (mp4)",
30
+ "videobook (mkv)",
31
+ "audiobook (wav)",
32
+ "audiobook (mp3)",
33
+ "audiobook (ogg)",
34
+ "book (txt)",
35
+ ] # Add DOCX and etc.
36
+
37
+
38
+ def get_no_ext_filename(file_path):
39
+ file_name_with_extension = os.path.basename(rf"{file_path}")
40
+ filename_without_extension, _ = os.path.splitext(file_name_with_extension)
41
+ return filename_without_extension
42
+
43
+
44
+ def get_video_info(link):
45
+ aux_name = f"video_url_{link}"
46
+ params_dlp = {"quiet": True, "no_warnings": True, "noplaylist": True}
47
+ try:
48
+ from yt_dlp import YoutubeDL
49
+
50
+ with YoutubeDL(params_dlp) as ydl:
51
+ if link.startswith(("www.youtube.com/", "m.youtube.com/")):
52
+ link = "https://" + link
53
+ info_dict = ydl.extract_info(link, download=False, process=False)
54
+ video_id = info_dict.get("id", aux_name)
55
+ video_title = info_dict.get("title", video_id)
56
+ if "youtube.com" in link and "&list=" in link:
57
+ video_title = ydl.extract_info(
58
+ "https://m.youtube.com/watch?v="+video_id,
59
+ download=False,
60
+ process=False
61
+ ).get("title", video_title)
62
+ except Exception as error:
63
+ logger.error(str(error))
64
+ video_title, video_id = aux_name, "NO_ID"
65
+ return video_title, video_id
66
+
67
+
68
+ def sanitize_file_name(file_name):
69
+ # Normalize the string to NFKD form to separate combined
70
+ # characters into base characters and diacritics
71
+ normalized_name = unicodedata.normalize("NFKD", file_name)
72
+ # Replace any non-ASCII characters or special symbols with an underscore
73
+ sanitized_name = re.sub(r"[^\w\s.-]", "_", normalized_name)
74
+ return sanitized_name
75
+
76
+
77
+ def get_output_file(
78
+ original_file,
79
+ new_file_name,
80
+ soft_subtitles,
81
+ output_directory="",
82
+ ):
83
+ directory_base = "." # default directory
84
+
85
+ if output_directory and os.path.isdir(output_directory):
86
+ new_file_path = os.path.join(output_directory, new_file_name)
87
+ else:
88
+ new_file_path = os.path.join(directory_base, "outputs", new_file_name)
89
+ remove_files(new_file_path)
90
+
91
+ cm = None
92
+ if soft_subtitles and original_file.endswith(".mp4"):
93
+ if new_file_path.endswith(".mp4"):
94
+ cm = f'ffmpeg -y -i "{original_file}" -i sub_tra.srt -i sub_ori.srt -map 0:v -map 0:a -map 1 -map 2 -c:v copy -c:a copy -c:s mov_text "{new_file_path}"'
95
+ else:
96
+ cm = f'ffmpeg -y -i "{original_file}" -i sub_tra.srt -i sub_ori.srt -map 0:v -map 0:a -map 1 -map 2 -c:v copy -c:a copy -c:s srt -movflags use_metadata_tags -map_metadata 0 "{new_file_path}"'
97
+ elif new_file_path.endswith(".mkv"):
98
+ cm = f'ffmpeg -i "{original_file}" -c:v copy -c:a copy "{new_file_path}"'
99
+ elif new_file_path.endswith(".wav") and not original_file.endswith(".wav"):
100
+ cm = f'ffmpeg -y -i "{original_file}" -acodec pcm_s16le -ar 44100 -ac 2 "{new_file_path}"'
101
+ elif new_file_path.endswith(".ogg"):
102
+ cm = f'ffmpeg -i "{original_file}" -c:a libvorbis "{new_file_path}"'
103
+ elif new_file_path.endswith(".mp3") and not original_file.endswith(".mp3"):
104
+ cm = f'ffmpeg -y -i "{original_file}" -codec:a libmp3lame -qscale:a 2 "{new_file_path}"'
105
+
106
+ if cm:
107
+ try:
108
+ run_command(cm)
109
+ except Exception as error:
110
+ logger.error(str(error))
111
+ remove_files(new_file_path)
112
+ shutil.copy2(original_file, new_file_path)
113
+ else:
114
+ shutil.copy2(original_file, new_file_path)
115
+
116
+ return os.path.abspath(new_file_path)
117
+
118
+
119
+ def media_out(
120
+ media_file,
121
+ lang_code,
122
+ media_out_name="",
123
+ extension="mp4",
124
+ file_obj="video_dub.mp4",
125
+ soft_subtitles=False,
126
+ subtitle_files="disable",
127
+ ):
128
+ if media_out_name:
129
+ base_name = media_out_name + "_origin"
130
+ else:
131
+ if os.path.exists(media_file):
132
+ base_name = get_no_ext_filename(media_file)
133
+ else:
134
+ base_name, _ = get_video_info(media_file)
135
+
136
+ media_out_name = f"{base_name}__{lang_code}"
137
+
138
+ f_name = f"{sanitize_file_name(media_out_name)}.{extension}"
139
+
140
+ if subtitle_files != "disable":
141
+ final_media = [get_output_file(file_obj, f_name, soft_subtitles)]
142
+ name_tra = f"{sanitize_file_name(media_out_name)}.{subtitle_files}"
143
+ name_ori = f"{sanitize_file_name(base_name)}.{subtitle_files}"
144
+ tgt_subs = f"sub_tra.{subtitle_files}"
145
+ ori_subs = f"sub_ori.{subtitle_files}"
146
+ final_subtitles = [
147
+ get_output_file(tgt_subs, name_tra, False),
148
+ get_output_file(ori_subs, name_ori, False)
149
+ ]
150
+ return final_media + final_subtitles
151
+ else:
152
+ return get_output_file(file_obj, f_name, soft_subtitles)
153
+
154
+
155
+ def get_subtitle_speaker(media_file, result, language, extension, base_name):
156
+
157
+ segments_base = copy.deepcopy(result)
158
+
159
+ # Sub segments by speaker
160
+ segments_by_speaker = {}
161
+ for segment in segments_base["segments"]:
162
+ if segment["speaker"] not in segments_by_speaker.keys():
163
+ segments_by_speaker[segment["speaker"]] = [segment]
164
+ else:
165
+ segments_by_speaker[segment["speaker"]].append(segment)
166
+
167
+ if not base_name:
168
+ if os.path.exists(media_file):
169
+ base_name = get_no_ext_filename(media_file)
170
+ else:
171
+ base_name, _ = get_video_info(media_file)
172
+
173
+ files_subs = []
174
+ for name_sk, segments in segments_by_speaker.items():
175
+
176
+ subtitle_speaker = get_subtitle(
177
+ language,
178
+ {"segments": segments},
179
+ extension,
180
+ filename=name_sk,
181
+ )
182
+
183
+ media_out_name = f"{base_name}_{language}_{name_sk}"
184
+
185
+ output = media_out(
186
+ media_file, # no need
187
+ language,
188
+ media_out_name,
189
+ extension,
190
+ file_obj=subtitle_speaker,
191
+ )
192
+
193
+ files_subs.append(output)
194
+
195
+ return files_subs
196
+
197
+
198
+ def sound_separate(media_file, task_uvr):
199
+ from .mdx_net import process_uvr_task
200
+
201
+ outputs = []
202
+
203
+ if "vocal" in task_uvr:
204
+ try:
205
+ _, _, _, _, vocal_audio = process_uvr_task(
206
+ orig_song_path=media_file,
207
+ main_vocals=False,
208
+ dereverb=True if "dereverb" in task_uvr else False,
209
+ remove_files_output_dir=True,
210
+ )
211
+ outputs.append(vocal_audio)
212
+ except Exception as error:
213
+ logger.error(str(error))
214
+
215
+ if "background" in task_uvr:
216
+ try:
217
+ background_audio, _ = process_uvr_task(
218
+ orig_song_path=media_file,
219
+ song_id="voiceless",
220
+ only_voiceless=True,
221
+ remove_files_output_dir=False if "vocal" in task_uvr else True,
222
+ )
223
+ # copy_files(background_audio, ".")
224
+ outputs.append(background_audio)
225
+ except Exception as error:
226
+ logger.error(str(error))
227
+
228
+ if not outputs:
229
+ raise Exception("Error in uvr process")
230
+
231
+ return outputs
quantum_dubbing/preprocessor.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import remove_files
2
+ import os, shutil, subprocess, time, shlex, sys # noqa
3
+ from .logging_setup import logger
4
+ import json
5
+
6
+ ERROR_INCORRECT_CODEC_PARAMETERS = [
7
+ "prores", # mov
8
+ "ffv1", # mkv
9
+ "msmpeg4v3", # avi
10
+ "wmv2", # wmv
11
+ "theora", # ogv
12
+ ] # fix final merge
13
+
14
+ TESTED_CODECS = [
15
+ "h264", # mp4
16
+ "h265", # mp4
17
+ "hevc", # test
18
+ "vp9", # webm
19
+ "mpeg4", # mp4
20
+ "mpeg2video", # mpg
21
+ "mjpeg", # avi
22
+ ]
23
+
24
+
25
+ class OperationFailedError(Exception):
26
+ def __init__(self, message="The operation did not complete successfully."):
27
+ self.message = message
28
+ super().__init__(self.message)
29
+
30
+
31
+ def get_video_codec(video_file):
32
+ command_base = rf'ffprobe -v error -select_streams v:0 -show_entries stream=codec_name -of json "{video_file}"'
33
+ command = shlex.split(command_base)
34
+ try:
35
+ process = subprocess.Popen(
36
+ command,
37
+ stdout=subprocess.PIPE,
38
+ creationflags=subprocess.CREATE_NO_WINDOW if sys.platform == "win32" else 0,
39
+ )
40
+ output, _ = process.communicate()
41
+ codec_info = json.loads(output.decode('utf-8'))
42
+ codec_name = codec_info['streams'][0]['codec_name']
43
+ return codec_name
44
+ except Exception as error:
45
+ logger.debug(str(error))
46
+ return None
47
+
48
+
49
+ def audio_preprocessor(preview, base_audio, audio_wav, use_cuda=False):
50
+ base_audio = base_audio.strip()
51
+ previous_files_to_remove = [audio_wav]
52
+ remove_files(previous_files_to_remove)
53
+
54
+ if preview:
55
+ logger.warning(
56
+ "Creating a preview video of 10 seconds, to disable "
57
+ "this option, go to advanced settings and turn off preview."
58
+ )
59
+ wav_ = f'ffmpeg -y -i "{base_audio}" -ss 00:00:20 -t 00:00:10 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav'
60
+ else:
61
+ wav_ = f'ffmpeg -y -i "{base_audio}" -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav'
62
+
63
+ # Run cmd process
64
+ sub_params = {
65
+ "stdout": subprocess.PIPE,
66
+ "stderr": subprocess.PIPE,
67
+ "creationflags": subprocess.CREATE_NO_WINDOW
68
+ if sys.platform == "win32"
69
+ else 0,
70
+ }
71
+ wav_ = shlex.split(wav_)
72
+ result_convert_audio = subprocess.Popen(wav_, **sub_params)
73
+ output, errors = result_convert_audio.communicate()
74
+ time.sleep(1)
75
+ if result_convert_audio.returncode in [1, 2] or not os.path.exists(
76
+ audio_wav
77
+ ):
78
+ raise OperationFailedError(f"Error can't create the audio file:\n{errors.decode('utf-8')}")
79
+
80
+
81
+ def audio_video_preprocessor(
82
+ preview, video, OutputFile, audio_wav, use_cuda=False
83
+ ):
84
+ video = video.strip()
85
+ previous_files_to_remove = [OutputFile, "audio.webm", audio_wav]
86
+ remove_files(previous_files_to_remove)
87
+
88
+ if os.path.exists(video):
89
+ if preview:
90
+ logger.warning(
91
+ "Creating a preview video of 10 seconds, "
92
+ "to disable this option, go to advanced "
93
+ "settings and turn off preview."
94
+ )
95
+ mp4_ = f'ffmpeg -y -i "{video}" -ss 00:00:20 -t 00:00:10 -c:v libx264 -c:a aac -strict experimental Video.mp4'
96
+ else:
97
+ video_codec = get_video_codec(video)
98
+ if not video_codec:
99
+ logger.debug("No video codec found in video")
100
+ else:
101
+ logger.info(f"Video codec: {video_codec}")
102
+
103
+ # Check if the file ends with ".mp4" extension or is valid codec
104
+ if video.endswith(".mp4") or video_codec in TESTED_CODECS:
105
+ destination_path = os.path.join(os.getcwd(), "Video.mp4")
106
+ shutil.copy(video, destination_path)
107
+ time.sleep(0.5)
108
+ if os.path.exists(OutputFile):
109
+ mp4_ = "ffmpeg -h"
110
+ else:
111
+ mp4_ = f'ffmpeg -y -i "{video}" -c copy Video.mp4'
112
+ else:
113
+ logger.warning(
114
+ "File does not have the '.mp4' extension or a "
115
+ "supported codec. Converting video to mp4 (codec: h264)."
116
+ )
117
+ mp4_ = f'ffmpeg -y -i "{video}" -c:v libx264 -c:a aac -strict experimental Video.mp4'
118
+ else:
119
+ if preview:
120
+ logger.warning(
121
+ "Creating a preview from the link, 10 seconds "
122
+ "to disable this option, go to advanced "
123
+ "settings and turn off preview."
124
+ )
125
+ # https://github.com/yt-dlp/yt-dlp/issues/2220
126
+ mp4_ = f'yt-dlp -f "mp4" --downloader ffmpeg --downloader-args "ffmpeg_i: -ss 00:00:20 -t 00:00:10" --force-overwrites --max-downloads 1 --no-warnings --no-playlist --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
127
+ wav_ = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
128
+ else:
129
+ mp4_ = f'yt-dlp -f "mp4" --force-overwrites --max-downloads 1 --no-warnings --no-playlist --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
130
+ wav_ = f"python -m yt_dlp --output {audio_wav} --force-overwrites --max-downloads 1 --no-warnings --no-playlist --no-abort-on-error --ignore-no-formats-error --extract-audio --audio-format wav {video}"
131
+
132
+ # Run cmd process
133
+ mp4_ = shlex.split(mp4_)
134
+ sub_params = {
135
+ "stdout": subprocess.PIPE,
136
+ "stderr": subprocess.PIPE,
137
+ "creationflags": subprocess.CREATE_NO_WINDOW
138
+ if sys.platform == "win32"
139
+ else 0,
140
+ }
141
+
142
+ if os.path.exists(video):
143
+ logger.info("Process video...")
144
+ result_convert_video = subprocess.Popen(mp4_, **sub_params)
145
+ # result_convert_video.wait()
146
+ output, errors = result_convert_video.communicate()
147
+ time.sleep(1)
148
+ if result_convert_video.returncode in [1, 2] or not os.path.exists(
149
+ OutputFile
150
+ ):
151
+ raise OperationFailedError(f"Error processing video:\n{errors.decode('utf-8')}")
152
+ logger.info("Process audio...")
153
+ wav_ = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
154
+ wav_ = shlex.split(wav_)
155
+ result_convert_audio = subprocess.Popen(wav_, **sub_params)
156
+ output, errors = result_convert_audio.communicate()
157
+ time.sleep(1)
158
+ if result_convert_audio.returncode in [1, 2] or not os.path.exists(
159
+ audio_wav
160
+ ):
161
+ raise OperationFailedError(f"Error can't create the audio file:\n{errors.decode('utf-8')}")
162
+
163
+ else:
164
+ wav_ = shlex.split(wav_)
165
+ if preview:
166
+ result_convert_video = subprocess.Popen(mp4_, **sub_params)
167
+ output, errors = result_convert_video.communicate()
168
+ time.sleep(0.5)
169
+ result_convert_audio = subprocess.Popen(wav_, **sub_params)
170
+ output, errors = result_convert_audio.communicate()
171
+ time.sleep(0.5)
172
+ if result_convert_audio.returncode in [1, 2] or not os.path.exists(
173
+ audio_wav
174
+ ):
175
+ raise OperationFailedError(
176
+ f"Error can't create the preview file:\n{errors.decode('utf-8')}"
177
+ )
178
+ else:
179
+ logger.info("Process audio...")
180
+ result_convert_audio = subprocess.Popen(wav_, **sub_params)
181
+ output, errors = result_convert_audio.communicate()
182
+ time.sleep(1)
183
+ if result_convert_audio.returncode in [1, 2] or not os.path.exists(
184
+ audio_wav
185
+ ):
186
+ raise OperationFailedError(f"Error can't download the audio:\n{errors.decode('utf-8')}")
187
+ logger.info("Process video...")
188
+ result_convert_video = subprocess.Popen(mp4_, **sub_params)
189
+ output, errors = result_convert_video.communicate()
190
+ time.sleep(1)
191
+ if result_convert_video.returncode in [1, 2] or not os.path.exists(
192
+ OutputFile
193
+ ):
194
+ raise OperationFailedError(f"Error can't download the video:\n{errors.decode('utf-8')}")
195
+
196
+
197
+ def old_audio_video_preprocessor(preview, video, OutputFile, audio_wav):
198
+ previous_files_to_remove = [OutputFile, "audio.webm", audio_wav]
199
+ remove_files(previous_files_to_remove)
200
+
201
+ if os.path.exists(video):
202
+ if preview:
203
+ logger.warning(
204
+ "Creating a preview video of 10 seconds, "
205
+ "to disable this option, go to advanced "
206
+ "settings and turn off preview."
207
+ )
208
+ command = f'ffmpeg -y -i "{video}" -ss 00:00:20 -t 00:00:10 -c:v libx264 -c:a aac -strict experimental Video.mp4'
209
+ result_convert_video = subprocess.run(
210
+ command, capture_output=True, text=True, shell=True
211
+ )
212
+ else:
213
+ # Check if the file ends with ".mp4" extension
214
+ if video.endswith(".mp4"):
215
+ destination_path = os.path.join(os.getcwd(), "Video.mp4")
216
+ shutil.copy(video, destination_path)
217
+ result_convert_video = {}
218
+ result_convert_video = subprocess.run(
219
+ "echo Video copied",
220
+ capture_output=True,
221
+ text=True,
222
+ shell=True,
223
+ )
224
+ else:
225
+ logger.warning(
226
+ "File does not have the '.mp4' extension. Converting video."
227
+ )
228
+ command = f'ffmpeg -y -i "{video}" -c:v libx264 -c:a aac -strict experimental Video.mp4'
229
+ result_convert_video = subprocess.run(
230
+ command, capture_output=True, text=True, shell=True
231
+ )
232
+
233
+ if result_convert_video.returncode in [1, 2]:
234
+ raise OperationFailedError("Error can't convert the video")
235
+
236
+ for i in range(120):
237
+ time.sleep(1)
238
+ logger.info("Process video...")
239
+ if os.path.exists(OutputFile):
240
+ time.sleep(1)
241
+ command = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
242
+ result_convert_audio = subprocess.run(
243
+ command, capture_output=True, text=True, shell=True
244
+ )
245
+ time.sleep(1)
246
+ break
247
+ if i == 119:
248
+ # if not os.path.exists(OutputFile):
249
+ raise OperationFailedError("Error processing video")
250
+
251
+ if result_convert_audio.returncode in [1, 2]:
252
+ raise OperationFailedError(
253
+ f"Error can't create the audio file: {result_convert_audio.stderr}"
254
+ )
255
+
256
+ for i in range(120):
257
+ time.sleep(1)
258
+ logger.info("Process audio...")
259
+ if os.path.exists(audio_wav):
260
+ break
261
+ if i == 119:
262
+ raise OperationFailedError("Error can't create the audio file")
263
+
264
+ else:
265
+ video = video.strip()
266
+ if preview:
267
+ logger.warning(
268
+ "Creating a preview from the link, 10 "
269
+ "seconds to disable this option, go to "
270
+ "advanced settings and turn off preview."
271
+ )
272
+ # https://github.com/yt-dlp/yt-dlp/issues/2220
273
+ mp4_ = f'yt-dlp -f "mp4" --downloader ffmpeg --downloader-args "ffmpeg_i: -ss 00:00:20 -t 00:00:10" --force-overwrites --max-downloads 1 --no-warnings --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
274
+ wav_ = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
275
+ result_convert_video = subprocess.run(
276
+ mp4_, capture_output=True, text=True, shell=True
277
+ )
278
+ result_convert_audio = subprocess.run(
279
+ wav_, capture_output=True, text=True, shell=True
280
+ )
281
+ if result_convert_audio.returncode in [1, 2]:
282
+ raise OperationFailedError("Error can't download a preview")
283
+ else:
284
+ mp4_ = f'yt-dlp -f "mp4" --force-overwrites --max-downloads 1 --no-warnings --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
285
+ wav_ = f"python -m yt_dlp --output {audio_wav} --force-overwrites --max-downloads 1 --no-warnings --no-abort-on-error --ignore-no-formats-error --extract-audio --audio-format wav {video}"
286
+
287
+ result_convert_audio = subprocess.run(
288
+ wav_, capture_output=True, text=True, shell=True
289
+ )
290
+
291
+ if result_convert_audio.returncode in [1, 2]:
292
+ raise OperationFailedError("Error can't download the audio")
293
+
294
+ for i in range(120):
295
+ time.sleep(1)
296
+ logger.info("Process audio...")
297
+ if os.path.exists(audio_wav) and not os.path.exists(
298
+ "audio.webm"
299
+ ):
300
+ time.sleep(1)
301
+ result_convert_video = subprocess.run(
302
+ mp4_, capture_output=True, text=True, shell=True
303
+ )
304
+ break
305
+ if i == 119:
306
+ raise OperationFailedError("Error downloading the audio")
307
+
308
+ if result_convert_video.returncode in [1, 2]:
309
+ raise OperationFailedError("Error can't download the video")
quantum_dubbing/speech_segmentation.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from whisperx.alignment import (
2
+ DEFAULT_ALIGN_MODELS_TORCH as DAMT,
3
+ DEFAULT_ALIGN_MODELS_HF as DAMHF,
4
+ )
5
+ from whisperx.utils import TO_LANGUAGE_CODE
6
+ import whisperx
7
+ import torch
8
+ import gc
9
+ import os
10
+ import soundfile as sf
11
+ from IPython.utils import capture # noqa
12
+ from .language_configuration import EXTRA_ALIGN, INVERTED_LANGUAGES
13
+ from .logging_setup import logger
14
+ from .postprocessor import sanitize_file_name
15
+ from .utils import remove_directory_contents, run_command
16
+
17
+ # ZERO GPU CONFIG
18
+ import spaces
19
+ import copy
20
+ import random
21
+ import time
22
+
23
+ def random_sleep():
24
+ if os.environ.get("ZERO_GPU") == "TRUE":
25
+ print("Random sleep")
26
+ sleep_time = round(random.uniform(7.2, 9.9), 1)
27
+ time.sleep(sleep_time)
28
+
29
+
30
+ @spaces.GPU
31
+ def load_and_transcribe_audio(asr_model, audio, compute_type, language, asr_options, batch_size, segment_duration_limit):
32
+ # Load model
33
+ model = whisperx.load_model(
34
+ asr_model,
35
+ os.environ.get("QUANTUM_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda",
36
+ compute_type=compute_type,
37
+ language=language,
38
+ asr_options=asr_options,
39
+ )
40
+
41
+ # Transcribe audio
42
+ result = model.transcribe(
43
+ audio,
44
+ batch_size=batch_size,
45
+ chunk_size=segment_duration_limit,
46
+ print_progress=True,
47
+ )
48
+
49
+ del model
50
+ gc.collect()
51
+ torch.cuda.empty_cache() # noqa
52
+
53
+ return result
54
+
55
+ def load_align_and_align_segments(result, audio, DAMHF):
56
+
57
+ # Load alignment model
58
+ model_a, metadata = whisperx.load_align_model(
59
+ language_code=result["language"],
60
+ device=os.environ.get("QUANTUM_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cpu",
61
+ model_name=None
62
+ if result["language"] in DAMHF.keys()
63
+ else EXTRA_ALIGN[result["language"]],
64
+ )
65
+
66
+ # Align segments
67
+ alignment_result = whisperx.align(
68
+ result["segments"],
69
+ model_a,
70
+ metadata,
71
+ audio,
72
+ os.environ.get("QUANTUM_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cpu",
73
+ return_char_alignments=True,
74
+ print_progress=False,
75
+ )
76
+
77
+ # Clean up
78
+ del model_a
79
+ gc.collect()
80
+ torch.cuda.empty_cache() # noqa
81
+
82
+ return alignment_result
83
+
84
+ @spaces.GPU
85
+ def diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers):
86
+
87
+ if os.environ.get("ZERO_GPU") == "TRUE":
88
+ diarize_model.model.to(torch.device("cuda"))
89
+ diarize_segments = diarize_model(
90
+ audio_wav,
91
+ min_speakers=min_speakers,
92
+ max_speakers=max_speakers
93
+ )
94
+ return diarize_segments
95
+
96
+ # ZERO GPU CONFIG
97
+
98
+ ASR_MODEL_OPTIONS = [
99
+ "tiny",
100
+ "base",
101
+ "small",
102
+ "medium",
103
+ "large",
104
+ "large-v1",
105
+ "large-v2",
106
+ "large-v3",
107
+ "distil-large-v2",
108
+ "Systran/faster-distil-whisper-large-v3",
109
+ "tiny.en",
110
+ "base.en",
111
+ "small.en",
112
+ "medium.en",
113
+ "distil-small.en",
114
+ "distil-medium.en",
115
+ "OpenAI_API_Whisper",
116
+ ]
117
+
118
+ COMPUTE_TYPE_GPU = [
119
+ "default",
120
+ "auto",
121
+ "int8",
122
+ "int8_float32",
123
+ "int8_float16",
124
+ "int8_bfloat16",
125
+ "float16",
126
+ "bfloat16",
127
+ "float32"
128
+ ]
129
+
130
+ COMPUTE_TYPE_CPU = [
131
+ "default",
132
+ "auto",
133
+ "int8",
134
+ "int8_float32",
135
+ "int16",
136
+ "float32",
137
+ ]
138
+
139
+ WHISPER_MODELS_PATH = './WHISPER_MODELS'
140
+
141
+
142
+ def openai_api_whisper(
143
+ input_audio_file,
144
+ source_lang=None,
145
+ chunk_duration=1800
146
+ ):
147
+
148
+ info = sf.info(input_audio_file)
149
+ duration = info.duration
150
+
151
+ output_directory = "./whisper_api_audio_parts"
152
+ os.makedirs(output_directory, exist_ok=True)
153
+ remove_directory_contents(output_directory)
154
+
155
+ if duration > chunk_duration:
156
+ # Split the audio file into smaller chunks with 30-minute duration
157
+ cm = f'ffmpeg -i "{input_audio_file}" -f segment -segment_time {chunk_duration} -c:a libvorbis "{output_directory}/output%03d.ogg"'
158
+ run_command(cm)
159
+ # Get list of generated chunk files
160
+ chunk_files = sorted(
161
+ [f"{output_directory}/{f}" for f in os.listdir(output_directory) if f.endswith('.ogg')]
162
+ )
163
+ else:
164
+ one_file = f"{output_directory}/output000.ogg"
165
+ cm = f'ffmpeg -i "{input_audio_file}" -c:a libvorbis {one_file}'
166
+ run_command(cm)
167
+ chunk_files = [one_file]
168
+
169
+ # Transcript
170
+ segments = []
171
+ language = source_lang if source_lang else None
172
+ for i, chunk in enumerate(chunk_files):
173
+ from openai import OpenAI
174
+ client = OpenAI()
175
+
176
+ audio_file = open(chunk, "rb")
177
+ transcription = client.audio.transcriptions.create(
178
+ model="whisper-1",
179
+ file=audio_file,
180
+ language=language,
181
+ response_format="verbose_json",
182
+ timestamp_granularities=["segment"],
183
+ )
184
+
185
+ try:
186
+ transcript_dict = transcription.model_dump()
187
+ except: # noqa
188
+ transcript_dict = transcription.to_dict()
189
+
190
+ if language is None:
191
+ logger.info(f'Language detected: {transcript_dict["language"]}')
192
+ language = TO_LANGUAGE_CODE[transcript_dict["language"]]
193
+
194
+ chunk_time = chunk_duration * (i)
195
+
196
+ for seg in transcript_dict["segments"]:
197
+
198
+ if "start" in seg.keys():
199
+ segments.append(
200
+ {
201
+ "text": seg["text"],
202
+ "start": seg["start"] + chunk_time,
203
+ "end": seg["end"] + chunk_time,
204
+ }
205
+ )
206
+
207
+ audio = whisperx.load_audio(input_audio_file)
208
+ result = {"segments": segments, "language": language}
209
+
210
+ return audio, result
211
+
212
+
213
+ def find_whisper_models():
214
+ path = WHISPER_MODELS_PATH
215
+ folders = []
216
+
217
+ if os.path.exists(path):
218
+ for folder in os.listdir(path):
219
+ folder_path = os.path.join(path, folder)
220
+ if (
221
+ os.path.isdir(folder_path)
222
+ and 'model.bin' in os.listdir(folder_path)
223
+ ):
224
+ folders.append(folder)
225
+ return folders
226
+
227
+ def transcribe_speech(
228
+ audio_wav,
229
+ asr_model,
230
+ compute_type,
231
+ batch_size,
232
+ SOURCE_LANGUAGE,
233
+ literalize_numbers=True,
234
+ segment_duration_limit=15,
235
+ ):
236
+ """
237
+ Transcribe speech using a whisper model.
238
+
239
+ Parameters:
240
+ - audio_wav (str): Path to the audio file in WAV format.
241
+ - asr_model (str): The whisper model to be loaded.
242
+ - compute_type (str): Type of compute to be used (e.g., 'int8', 'float16').
243
+ - batch_size (int): Batch size for transcription.
244
+ - SOURCE_LANGUAGE (str): Source language for transcription.
245
+
246
+ Returns:
247
+ - Tuple containing:
248
+ - audio: Loaded audio file.
249
+ - result: Transcription result as a dictionary.
250
+ """
251
+
252
+ if asr_model == "OpenAI_API_Whisper":
253
+ if literalize_numbers:
254
+ logger.info(
255
+ "OpenAI's API Whisper does not support "
256
+ "the literalization of numbers."
257
+ )
258
+ return openai_api_whisper(audio_wav, SOURCE_LANGUAGE)
259
+
260
+ # https://github.com/openai/whisper/discussions/277
261
+ prompt = "以下是普通话的句子。" if SOURCE_LANGUAGE == "zh" else None
262
+ SOURCE_LANGUAGE = (
263
+ SOURCE_LANGUAGE if SOURCE_LANGUAGE != "zh-TW" else "zh"
264
+ )
265
+ asr_options = {
266
+ "initial_prompt": prompt,
267
+ "suppress_numerals": literalize_numbers
268
+ }
269
+
270
+ if asr_model not in ASR_MODEL_OPTIONS:
271
+
272
+ base_dir = WHISPER_MODELS_PATH
273
+ if not os.path.exists(base_dir):
274
+ os.makedirs(base_dir)
275
+ model_dir = os.path.join(base_dir, sanitize_file_name(asr_model))
276
+
277
+ if not os.path.exists(model_dir):
278
+ from ctranslate2.converters import TransformersConverter
279
+
280
+ quantization = "float32"
281
+ # Download new model
282
+ try:
283
+ converter = TransformersConverter(
284
+ asr_model,
285
+ low_cpu_mem_usage=True,
286
+ copy_files=[
287
+ "tokenizer_config.json", "preprocessor_config.json"
288
+ ]
289
+ )
290
+ converter.convert(
291
+ model_dir,
292
+ quantization=quantization,
293
+ force=False
294
+ )
295
+ except Exception as error:
296
+ if "File tokenizer_config.json does not exist" in str(error):
297
+ converter._copy_files = [
298
+ "tokenizer.json", "preprocessor_config.json"
299
+ ]
300
+ converter.convert(
301
+ model_dir,
302
+ quantization=quantization,
303
+ force=True
304
+ )
305
+ else:
306
+ raise error
307
+
308
+ asr_model = model_dir
309
+ logger.info(f"ASR Model: {str(model_dir)}")
310
+
311
+ audio = whisperx.load_audio(audio_wav)
312
+
313
+ result = load_and_transcribe_audio(
314
+ asr_model, audio, compute_type, SOURCE_LANGUAGE, asr_options, batch_size, segment_duration_limit
315
+ )
316
+
317
+ if result["language"] == "zh" and not prompt:
318
+ result["language"] = "zh-TW"
319
+ logger.info("Chinese - Traditional (zh-TW)")
320
+
321
+
322
+ return audio, result
323
+
324
+
325
+ def align_speech(audio, result):
326
+ """
327
+ Aligns speech segments based on the provided audio and result metadata.
328
+
329
+ Parameters:
330
+ - audio (array): The audio data in a suitable format for alignment.
331
+ - result (dict): Metadata containing information about the segments
332
+ and language.
333
+
334
+ Returns:
335
+ - result (dict): Updated metadata after aligning the segments with
336
+ the audio. This includes character-level alignments if
337
+ 'return_char_alignments' is set to True.
338
+
339
+ Notes:
340
+ - This function uses language-specific models to align speech segments.
341
+ - It performs language compatibility checks and selects the
342
+ appropriate alignment model.
343
+ - Cleans up memory by releasing resources after alignment.
344
+ """
345
+ DAMHF.update(DAMT) # lang align
346
+ if (
347
+ not result["language"] in DAMHF.keys()
348
+ and not result["language"] in EXTRA_ALIGN.keys()
349
+ ):
350
+ logger.warning(
351
+ "Automatic detection: Source language not compatible with align"
352
+ )
353
+ raise ValueError(
354
+ f"Detected language {result['language']} incompatible, "
355
+ "you can select the source language to avoid this error."
356
+ )
357
+ if (
358
+ result["language"] in EXTRA_ALIGN.keys()
359
+ and EXTRA_ALIGN[result["language"]] == ""
360
+ ):
361
+ lang_name = (
362
+ INVERTED_LANGUAGES[result["language"]]
363
+ if result["language"] in INVERTED_LANGUAGES.keys()
364
+ else result["language"]
365
+ )
366
+ logger.warning(
367
+ "No compatible wav2vec2 model found "
368
+ f"for the language '{lang_name}', skipping alignment."
369
+ )
370
+ return result
371
+
372
+ # random_sleep()
373
+ result = load_align_and_align_segments(result, audio, DAMHF)
374
+
375
+ return result
376
+
377
+
378
+ diarization_models = {
379
+ "pyannote_3.1": "pyannote/speaker-diarization-3.1",
380
+ "pyannote_2.1": "pyannote/[email protected]",
381
+ "disable": "",
382
+ }
383
+
384
+
385
+ def reencode_speakers(result):
386
+
387
+ if result["segments"][0]["speaker"] == "SPEAKER_00":
388
+ return result
389
+
390
+ speaker_mapping = {}
391
+ counter = 0
392
+
393
+ logger.debug("Reencode speakers")
394
+
395
+ for segment in result["segments"]:
396
+ old_speaker = segment["speaker"]
397
+ if old_speaker not in speaker_mapping:
398
+ speaker_mapping[old_speaker] = f"SPEAKER_{counter:02d}"
399
+ counter += 1
400
+ segment["speaker"] = speaker_mapping[old_speaker]
401
+
402
+ return result
403
+
404
+
405
+ def diarize_speech(
406
+ audio_wav,
407
+ result,
408
+ min_speakers,
409
+ max_speakers,
410
+ YOUR_HF_TOKEN,
411
+ model_name="pyannote/[email protected]",
412
+ ):
413
+ """
414
+ Performs speaker diarization on speech segments.
415
+
416
+ Parameters:
417
+ - audio_wav (array): Audio data in WAV format to perform speaker
418
+ diarization.
419
+ - result (dict): Metadata containing information about speech segments
420
+ and alignments.
421
+ - min_speakers (int): Minimum number of speakers expected in the audio.
422
+ - max_speakers (int): Maximum number of speakers expected in the audio.
423
+ - YOUR_HF_TOKEN (str): Your Hugging Face API token for model
424
+ authentication.
425
+ - model_name (str): Name of the speaker diarization model to be used
426
+ (default: "pyannote/[email protected]").
427
+
428
+ Returns:
429
+ - result_diarize (dict): Updated metadata after assigning speaker
430
+ labels to segments.
431
+
432
+ Notes:
433
+ - This function utilizes a speaker diarization model to label speaker
434
+ segments in the audio.
435
+ - It assigns speakers to word-level segments based on diarization results.
436
+ - Cleans up memory by releasing resources after diarization.
437
+ - If only one speaker is specified, each segment is automatically assigned
438
+ as the first speaker, eliminating the need for diarization inference.
439
+ """
440
+
441
+ if max(min_speakers, max_speakers) > 1 and model_name:
442
+ try:
443
+
444
+ diarize_model = whisperx.DiarizationPipeline(
445
+ model_name=model_name,
446
+ use_auth_token=YOUR_HF_TOKEN,
447
+ device=os.environ.get("QUANTUM_DEVICE"),
448
+ )
449
+
450
+ except Exception as error:
451
+ error_str = str(error)
452
+ gc.collect()
453
+ torch.cuda.empty_cache() # noqa
454
+ if "'NoneType' object has no attribute 'to'" in error_str:
455
+ if model_name == diarization_models["pyannote_2.1"]:
456
+ raise ValueError(
457
+ "Accept the license agreement for using Pyannote 2.1."
458
+ " You need to have an account on Hugging Face and "
459
+ "accept the license to use the models: "
460
+ "https://huggingface.co/pyannote/speaker-diarization "
461
+ "and https://huggingface.co/pyannote/segmentation "
462
+ "Get your KEY TOKEN here: "
463
+ "https://hf.co/settings/tokens "
464
+ )
465
+ elif model_name == diarization_models["pyannote_3.1"]:
466
+ raise ValueError(
467
+ "New Licence Pyannote 3.1: You need to have an account"
468
+ " on Hugging Face and accept the license to use the "
469
+ "models: https://huggingface.co/pyannote/speaker-diarization-3.1 " # noqa
470
+ "and https://huggingface.co/pyannote/segmentation-3.0 "
471
+ )
472
+ else:
473
+ raise error
474
+
475
+ random_sleep()
476
+ diarize_segments = diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers)
477
+
478
+ result_diarize = whisperx.assign_word_speakers(
479
+ diarize_segments, result
480
+ )
481
+
482
+ for segment in result_diarize["segments"]:
483
+ if "speaker" not in segment:
484
+ segment["speaker"] = "SPEAKER_00"
485
+ logger.warning(
486
+ f"No speaker detected in {segment['start']}. First TTS "
487
+ f"will be used for the segment text: {segment['text']} "
488
+ )
489
+
490
+ del diarize_model
491
+ gc.collect()
492
+ torch.cuda.empty_cache() # noqa
493
+ else:
494
+ result_diarize = result
495
+ result_diarize["segments"] = [
496
+ {**item, "speaker": "SPEAKER_00"}
497
+ for item in result_diarize["segments"]
498
+ ]
499
+ return reencode_speakers(result_diarize)
quantum_dubbing/text_multiformat_processor.py ADDED
@@ -0,0 +1,987 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .logging_setup import logger
2
+ from whisperx.utils import get_writer
3
+ from .utils import remove_files, run_command, remove_directory_contents
4
+ from typing import List
5
+ import srt
6
+ import re
7
+ import os
8
+ import copy
9
+ import string
10
+ import soundfile as sf
11
+ from PIL import Image, ImageOps, ImageDraw, ImageFont
12
+
13
+ punctuation_list = list(
14
+ string.punctuation + "¡¿«»„”“”‚‘’「」『』《》()【】〈〉〔〕〖〗〘〙〚〛⸤⸥⸨⸩"
15
+ )
16
+ symbol_list = punctuation_list + ["", "..", "..."]
17
+
18
+
19
+ def extract_from_srt(file_path):
20
+ with open(file_path, "r", encoding="utf-8") as file:
21
+ srt_content = file.read()
22
+
23
+ subtitle_generator = srt.parse(srt_content)
24
+ srt_content_list = list(subtitle_generator)
25
+
26
+ return srt_content_list
27
+
28
+
29
+ def clean_text(text):
30
+
31
+ # Remove content within square brackets
32
+ text = re.sub(r'\[.*?\]', '', text)
33
+ # Add pattern to remove content within <comment> tags
34
+ text = re.sub(r'<comment>.*?</comment>', '', text)
35
+ # Remove HTML tags
36
+ text = re.sub(r'<.*?>', '', text)
37
+ # Remove "♫" and "♪" content
38
+ text = re.sub(r'♫.*?♫', '', text)
39
+ text = re.sub(r'♪.*?♪', '', text)
40
+ # Replace newline characters with an empty string
41
+ text = text.replace("\n", ". ")
42
+ # Remove double quotation marks
43
+ text = text.replace('"', '')
44
+ # Collapse multiple spaces and replace with a single space
45
+ text = re.sub(r"\s+", " ", text)
46
+ # Normalize spaces around periods
47
+ text = re.sub(r"[\s\.]+(?=\s)", ". ", text)
48
+ # Check if there are ♫ or ♪ symbols present
49
+ if '♫' in text or '♪' in text:
50
+ return ""
51
+
52
+ text = text.strip()
53
+
54
+ # Valid text
55
+ return text if text not in symbol_list else ""
56
+
57
+
58
+ def srt_file_to_segments(file_path, speaker=False):
59
+ try:
60
+ srt_content_list = extract_from_srt(file_path)
61
+ except Exception as error:
62
+ logger.error(str(error))
63
+ fixed_file = "fixed_sub.srt"
64
+ remove_files(fixed_file)
65
+ fix_sub = f'ffmpeg -i "{file_path}" "{fixed_file}" -y'
66
+ run_command(fix_sub)
67
+ srt_content_list = extract_from_srt(fixed_file)
68
+
69
+ segments = []
70
+ for segment in srt_content_list:
71
+
72
+ text = clean_text(str(segment.content))
73
+
74
+ if text:
75
+ segments.append(
76
+ {
77
+ "text": text,
78
+ "start": float(segment.start.total_seconds()),
79
+ "end": float(segment.end.total_seconds()),
80
+ }
81
+ )
82
+
83
+ if not segments:
84
+ raise Exception("No data found in srt subtitle file")
85
+
86
+ if speaker:
87
+ segments = [{**seg, "speaker": "SPEAKER_00"} for seg in segments]
88
+
89
+ return {"segments": segments}
90
+
91
+
92
+ # documents
93
+
94
+
95
+ def dehyphenate(lines: List[str], line_no: int) -> List[str]:
96
+ next_line = lines[line_no + 1]
97
+ word_suffix = next_line.split(" ")[0]
98
+
99
+ lines[line_no] = lines[line_no][:-1] + word_suffix
100
+ lines[line_no + 1] = lines[line_no + 1][len(word_suffix):]
101
+ return lines
102
+
103
+
104
+ def remove_hyphens(text: str) -> str:
105
+ """
106
+
107
+ This fails for:
108
+ * Natural dashes: well-known, self-replication, use-cases, non-semantic,
109
+ Post-processing, Window-wise, viewpoint-dependent
110
+ * Trailing math operands: 2 - 4
111
+ * Names: Lopez-Ferreras, VGG-19, CIFAR-100
112
+ """
113
+ lines = [line.rstrip() for line in text.split("\n")]
114
+
115
+ # Find dashes
116
+ line_numbers = []
117
+ for line_no, line in enumerate(lines[:-1]):
118
+ if line.endswith("-"):
119
+ line_numbers.append(line_no)
120
+
121
+ # Replace
122
+ for line_no in line_numbers:
123
+ lines = dehyphenate(lines, line_no)
124
+
125
+ return "\n".join(lines)
126
+
127
+
128
+ def pdf_to_txt(pdf_file, start_page, end_page):
129
+ from pypdf import PdfReader
130
+
131
+ with open(pdf_file, "rb") as file:
132
+ reader = PdfReader(file)
133
+ logger.debug(f"Total pages: {reader.get_num_pages()}")
134
+ text = ""
135
+
136
+ start_page_idx = max((start_page-1), 0)
137
+ end_page_inx = min((end_page), (reader.get_num_pages()))
138
+ document_pages = reader.pages[start_page_idx:end_page_inx]
139
+ logger.info(
140
+ f"Selected pages from {start_page_idx} to {end_page_inx}: "
141
+ f"{len(document_pages)}"
142
+ )
143
+
144
+ for page in document_pages:
145
+ text += remove_hyphens(page.extract_text())
146
+ return text
147
+
148
+
149
+ def docx_to_txt(docx_file):
150
+ # https://github.com/AlJohri/docx2pdf update
151
+ from docx import Document
152
+
153
+ doc = Document(docx_file)
154
+ text = ""
155
+ for paragraph in doc.paragraphs:
156
+ text += paragraph.text + "\n"
157
+ return text
158
+
159
+
160
+ def replace_multiple_elements(text, replacements):
161
+ pattern = re.compile("|".join(map(re.escape, replacements.keys())))
162
+ replaced_text = pattern.sub(
163
+ lambda match: replacements[match.group(0)], text
164
+ )
165
+
166
+ # Remove multiple spaces
167
+ replaced_text = re.sub(r"\s+", " ", replaced_text)
168
+
169
+ return replaced_text
170
+
171
+
172
+ def document_preprocessor(file_path, is_string, start_page, end_page):
173
+ if not is_string:
174
+ file_ext = os.path.splitext(file_path)[1].lower()
175
+
176
+ if is_string:
177
+ text = file_path
178
+ elif file_ext == ".pdf":
179
+ text = pdf_to_txt(file_path, start_page, end_page)
180
+ elif file_ext == ".docx":
181
+ text = docx_to_txt(file_path)
182
+ elif file_ext == ".txt":
183
+ with open(
184
+ file_path, "r", encoding='utf-8', errors='replace'
185
+ ) as file:
186
+ text = file.read()
187
+ else:
188
+ raise Exception("Unsupported file format")
189
+
190
+ # Add space to break segments more easily later
191
+ replacements = {
192
+ "、": "、 ",
193
+ "。": "。 ",
194
+ # "\n": " ",
195
+ }
196
+ text = replace_multiple_elements(text, replacements)
197
+
198
+ # Save text to a .txt file
199
+ # file_name = os.path.splitext(os.path.basename(file_path))[0]
200
+ txt_file_path = "./text_preprocessor.txt"
201
+
202
+ with open(
203
+ txt_file_path, "w", encoding='utf-8', errors='replace'
204
+ ) as txt_file:
205
+ txt_file.write(text)
206
+
207
+ return txt_file_path, text
208
+
209
+
210
+ def split_text_into_chunks(text, chunk_size):
211
+ words = re.findall(r"\b\w+\b", text)
212
+ chunks = []
213
+ current_chunk = ""
214
+ for word in words:
215
+ if (
216
+ len(current_chunk) + len(word) + 1 <= chunk_size
217
+ ): # Adding 1 for the space between words
218
+ if current_chunk:
219
+ current_chunk += " "
220
+ current_chunk += word
221
+ else:
222
+ chunks.append(current_chunk)
223
+ current_chunk = word
224
+ if current_chunk:
225
+ chunks.append(current_chunk)
226
+ return chunks
227
+
228
+
229
+ def determine_chunk_size(file_name):
230
+ patterns = {
231
+ re.compile(r".*-(Male|Female)$"): 1024, # by character
232
+ re.compile(r".* BARK$"): 100, # t 64 256
233
+ re.compile(r".* VITS$"): 500,
234
+ re.compile(
235
+ r".+\.(wav|mp3|ogg|m4a)$"
236
+ ): 150, # t 250 400 api automatic split
237
+ re.compile(r".* VITS-onnx$"): 250, # automatic sentence split
238
+ re.compile(r".* OpenAI-TTS$"): 1024 # max charaters 4096
239
+ }
240
+
241
+ for pattern, chunk_size in patterns.items():
242
+ if pattern.match(file_name):
243
+ return chunk_size
244
+
245
+ # Default chunk size if the file doesn't match any pattern; max 1800
246
+ return 100
247
+
248
+
249
+ def plain_text_to_segments(result_text=None, chunk_size=None):
250
+ if not chunk_size:
251
+ chunk_size = 100
252
+ text_chunks = split_text_into_chunks(result_text, chunk_size)
253
+
254
+ segments_chunks = []
255
+ for num, chunk in enumerate(text_chunks):
256
+ chunk_dict = {
257
+ "text": chunk,
258
+ "start": (1.0 + num),
259
+ "end": (2.0 + num),
260
+ "speaker": "SPEAKER_00",
261
+ }
262
+ segments_chunks.append(chunk_dict)
263
+
264
+ result_diarize = {"segments": segments_chunks}
265
+
266
+ return result_diarize
267
+
268
+
269
+ def segments_to_plain_text(result_diarize):
270
+ complete_text = ""
271
+ for seg in result_diarize["segments"]:
272
+ complete_text += seg["text"] + " " # issue
273
+
274
+ # Save text to a .txt file
275
+ # file_name = os.path.splitext(os.path.basename(file_path))[0]
276
+ txt_file_path = "./text_translation.txt"
277
+
278
+ with open(
279
+ txt_file_path, "w", encoding='utf-8', errors='replace'
280
+ ) as txt_file:
281
+ txt_file.write(complete_text)
282
+
283
+ return txt_file_path, complete_text
284
+
285
+
286
+ # doc to video
287
+
288
+ COLORS = {
289
+ "black": (0, 0, 0),
290
+ "white": (255, 255, 255),
291
+ "red": (255, 0, 0),
292
+ "green": (0, 255, 0),
293
+ "blue": (0, 0, 255),
294
+ "yellow": (255, 255, 0),
295
+ "light_gray": (200, 200, 200),
296
+ "light_blue": (173, 216, 230),
297
+ "light_green": (144, 238, 144),
298
+ "light_yellow": (255, 255, 224),
299
+ "light_pink": (255, 182, 193),
300
+ "lavender": (230, 230, 250),
301
+ "peach": (255, 218, 185),
302
+ "light_cyan": (224, 255, 255),
303
+ "light_salmon": (255, 160, 122),
304
+ "light_green_yellow": (173, 255, 47),
305
+ }
306
+
307
+ BORDER_COLORS = ["dynamic"] + list(COLORS.keys())
308
+
309
+
310
+ def calculate_average_color(img):
311
+ # Resize the image to a small size for faster processing
312
+ img_small = img.resize((50, 50))
313
+ # Calculate the average color
314
+ average_color = img_small.convert("RGB").resize((1, 1)).getpixel((0, 0))
315
+ return average_color
316
+
317
+
318
+ def add_border_to_image(
319
+ image_path,
320
+ target_width,
321
+ target_height,
322
+ border_color=None
323
+ ):
324
+
325
+ img = Image.open(image_path)
326
+
327
+ # Calculate the width and height for the new image with borders
328
+ original_width, original_height = img.size
329
+ original_aspect_ratio = original_width / original_height
330
+ target_aspect_ratio = target_width / target_height
331
+
332
+ # Resize the image to fit the target resolution retaining aspect ratio
333
+ if original_aspect_ratio > target_aspect_ratio:
334
+ # Image is wider, calculate new height
335
+ new_height = int(target_width / original_aspect_ratio)
336
+ resized_img = img.resize((target_width, new_height))
337
+ else:
338
+ # Image is taller, calculate new width
339
+ new_width = int(target_height * original_aspect_ratio)
340
+ resized_img = img.resize((new_width, target_height))
341
+
342
+ # Calculate padding for borders
343
+ padding = (0, 0, 0, 0)
344
+ if resized_img.size[0] != target_width or resized_img.size[1] != target_height:
345
+ if original_aspect_ratio > target_aspect_ratio:
346
+ # Add borders vertically
347
+ padding = (0, (target_height - resized_img.size[1]) // 2, 0, (target_height - resized_img.size[1]) // 2)
348
+ else:
349
+ # Add borders horizontally
350
+ padding = ((target_width - resized_img.size[0]) // 2, 0, (target_width - resized_img.size[0]) // 2, 0)
351
+
352
+ # Add borders with specified color
353
+ if not border_color or border_color == "dynamic":
354
+ border_color = calculate_average_color(resized_img)
355
+ else:
356
+ border_color = COLORS.get(border_color, (0, 0, 0))
357
+
358
+ bordered_img = ImageOps.expand(resized_img, padding, fill=border_color)
359
+
360
+ bordered_img.save(image_path)
361
+
362
+ return image_path
363
+
364
+
365
+ def resize_and_position_subimage(
366
+ subimage,
367
+ max_width,
368
+ max_height,
369
+ subimage_position,
370
+ main_width,
371
+ main_height
372
+ ):
373
+ subimage_width, subimage_height = subimage.size
374
+
375
+ # Resize subimage if it exceeds maximum dimensions
376
+ if subimage_width > max_width or subimage_height > max_height:
377
+ # Calculate scaling factor
378
+ width_scale = max_width / subimage_width
379
+ height_scale = max_height / subimage_height
380
+ scale = min(width_scale, height_scale)
381
+
382
+ # Resize subimage
383
+ subimage = subimage.resize(
384
+ (int(subimage_width * scale), int(subimage_height * scale))
385
+ )
386
+
387
+ # Calculate position to place the subimage
388
+ if subimage_position == "top-left":
389
+ subimage_x = 0
390
+ subimage_y = 0
391
+ elif subimage_position == "top-right":
392
+ subimage_x = main_width - subimage.width
393
+ subimage_y = 0
394
+ elif subimage_position == "bottom-left":
395
+ subimage_x = 0
396
+ subimage_y = main_height - subimage.height
397
+ elif subimage_position == "bottom-right":
398
+ subimage_x = main_width - subimage.width
399
+ subimage_y = main_height - subimage.height
400
+ else:
401
+ raise ValueError(
402
+ "Invalid subimage_position. Choose from 'top-left', 'top-right',"
403
+ " 'bottom-left', or 'bottom-right'."
404
+ )
405
+
406
+ return subimage, subimage_x, subimage_y
407
+
408
+
409
+ def create_image_with_text_and_subimages(
410
+ text,
411
+ subimages,
412
+ width,
413
+ height,
414
+ text_color,
415
+ background_color,
416
+ output_file
417
+ ):
418
+ # Create an image with the specified resolution and background color
419
+ image = Image.new('RGB', (width, height), color=background_color)
420
+
421
+ # Initialize ImageDraw object
422
+ draw = ImageDraw.Draw(image)
423
+
424
+ # Load a font
425
+ font = ImageFont.load_default() # You can specify your font file here
426
+
427
+ # Calculate text size and position
428
+ text_bbox = draw.textbbox((0, 0), text, font=font)
429
+ text_width = text_bbox[2] - text_bbox[0]
430
+ text_height = text_bbox[3] - text_bbox[1]
431
+ text_x = (width - text_width) / 2
432
+ text_y = (height - text_height) / 2
433
+
434
+ # Draw text on the image
435
+ draw.text((text_x, text_y), text, fill=text_color, font=font)
436
+
437
+ # Paste subimages onto the main image
438
+ for subimage_path, subimage_position in subimages:
439
+ # Open the subimage
440
+ subimage = Image.open(subimage_path)
441
+
442
+ # Convert subimage to RGBA mode if it doesn't have an alpha channel
443
+ if subimage.mode != 'RGBA':
444
+ subimage = subimage.convert('RGBA')
445
+
446
+ # Resize and position the subimage
447
+ subimage, subimage_x, subimage_y = resize_and_position_subimage(
448
+ subimage, width / 4, height / 4, subimage_position, width, height
449
+ )
450
+
451
+ # Paste the subimage onto the main image
452
+ image.paste(subimage, (int(subimage_x), int(subimage_y)), subimage)
453
+
454
+ image.save(output_file)
455
+
456
+ return output_file
457
+
458
+
459
+ def doc_to_txtximg_pages(
460
+ document,
461
+ width,
462
+ height,
463
+ start_page,
464
+ end_page,
465
+ bcolor
466
+ ):
467
+ from pypdf import PdfReader
468
+
469
+ images_folder = "pdf_images/"
470
+ os.makedirs(images_folder, exist_ok=True)
471
+ remove_directory_contents(images_folder)
472
+
473
+ # First image
474
+ text_image = os.path.basename(document)[:-4]
475
+ subimages = [("./assets/logo.jpeg", "top-left")]
476
+ text_color = (255, 255, 255) if bcolor == "black" else (0, 0, 0) # w|b
477
+ background_color = COLORS.get(bcolor, (255, 255, 255)) # dynamic white
478
+ first_image = "pdf_images/0000_00_aaa.png"
479
+
480
+ create_image_with_text_and_subimages(
481
+ text_image,
482
+ subimages,
483
+ width,
484
+ height,
485
+ text_color,
486
+ background_color,
487
+ first_image
488
+ )
489
+
490
+ reader = PdfReader(document)
491
+ logger.debug(f"Total pages: {reader.get_num_pages()}")
492
+
493
+ start_page_idx = max((start_page-1), 0)
494
+ end_page_inx = min((end_page), (reader.get_num_pages()))
495
+ document_pages = reader.pages[start_page_idx:end_page_inx]
496
+
497
+ logger.info(
498
+ f"Selected pages from {start_page_idx} to {end_page_inx}: "
499
+ f"{len(document_pages)}"
500
+ )
501
+
502
+ data_doc = {}
503
+ for i, page in enumerate(document_pages):
504
+
505
+ count = 0
506
+ images = []
507
+ for image_file_object in page.images:
508
+ img_name = f"{images_folder}{i:04d}_{count:02d}_{image_file_object.name}"
509
+ images.append(img_name)
510
+ with open(img_name, "wb") as fp:
511
+ fp.write(image_file_object.data)
512
+ count += 1
513
+ img_name = add_border_to_image(img_name, width, height, bcolor)
514
+
515
+ data_doc[i] = {
516
+ "text": remove_hyphens(page.extract_text()),
517
+ "images": images
518
+ }
519
+
520
+ return data_doc
521
+
522
+
523
+ def page_data_to_segments(result_text=None, chunk_size=None):
524
+
525
+ if not chunk_size:
526
+ chunk_size = 100
527
+
528
+ segments_chunks = []
529
+ time_global = 0
530
+ for page, result_data in result_text.items():
531
+ # result_image = result_data["images"]
532
+ result_text = result_data["text"]
533
+ text_chunks = split_text_into_chunks(result_text, chunk_size)
534
+ if not text_chunks:
535
+ text_chunks = [" "]
536
+
537
+ for chunk in text_chunks:
538
+ chunk_dict = {
539
+ "text": chunk,
540
+ "start": (1.0 + time_global),
541
+ "end": (2.0 + time_global),
542
+ "speaker": "SPEAKER_00",
543
+ "page": page,
544
+ }
545
+ segments_chunks.append(chunk_dict)
546
+ time_global += 1
547
+
548
+ result_diarize = {"segments": segments_chunks}
549
+
550
+ return result_diarize
551
+
552
+
553
+ def update_page_data(result_diarize, doc_data):
554
+ complete_text = ""
555
+ current_page = result_diarize["segments"][0]["page"]
556
+ text_page = ""
557
+
558
+ for seg in result_diarize["segments"]:
559
+ text = seg["text"] + " " # issue
560
+ complete_text += text
561
+
562
+ page = seg["page"]
563
+
564
+ if page == current_page:
565
+ text_page += text
566
+ else:
567
+ doc_data[current_page]["text"] = text_page
568
+
569
+ # Next
570
+ text_page = text
571
+ current_page = page
572
+
573
+ if doc_data[current_page]["text"] != text_page:
574
+ doc_data[current_page]["text"] = text_page
575
+
576
+ return doc_data
577
+
578
+
579
+ def fix_timestamps_docs(result_diarize, audio_files):
580
+ current_start = 0.0
581
+
582
+ for seg, audio in zip(result_diarize["segments"], audio_files):
583
+ duration = round(sf.info(audio).duration, 2)
584
+
585
+ seg["start"] = current_start
586
+ current_start += duration
587
+ seg["end"] = current_start
588
+
589
+ return result_diarize
590
+
591
+
592
+ def create_video_from_images(
593
+ doc_data,
594
+ result_diarize
595
+ ):
596
+
597
+ # First image path
598
+ first_image = "pdf_images/0000_00_aaa.png"
599
+
600
+ # Time segments and images
601
+ max_pages_idx = len(doc_data) - 1
602
+ current_page = result_diarize["segments"][0]["page"]
603
+ duration_page = 0.0
604
+ last_image = None
605
+
606
+ for seg in result_diarize["segments"]:
607
+ start = seg["start"]
608
+ end = seg["end"]
609
+ duration_seg = end - start
610
+
611
+ page = seg["page"]
612
+
613
+ if page == current_page:
614
+ duration_page += duration_seg
615
+ else:
616
+
617
+ images = doc_data[current_page]["images"]
618
+
619
+ if first_image:
620
+ images = [first_image] + images
621
+ first_image = None
622
+ if not doc_data[min(max_pages_idx, (current_page+1))]["text"].strip():
623
+ images = images + doc_data[min(max_pages_idx, (current_page+1))]["images"]
624
+ if not images and last_image:
625
+ images = [last_image]
626
+
627
+ # Calculate images duration
628
+ time_duration_per_image = round((duration_page / len(images)), 2)
629
+ doc_data[current_page]["time_per_image"] = time_duration_per_image
630
+
631
+ # Next values
632
+ doc_data[current_page]["images"] = images
633
+ last_image = images[-1]
634
+ duration_page = duration_seg
635
+ current_page = page
636
+
637
+ if "time_per_image" not in doc_data[current_page].keys():
638
+ images = doc_data[current_page]["images"]
639
+ if first_image:
640
+ images = [first_image] + images
641
+ if not images:
642
+ images = [last_image]
643
+ time_duration_per_image = round((duration_page / len(images)), 2)
644
+ doc_data[current_page]["time_per_image"] = time_duration_per_image
645
+
646
+ # Timestamped image video.
647
+ with open("list.txt", "w") as file:
648
+
649
+ for i, page in enumerate(doc_data.values()):
650
+
651
+ duration = page["time_per_image"]
652
+ for img in page["images"]:
653
+ if i == len(doc_data) - 1 and img == page["images"][-1]: # Check if it's the last item
654
+ file.write(f"file {img}\n")
655
+ file.write(f"outpoint {duration}")
656
+ else:
657
+ file.write(f"file {img}\n")
658
+ file.write(f"outpoint {duration}\n")
659
+
660
+ out_video = "video_from_images.mp4"
661
+ remove_files(out_video)
662
+
663
+ cm = f"ffmpeg -y -f concat -i list.txt -c:v libx264 -preset veryfast -crf 18 -pix_fmt yuv420p {out_video}"
664
+ cm_alt = f"ffmpeg -f concat -i list.txt -c:v libx264 -r 30 -pix_fmt yuv420p -y {out_video}"
665
+ try:
666
+ run_command(cm)
667
+ except Exception as error:
668
+ logger.error(str(error))
669
+ remove_files(out_video)
670
+ run_command(cm_alt)
671
+
672
+ return out_video
673
+
674
+
675
+ def merge_video_and_audio(video_doc, final_wav_file):
676
+
677
+ fixed_audio = "fixed_audio.mp3"
678
+ remove_files(fixed_audio)
679
+ cm = f"ffmpeg -i {final_wav_file} -c:a libmp3lame {fixed_audio}"
680
+ run_command(cm)
681
+
682
+ vid_out = "video_book.mp4"
683
+ remove_files(vid_out)
684
+ cm = f"ffmpeg -i {video_doc} -i {fixed_audio} -c:v copy -c:a copy -map 0:v -map 1:a -shortest {vid_out}"
685
+ run_command(cm)
686
+
687
+ return vid_out
688
+
689
+
690
+ # subtitles
691
+
692
+
693
+ def get_subtitle(
694
+ language,
695
+ segments_data,
696
+ extension,
697
+ filename=None,
698
+ highlight_words=False,
699
+ ):
700
+ if not filename:
701
+ filename = "task_subtitle"
702
+
703
+ is_ass_extension = False
704
+ if extension == "ass":
705
+ is_ass_extension = True
706
+ extension = "srt"
707
+
708
+ sub_file = filename + "." + extension
709
+ support_name = filename + ".mp3"
710
+ remove_files(sub_file)
711
+
712
+ writer = get_writer(extension, output_dir=".")
713
+ word_options = {
714
+ "highlight_words": highlight_words,
715
+ "max_line_count": None,
716
+ "max_line_width": None,
717
+ }
718
+
719
+ # Get data subs
720
+ subtitle_data = copy.deepcopy(segments_data)
721
+ subtitle_data["language"] = (
722
+ "ja" if language in ["ja", "zh", "zh-TW"] else language
723
+ )
724
+
725
+ # Clean
726
+ if not highlight_words:
727
+ subtitle_data.pop("word_segments", None)
728
+ for segment in subtitle_data["segments"]:
729
+ for key in ["speaker", "chars", "words"]:
730
+ segment.pop(key, None)
731
+
732
+ writer(
733
+ subtitle_data,
734
+ support_name,
735
+ word_options,
736
+ )
737
+
738
+ if is_ass_extension:
739
+ temp_name = filename + ".ass"
740
+ remove_files(temp_name)
741
+ convert_sub = f'ffmpeg -i "{sub_file}" "{temp_name}" -y'
742
+ run_command(convert_sub)
743
+ sub_file = temp_name
744
+
745
+ return sub_file
746
+
747
+
748
+ def process_subtitles(
749
+ deep_copied_result,
750
+ align_language,
751
+ result_diarize,
752
+ output_format_subtitle,
753
+ TRANSLATE_AUDIO_TO,
754
+ ):
755
+ name_ori = "sub_ori."
756
+ name_tra = "sub_tra."
757
+ remove_files(
758
+ [name_ori + output_format_subtitle, name_tra + output_format_subtitle]
759
+ )
760
+
761
+ writer = get_writer(output_format_subtitle, output_dir=".")
762
+ word_options = {
763
+ "highlight_words": False,
764
+ "max_line_count": None,
765
+ "max_line_width": None,
766
+ }
767
+
768
+ # original lang
769
+ subs_copy_result = copy.deepcopy(deep_copied_result)
770
+ subs_copy_result["language"] = (
771
+ "zh" if align_language == "zh-TW" else align_language
772
+ )
773
+ for segment in subs_copy_result["segments"]:
774
+ segment.pop("speaker", None)
775
+
776
+ try:
777
+ writer(
778
+ subs_copy_result,
779
+ name_ori[:-1] + ".mp3",
780
+ word_options,
781
+ )
782
+ except Exception as error:
783
+ logger.error(str(error))
784
+ if str(error) == "list indices must be integers or slices, not str":
785
+ logger.error(
786
+ "Related to poor word segmentation"
787
+ " in segments after alignment."
788
+ )
789
+ subs_copy_result["segments"][0].pop("words")
790
+ writer(
791
+ subs_copy_result,
792
+ name_ori[:-1] + ".mp3",
793
+ word_options,
794
+ )
795
+
796
+ # translated lang
797
+ subs_tra_copy_result = copy.deepcopy(result_diarize)
798
+ subs_tra_copy_result["language"] = (
799
+ "ja" if TRANSLATE_AUDIO_TO in ["ja", "zh", "zh-TW"] else align_language
800
+ )
801
+ subs_tra_copy_result.pop("word_segments", None)
802
+ for segment in subs_tra_copy_result["segments"]:
803
+ for key in ["speaker", "chars", "words"]:
804
+ segment.pop(key, None)
805
+
806
+ writer(
807
+ subs_tra_copy_result,
808
+ name_tra[:-1] + ".mp3",
809
+ word_options,
810
+ )
811
+
812
+ return name_tra + output_format_subtitle
813
+
814
+
815
+ def linguistic_level_segments(
816
+ result_base,
817
+ linguistic_unit="word", # word or char
818
+ ):
819
+ linguistic_unit = linguistic_unit[:4]
820
+ linguistic_unit_key = linguistic_unit + "s"
821
+ result = copy.deepcopy(result_base)
822
+
823
+ if linguistic_unit_key not in result["segments"][0].keys():
824
+ raise ValueError("No alignment detected, can't process")
825
+
826
+ segments_by_unit = []
827
+ for segment in result["segments"]:
828
+ segment_units = segment[linguistic_unit_key]
829
+ # segment_speaker = segment.get("speaker", "SPEAKER_00")
830
+
831
+ for unit in segment_units:
832
+
833
+ text = unit[linguistic_unit]
834
+
835
+ if "start" in unit.keys():
836
+ segments_by_unit.append(
837
+ {
838
+ "start": unit["start"],
839
+ "end": unit["end"],
840
+ "text": text,
841
+ # "speaker": segment_speaker,
842
+ }
843
+ )
844
+ elif not segments_by_unit:
845
+ pass
846
+ else:
847
+ segments_by_unit[-1]["text"] += text
848
+
849
+ return {"segments": segments_by_unit}
850
+
851
+
852
+ def break_aling_segments(
853
+ result: dict,
854
+ break_characters: str = "", # ":|,|.|"
855
+ ):
856
+ result_align = copy.deepcopy(result)
857
+
858
+ break_characters_list = break_characters.split("|")
859
+ break_characters_list = [i for i in break_characters_list if i != '']
860
+
861
+ if not break_characters_list:
862
+ logger.info("No valid break characters were specified.")
863
+ return result
864
+
865
+ logger.info(f"Redivide text segments by: {str(break_characters_list)}")
866
+
867
+ # create new with filters
868
+ normal = []
869
+
870
+ def process_chars(chars, letter_new_start, num, text):
871
+ start_key, end_key = "start", "end"
872
+ start_value = end_value = None
873
+
874
+ for char in chars:
875
+ if start_key in char:
876
+ start_value = char[start_key]
877
+ break
878
+
879
+ for char in reversed(chars):
880
+ if end_key in char:
881
+ end_value = char[end_key]
882
+ break
883
+
884
+ if not start_value or not end_value:
885
+ raise Exception(
886
+ f"Unable to obtain a valid timestamp for chars: {str(chars)}"
887
+ )
888
+
889
+ return {
890
+ "start": start_value,
891
+ "end": end_value,
892
+ "text": text,
893
+ "words": chars,
894
+ }
895
+
896
+ for i, segment in enumerate(result_align['segments']):
897
+
898
+ logger.debug(f"- Process segment: {i}, text: {segment['text']}")
899
+ # start = segment['start']
900
+ letter_new_start = 0
901
+ for num, char in enumerate(segment['chars']):
902
+
903
+ if char["char"] is None:
904
+ continue
905
+
906
+ # if "start" in char:
907
+ # start = char["start"]
908
+
909
+ # if "end" in char:
910
+ # end = char["end"]
911
+
912
+ # Break by character
913
+ if char['char'] in break_characters_list:
914
+
915
+ text = segment['text'][letter_new_start:num+1]
916
+
917
+ logger.debug(
918
+ f"Break in: {char['char']}, position: {num}, text: {text}"
919
+ )
920
+
921
+ chars = segment['chars'][letter_new_start:num+1]
922
+
923
+ if not text:
924
+ logger.debug("No text")
925
+ continue
926
+
927
+ if num == 0 and not text.strip():
928
+ logger.debug("blank space in start")
929
+ continue
930
+
931
+ if len(text) == 1:
932
+ logger.debug(f"Short char append, num: {num}")
933
+ normal[-1]["text"] += text
934
+ normal[-1]["words"].append(chars)
935
+ continue
936
+
937
+ # logger.debug(chars)
938
+ normal_dict = process_chars(chars, letter_new_start, num, text)
939
+
940
+ letter_new_start = num+1
941
+
942
+ normal.append(normal_dict)
943
+
944
+ # If we reach the end of the segment, add the last part of chars.
945
+ if num == len(segment["chars"]) - 1:
946
+
947
+ text = segment['text'][letter_new_start:num+1]
948
+
949
+ # If remain text len is not default len text
950
+ if num not in [len(text)-1, len(text)] and text:
951
+ logger.debug(f'Remaining text: {text}')
952
+
953
+ if not text:
954
+ logger.debug("No remaining text.")
955
+ continue
956
+
957
+ if len(text) == 1:
958
+ logger.debug(f"Short char append, num: {num}")
959
+ normal[-1]["text"] += text
960
+ normal[-1]["words"].append(chars)
961
+ continue
962
+
963
+ chars = segment['chars'][letter_new_start:num+1]
964
+
965
+ normal_dict = process_chars(chars, letter_new_start, num, text)
966
+
967
+ letter_new_start = num+1
968
+
969
+ normal.append(normal_dict)
970
+
971
+ # Rename char to word
972
+ for item in normal:
973
+ words_list = item['words']
974
+ for word_item in words_list:
975
+ if 'char' in word_item:
976
+ word_item['word'] = word_item.pop('char')
977
+
978
+ # Convert to dict default
979
+ break_segments = {"segments": normal}
980
+
981
+ msg_count = (
982
+ f"Segment count before: {len(result['segments'])}, "
983
+ f"after: {len(break_segments['segments'])}."
984
+ )
985
+ logger.info(msg_count)
986
+
987
+ return break_segments
quantum_dubbing/text_to_speech.py ADDED
@@ -0,0 +1,1574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gtts import gTTS
2
+ import edge_tts, asyncio, json, glob # noqa
3
+ from tqdm import tqdm
4
+ import librosa, os, re, torch, gc, subprocess # noqa
5
+ from .language_configuration import (
6
+ fix_code_language,
7
+ BARK_VOICES_LIST,
8
+ VITS_VOICES_LIST,
9
+ )
10
+ from .utils import (
11
+ download_manager,
12
+ create_directories,
13
+ copy_files,
14
+ rename_file,
15
+ remove_directory_contents,
16
+ remove_files,
17
+ run_command,
18
+ )
19
+ import numpy as np
20
+ from typing import Any, Dict
21
+ from pathlib import Path
22
+ import soundfile as sf
23
+ import platform
24
+ import logging
25
+ import traceback
26
+ from .logging_setup import logger
27
+
28
+
29
+ class TTS_OperationError(Exception):
30
+ def __init__(self, message="The operation did not complete successfully."):
31
+ self.message = message
32
+ super().__init__(self.message)
33
+
34
+
35
+ def verify_saved_file_and_size(filename):
36
+ if not os.path.exists(filename):
37
+ raise TTS_OperationError(f"File '{filename}' was not saved.")
38
+ if os.path.getsize(filename) == 0:
39
+ raise TTS_OperationError(
40
+ f"File '{filename}' has a zero size. "
41
+ "Related to incorrect TTS for the target language"
42
+ )
43
+
44
+
45
+ def error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename):
46
+ traceback.print_exc()
47
+ logger.error(f"Error: {str(error)}")
48
+ try:
49
+ from tempfile import TemporaryFile
50
+
51
+ tts = gTTS(segment["text"], lang=fix_code_language(TRANSLATE_AUDIO_TO))
52
+ # tts.save(filename)
53
+ f = TemporaryFile()
54
+ tts.write_to_fp(f)
55
+
56
+ # Reset the file pointer to the beginning of the file
57
+ f.seek(0)
58
+
59
+ # Read audio data from the TemporaryFile using soundfile
60
+ audio_data, samplerate = sf.read(f)
61
+ f.close() # Close the TemporaryFile
62
+ sf.write(
63
+ filename, audio_data, samplerate, format="ogg", subtype="vorbis"
64
+ )
65
+
66
+ logger.warning(
67
+ 'TTS auxiliary will be utilized '
68
+ f'rather than TTS: {segment["tts_name"]}'
69
+ )
70
+ verify_saved_file_and_size(filename)
71
+ except Exception as error:
72
+ logger.critical(f"Error: {str(error)}")
73
+ sample_rate_aux = 22050
74
+ duration = float(segment["end"]) - float(segment["start"])
75
+ data = np.zeros(int(sample_rate_aux * duration)).astype(np.float32)
76
+ sf.write(
77
+ filename, data, sample_rate_aux, format="ogg", subtype="vorbis"
78
+ )
79
+ logger.error("Audio will be replaced -> [silent audio].")
80
+ verify_saved_file_and_size(filename)
81
+
82
+
83
+ def pad_array(array, sr):
84
+
85
+ if isinstance(array, list):
86
+ array = np.array(array)
87
+
88
+ if not array.shape[0]:
89
+ raise ValueError("The generated audio does not contain any data")
90
+
91
+ valid_indices = np.where(np.abs(array) > 0.001)[0]
92
+
93
+ if len(valid_indices) == 0:
94
+ logger.debug(f"No valid indices: {array}")
95
+ return array
96
+
97
+ try:
98
+ pad_indice = int(0.1 * sr)
99
+ start_pad = max(0, valid_indices[0] - pad_indice)
100
+ end_pad = min(len(array), valid_indices[-1] + 1 + pad_indice)
101
+ padded_array = array[start_pad:end_pad]
102
+ return padded_array
103
+ except Exception as error:
104
+ logger.error(str(error))
105
+ return array
106
+
107
+
108
+ # =====================================
109
+ # EDGE TTS
110
+ # =====================================
111
+
112
+
113
+ def edge_tts_voices_list():
114
+ try:
115
+ completed_process = subprocess.run(
116
+ ["edge-tts", "--list-voices"], capture_output=True, text=True
117
+ )
118
+ lines = completed_process.stdout.strip().split("\n")
119
+ except Exception as error:
120
+ logger.debug(str(error))
121
+ lines = []
122
+
123
+ voices = []
124
+ for line in lines:
125
+ if line.startswith("Name: "):
126
+ voice_entry = {}
127
+ voice_entry["Name"] = line.split(": ")[1]
128
+ elif line.startswith("Gender: "):
129
+ voice_entry["Gender"] = line.split(": ")[1]
130
+ voices.append(voice_entry)
131
+
132
+ formatted_voices = [
133
+ f"{entry['Name']}-{entry['Gender']}" for entry in voices
134
+ ]
135
+
136
+ if not formatted_voices:
137
+ logger.warning(
138
+ "The list of Edge TTS voices could not be obtained, "
139
+ "switching to an alternative method"
140
+ )
141
+ tts_voice_list = asyncio.new_event_loop().run_until_complete(
142
+ edge_tts.list_voices()
143
+ )
144
+ formatted_voices = sorted(
145
+ [f"{v['ShortName']}-{v['Gender']}" for v in tts_voice_list]
146
+ )
147
+
148
+ if not formatted_voices:
149
+ logger.error("Can't get EDGE TTS - list voices")
150
+
151
+ return formatted_voices
152
+
153
+
154
+ def segments_egde_tts(filtered_edge_segments, TRANSLATE_AUDIO_TO, is_gui):
155
+ for segment in tqdm(filtered_edge_segments["segments"]):
156
+ speaker = segment["speaker"] # noqa
157
+ text = segment["text"]
158
+ start = segment["start"]
159
+ tts_name = segment["tts_name"]
160
+
161
+ # make the tts audio
162
+ filename = f"audio/{start}.ogg"
163
+ temp_file = filename[:-3] + "mp3"
164
+
165
+ logger.info(f"{text} >> {filename}")
166
+ try:
167
+ if is_gui:
168
+ asyncio.run(
169
+ edge_tts.Communicate(
170
+ text, "-".join(tts_name.split("-")[:-1])
171
+ ).save(temp_file)
172
+ )
173
+ else:
174
+ # nest_asyncio.apply() if not is_gui else None
175
+ command = f'edge-tts -t "{text}" -v "{tts_name.replace("-Male", "").replace("-Female", "")}" --write-media "{temp_file}"'
176
+ run_command(command)
177
+ verify_saved_file_and_size(temp_file)
178
+
179
+ data, sample_rate = sf.read(temp_file)
180
+ data = pad_array(data, sample_rate)
181
+ # os.remove(temp_file)
182
+
183
+ # Save file
184
+ sf.write(
185
+ file=filename,
186
+ samplerate=sample_rate,
187
+ data=data,
188
+ format="ogg",
189
+ subtype="vorbis",
190
+ )
191
+ verify_saved_file_and_size(filename)
192
+
193
+ except Exception as error:
194
+ error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
195
+
196
+
197
+ # =====================================
198
+ # BARK TTS
199
+ # =====================================
200
+
201
+
202
+ def segments_bark_tts(
203
+ filtered_bark_segments, TRANSLATE_AUDIO_TO, model_id_bark="suno/bark-small"
204
+ ):
205
+ from transformers import AutoProcessor, BarkModel
206
+ from optimum.bettertransformer import BetterTransformer
207
+
208
+ device = os.environ.get("QUANTUM_DEVICE")
209
+ torch_dtype_env = torch.float16 if device == "cuda" else torch.float32
210
+
211
+ # load model bark
212
+ model = BarkModel.from_pretrained(
213
+ model_id_bark, torch_dtype=torch_dtype_env
214
+ ).to(device)
215
+ model = model.to(device)
216
+ processor = AutoProcessor.from_pretrained(
217
+ model_id_bark, return_tensors="pt"
218
+ ) # , padding=True
219
+ if device == "cuda":
220
+ # convert to bettertransformer
221
+ model = BetterTransformer.transform(model, keep_original_model=False)
222
+ # enable CPU offload
223
+ # model.enable_cpu_offload()
224
+ sampling_rate = model.generation_config.sample_rate
225
+
226
+ # filtered_segments = filtered_bark_segments['segments']
227
+ # Sorting the segments by 'tts_name'
228
+ # sorted_segments = sorted(filtered_segments, key=lambda x: x['tts_name'])
229
+ # logger.debug(sorted_segments)
230
+
231
+ for segment in tqdm(filtered_bark_segments["segments"]):
232
+ speaker = segment["speaker"] # noqa
233
+ text = segment["text"]
234
+ start = segment["start"]
235
+ tts_name = segment["tts_name"]
236
+
237
+ inputs = processor(text, voice_preset=BARK_VOICES_LIST[tts_name]).to(
238
+ device
239
+ )
240
+
241
+ # make the tts audio
242
+ filename = f"audio/{start}.ogg"
243
+ logger.info(f"{text} >> {filename}")
244
+ try:
245
+ # Infer
246
+ with torch.inference_mode():
247
+ speech_output = model.generate(
248
+ **inputs,
249
+ do_sample=True,
250
+ fine_temperature=0.4,
251
+ coarse_temperature=0.8,
252
+ pad_token_id=processor.tokenizer.pad_token_id,
253
+ )
254
+ # Save file
255
+ data_tts = pad_array(
256
+ speech_output.cpu().numpy().squeeze().astype(np.float32),
257
+ sampling_rate,
258
+ )
259
+ sf.write(
260
+ file=filename,
261
+ samplerate=sampling_rate,
262
+ data=data_tts,
263
+ format="ogg",
264
+ subtype="vorbis",
265
+ )
266
+ verify_saved_file_and_size(filename)
267
+ except Exception as error:
268
+ error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
269
+ gc.collect()
270
+ torch.cuda.empty_cache()
271
+ try:
272
+ del processor
273
+ del model
274
+ gc.collect()
275
+ torch.cuda.empty_cache()
276
+ except Exception as error:
277
+ logger.error(str(error))
278
+ gc.collect()
279
+ torch.cuda.empty_cache()
280
+
281
+
282
+ # =====================================
283
+ # VITS TTS
284
+ # =====================================
285
+
286
+
287
+ def uromanize(input_string):
288
+ """Convert non-Roman strings to Roman using the `uroman` perl package."""
289
+ # script_path = os.path.join(uroman_path, "bin", "uroman.pl")
290
+
291
+ if not os.path.exists("./uroman"):
292
+ logger.info(
293
+ "Clonning repository uroman https://github.com/isi-nlp/uroman.git"
294
+ " for romanize the text"
295
+ )
296
+ process = subprocess.Popen(
297
+ ["git", "clone", "https://github.com/isi-nlp/uroman.git"],
298
+ stdout=subprocess.PIPE,
299
+ stderr=subprocess.PIPE,
300
+ )
301
+ stdout, stderr = process.communicate()
302
+ script_path = os.path.join("./uroman", "uroman", "uroman.pl")
303
+
304
+ command = ["perl", script_path]
305
+
306
+ process = subprocess.Popen(
307
+ command,
308
+ stdin=subprocess.PIPE,
309
+ stdout=subprocess.PIPE,
310
+ stderr=subprocess.PIPE,
311
+ )
312
+ # Execute the perl command
313
+ stdout, stderr = process.communicate(input=input_string.encode())
314
+
315
+ if process.returncode != 0:
316
+ raise ValueError(f"Error {process.returncode}: {stderr.decode()}")
317
+
318
+ # Return the output as a string and skip the new-line character at the end
319
+ return stdout.decode()[:-1]
320
+
321
+
322
+ def segments_vits_tts(filtered_vits_segments, TRANSLATE_AUDIO_TO):
323
+ from transformers import VitsModel, AutoTokenizer
324
+
325
+ filtered_segments = filtered_vits_segments["segments"]
326
+ # Sorting the segments by 'tts_name'
327
+ sorted_segments = sorted(filtered_segments, key=lambda x: x["tts_name"])
328
+ logger.debug(sorted_segments)
329
+
330
+ model_name_key = None
331
+ for segment in tqdm(sorted_segments):
332
+ speaker = segment["speaker"] # noqa
333
+ text = segment["text"]
334
+ start = segment["start"]
335
+ tts_name = segment["tts_name"]
336
+
337
+ if tts_name != model_name_key:
338
+ model_name_key = tts_name
339
+ model = VitsModel.from_pretrained(VITS_VOICES_LIST[tts_name])
340
+ tokenizer = AutoTokenizer.from_pretrained(
341
+ VITS_VOICES_LIST[tts_name]
342
+ )
343
+ sampling_rate = model.config.sampling_rate
344
+
345
+ if tokenizer.is_uroman:
346
+ romanize_text = uromanize(text)
347
+ logger.debug(f"Romanize text: {romanize_text}")
348
+ inputs = tokenizer(romanize_text, return_tensors="pt")
349
+ else:
350
+ inputs = tokenizer(text, return_tensors="pt")
351
+
352
+ # make the tts audio
353
+ filename = f"audio/{start}.ogg"
354
+ logger.info(f"{text} >> {filename}")
355
+ try:
356
+ # Infer
357
+ with torch.no_grad():
358
+ speech_output = model(**inputs).waveform
359
+
360
+ data_tts = pad_array(
361
+ speech_output.cpu().numpy().squeeze().astype(np.float32),
362
+ sampling_rate,
363
+ )
364
+ # Save file
365
+ sf.write(
366
+ file=filename,
367
+ samplerate=sampling_rate,
368
+ data=data_tts,
369
+ format="ogg",
370
+ subtype="vorbis",
371
+ )
372
+ verify_saved_file_and_size(filename)
373
+ except Exception as error:
374
+ error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
375
+ gc.collect()
376
+ torch.cuda.empty_cache()
377
+ try:
378
+ del tokenizer
379
+ del model
380
+ gc.collect()
381
+ torch.cuda.empty_cache()
382
+ except Exception as error:
383
+ logger.error(str(error))
384
+ gc.collect()
385
+ torch.cuda.empty_cache()
386
+
387
+
388
+ # =====================================
389
+ # Coqui XTTS
390
+ # =====================================
391
+
392
+
393
+ def coqui_xtts_voices_list():
394
+ main_folder = "_XTTS_"
395
+ pattern_coqui = re.compile(r".+\.(wav|mp3|ogg|m4a)$")
396
+ pattern_automatic_speaker = re.compile(r"AUTOMATIC_SPEAKER_\d+\.wav$")
397
+
398
+ # List only files in the directory matching the pattern but not matching
399
+ # AUTOMATIC_SPEAKER_00.wav, AUTOMATIC_SPEAKER_01.wav, etc.
400
+ wav_voices = [
401
+ "_XTTS_/" + f
402
+ for f in os.listdir(main_folder)
403
+ if os.path.isfile(os.path.join(main_folder, f))
404
+ and pattern_coqui.match(f)
405
+ and not pattern_automatic_speaker.match(f)
406
+ ]
407
+
408
+ return ["_XTTS_/AUTOMATIC.wav"] + wav_voices
409
+
410
+
411
+ def seconds_to_hhmmss_ms(seconds):
412
+ hours = seconds // 3600
413
+ minutes = (seconds % 3600) // 60
414
+ seconds = seconds % 60
415
+ milliseconds = int((seconds - int(seconds)) * 1000)
416
+ return "%02d:%02d:%02d.%03d" % (hours, minutes, int(seconds), milliseconds)
417
+
418
+
419
+ def audio_trimming(audio_path, destination, start, end):
420
+ if isinstance(start, (int, float)):
421
+ start = seconds_to_hhmmss_ms(start)
422
+ if isinstance(end, (int, float)):
423
+ end = seconds_to_hhmmss_ms(end)
424
+
425
+ if destination:
426
+ file_directory = destination
427
+ else:
428
+ file_directory = os.path.dirname(audio_path)
429
+
430
+ file_name = os.path.splitext(os.path.basename(audio_path))[0]
431
+ file_ = f"{file_name}_trim.wav"
432
+ # file_ = f'{os.path.splitext(audio_path)[0]}_trim.wav'
433
+ output_path = os.path.join(file_directory, file_)
434
+
435
+ # -t (duration from -ss) | -to (time stop) | -af silenceremove=1:0:-50dB (remove silence)
436
+ command = f'ffmpeg -y -loglevel error -i "{audio_path}" -ss {start} -to {end} -acodec pcm_s16le -f wav "{output_path}"'
437
+ run_command(command)
438
+
439
+ return output_path
440
+
441
+
442
+ def convert_to_xtts_good_sample(audio_path: str = "", destination: str = ""):
443
+ if destination:
444
+ file_directory = destination
445
+ else:
446
+ file_directory = os.path.dirname(audio_path)
447
+
448
+ file_name = os.path.splitext(os.path.basename(audio_path))[0]
449
+ file_ = f"{file_name}_good_sample.wav"
450
+ # file_ = f'{os.path.splitext(audio_path)[0]}_good_sample.wav'
451
+ mono_path = os.path.join(file_directory, file_) # get root
452
+
453
+ command = f'ffmpeg -y -loglevel error -i "{audio_path}" -ac 1 -ar 22050 -sample_fmt s16 -f wav "{mono_path}"'
454
+ run_command(command)
455
+
456
+ return mono_path
457
+
458
+
459
+ def sanitize_file_name(file_name):
460
+ import unicodedata
461
+
462
+ # Normalize the string to NFKD form to separate combined characters into
463
+ # base characters and diacritics
464
+ normalized_name = unicodedata.normalize("NFKD", file_name)
465
+ # Replace any non-ASCII characters or special symbols with an underscore
466
+ sanitized_name = re.sub(r"[^\w\s.-]", "_", normalized_name)
467
+ return sanitized_name
468
+
469
+
470
+ def create_wav_file_vc(
471
+ sample_name="", # name final file
472
+ audio_wav="", # path
473
+ start=None, # trim start
474
+ end=None, # trim end
475
+ output_final_path="_XTTS_",
476
+ get_vocals_dereverb=True,
477
+ ):
478
+ sample_name = sample_name if sample_name else "default_name"
479
+ sample_name = sanitize_file_name(sample_name)
480
+ audio_wav = audio_wav if isinstance(audio_wav, str) else audio_wav.name
481
+
482
+ BASE_DIR = (
483
+ "." # os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
484
+ )
485
+
486
+ output_dir = os.path.join(BASE_DIR, "clean_song_output") # remove content
487
+ # remove_directory_contents(output_dir)
488
+
489
+ if start or end:
490
+ # Cut file
491
+ audio_segment = audio_trimming(audio_wav, output_dir, start, end)
492
+ else:
493
+ # Complete file
494
+ audio_segment = audio_wav
495
+
496
+ from .mdx_net import process_uvr_task
497
+
498
+ try:
499
+ _, _, _, _, audio_segment = process_uvr_task(
500
+ orig_song_path=audio_segment,
501
+ main_vocals=True,
502
+ dereverb=get_vocals_dereverb,
503
+ )
504
+ except Exception as error:
505
+ logger.error(str(error))
506
+
507
+ sample = convert_to_xtts_good_sample(audio_segment)
508
+
509
+ sample_name = f"{sample_name}.wav"
510
+ sample_rename = rename_file(sample, sample_name)
511
+
512
+ copy_files(sample_rename, output_final_path)
513
+
514
+ final_sample = os.path.join(output_final_path, sample_name)
515
+ if os.path.exists(final_sample):
516
+ logger.info(final_sample)
517
+ return final_sample
518
+ else:
519
+ raise Exception(f"Error wav: {final_sample}")
520
+
521
+
522
+ def create_new_files_for_vc(
523
+ speakers_coqui,
524
+ segments_base,
525
+ dereverb_automatic=True
526
+ ):
527
+ # before function delete automatic delete_previous_automatic
528
+ output_dir = os.path.join(".", "clean_song_output") # remove content
529
+ remove_directory_contents(output_dir)
530
+
531
+ for speaker in speakers_coqui:
532
+ filtered_speaker = [
533
+ segment
534
+ for segment in segments_base
535
+ if segment["speaker"] == speaker
536
+ ]
537
+ if len(filtered_speaker) > 4:
538
+ filtered_speaker = filtered_speaker[1:]
539
+ if filtered_speaker[0]["tts_name"] == "_XTTS_/AUTOMATIC.wav":
540
+ name_automatic_wav = f"AUTOMATIC_{speaker}"
541
+ if os.path.exists(f"_XTTS_/{name_automatic_wav}.wav"):
542
+ logger.info(f"WAV automatic {speaker} exists")
543
+ # path_wav = path_automatic_wav
544
+ pass
545
+ else:
546
+ # create wav
547
+ wav_ok = False
548
+ for seg in filtered_speaker:
549
+ duration = float(seg["end"]) - float(seg["start"])
550
+ if duration > 7.0 and duration < 12.0:
551
+ logger.info(
552
+ f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {duration}, {seg["text"]}'
553
+ )
554
+ create_wav_file_vc(
555
+ sample_name=name_automatic_wav,
556
+ audio_wav="audio.wav",
557
+ start=(float(seg["start"]) + 1.0),
558
+ end=(float(seg["end"]) - 1.0),
559
+ get_vocals_dereverb=dereverb_automatic,
560
+ )
561
+ wav_ok = True
562
+ break
563
+
564
+ if not wav_ok:
565
+ logger.info("Taking the first segment")
566
+ seg = filtered_speaker[0]
567
+ logger.info(
568
+ f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {seg["text"]}'
569
+ )
570
+ max_duration = float(seg["end"]) - float(seg["start"])
571
+ max_duration = max(2.0, min(max_duration, 9.0))
572
+
573
+ create_wav_file_vc(
574
+ sample_name=name_automatic_wav,
575
+ audio_wav="audio.wav",
576
+ start=(float(seg["start"])),
577
+ end=(float(seg["start"]) + max_duration),
578
+ get_vocals_dereverb=dereverb_automatic,
579
+ )
580
+
581
+
582
+ def segments_coqui_tts(
583
+ filtered_coqui_segments,
584
+ TRANSLATE_AUDIO_TO,
585
+ model_id_coqui="tts_models/multilingual/multi-dataset/xtts_v2",
586
+ speakers_coqui=None,
587
+ delete_previous_automatic=True,
588
+ dereverb_automatic=True,
589
+ emotion=None,
590
+ ):
591
+ """XTTS
592
+ Install:
593
+ pip install -q TTS==0.21.1
594
+ pip install -q numpy==1.23.5
595
+
596
+ Notes:
597
+ - tts_name is the wav|mp3|ogg|m4a file for VC
598
+ """
599
+ from TTS.api import TTS
600
+
601
+ TRANSLATE_AUDIO_TO = fix_code_language(TRANSLATE_AUDIO_TO, syntax="coqui")
602
+ supported_lang_coqui = [
603
+ "zh-cn",
604
+ "en",
605
+ "fr",
606
+ "de",
607
+ "it",
608
+ "pt",
609
+ "pl",
610
+ "tr",
611
+ "ru",
612
+ "nl",
613
+ "cs",
614
+ "ar",
615
+ "es",
616
+ "hu",
617
+ "ko",
618
+ "ja",
619
+ ]
620
+ if TRANSLATE_AUDIO_TO not in supported_lang_coqui:
621
+ raise TTS_OperationError(
622
+ f"'{TRANSLATE_AUDIO_TO}' is not a supported language for Coqui XTTS"
623
+ )
624
+ # Emotion and speed can only be used with Coqui Studio models. discontinued
625
+ # emotions = ["Neutral", "Happy", "Sad", "Angry", "Dull"]
626
+
627
+ if delete_previous_automatic:
628
+ for spk in speakers_coqui:
629
+ remove_files(f"_XTTS_/AUTOMATIC_{spk}.wav")
630
+
631
+ directory_audios_vc = "_XTTS_"
632
+ create_directories(directory_audios_vc)
633
+ create_new_files_for_vc(
634
+ speakers_coqui,
635
+ filtered_coqui_segments["segments"],
636
+ dereverb_automatic,
637
+ )
638
+
639
+ # Init TTS
640
+ device = os.environ.get("QUANTUM_DEVICE")
641
+ model = TTS(model_id_coqui).to(device)
642
+ sampling_rate = 24000
643
+
644
+ # filtered_segments = filtered_coqui_segments['segments']
645
+ # Sorting the segments by 'tts_name'
646
+ # sorted_segments = sorted(filtered_segments, key=lambda x: x['tts_name'])
647
+ # logger.debug(sorted_segments)
648
+
649
+ for segment in tqdm(filtered_coqui_segments["segments"]):
650
+ speaker = segment["speaker"]
651
+ text = segment["text"]
652
+ start = segment["start"]
653
+ tts_name = segment["tts_name"]
654
+ if tts_name == "_XTTS_/AUTOMATIC.wav":
655
+ tts_name = f"_XTTS_/AUTOMATIC_{speaker}.wav"
656
+
657
+ # make the tts audio
658
+ filename = f"audio/{start}.ogg"
659
+ logger.info(f"{text} >> {filename}")
660
+ try:
661
+ # Infer
662
+ wav = model.tts(
663
+ text=text, speaker_wav=tts_name, language=TRANSLATE_AUDIO_TO
664
+ )
665
+ data_tts = pad_array(
666
+ wav,
667
+ sampling_rate,
668
+ )
669
+ # Save file
670
+ sf.write(
671
+ file=filename,
672
+ samplerate=sampling_rate,
673
+ data=data_tts,
674
+ format="ogg",
675
+ subtype="vorbis",
676
+ )
677
+ verify_saved_file_and_size(filename)
678
+ except Exception as error:
679
+ error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
680
+ gc.collect()
681
+ torch.cuda.empty_cache()
682
+ try:
683
+ del model
684
+ gc.collect()
685
+ torch.cuda.empty_cache()
686
+ except Exception as error:
687
+ logger.error(str(error))
688
+ gc.collect()
689
+ torch.cuda.empty_cache()
690
+
691
+
692
+ # =====================================
693
+ # PIPER TTS
694
+ # =====================================
695
+
696
+
697
+ def piper_tts_voices_list():
698
+ file_path = download_manager(
699
+ url="https://huggingface.co/rhasspy/piper-voices/resolve/main/voices.json",
700
+ path="./PIPER_MODELS",
701
+ )
702
+
703
+ with open(file_path, "r", encoding="utf8") as file:
704
+ data = json.load(file)
705
+ piper_id_models = [key + " VITS-onnx" for key in data.keys()]
706
+
707
+ return piper_id_models
708
+
709
+
710
+ def replace_text_in_json(file_path, key_to_replace, new_text, condition=None):
711
+ # Read the JSON file
712
+ with open(file_path, "r", encoding="utf-8") as file:
713
+ data = json.load(file)
714
+
715
+ # Modify the specified key's value with the new text
716
+ if key_to_replace in data:
717
+ if condition:
718
+ value_condition = condition
719
+ else:
720
+ value_condition = data[key_to_replace]
721
+
722
+ if data[key_to_replace] == value_condition:
723
+ data[key_to_replace] = new_text
724
+
725
+ # Write the modified content back to the JSON file
726
+ with open(file_path, "w") as file:
727
+ json.dump(
728
+ data, file, indent=2
729
+ ) # Write the modified data back to the file with indentation for readability
730
+
731
+
732
+ def load_piper_model(
733
+ model: str,
734
+ data_dir: list,
735
+ download_dir: str = "",
736
+ update_voices: bool = False,
737
+ ):
738
+ from piper import PiperVoice
739
+ from piper.download import ensure_voice_exists, find_voice, get_voices
740
+
741
+ try:
742
+ import onnxruntime as rt
743
+
744
+ if rt.get_device() == "GPU" and os.environ.get("QUANTUM_DEVICE") == "cuda":
745
+ logger.debug("onnxruntime device > GPU")
746
+ cuda = True
747
+ else:
748
+ logger.info(
749
+ "onnxruntime device > CPU"
750
+ ) # try pip install onnxruntime-gpu
751
+ cuda = False
752
+ except Exception as error:
753
+ raise TTS_OperationError(f"onnxruntime error: {str(error)}")
754
+
755
+ # Disable CUDA in Windows
756
+ if platform.system() == "Windows":
757
+ logger.info("Employing CPU exclusivity with Piper TTS")
758
+ cuda = False
759
+
760
+ if not download_dir:
761
+ # Download to first data directory by default
762
+ download_dir = data_dir[0]
763
+ else:
764
+ data_dir = [os.path.join(data_dir[0], download_dir)]
765
+
766
+ # Download voice if file doesn't exist
767
+ model_path = Path(model)
768
+ if not model_path.exists():
769
+ # Load voice info
770
+ voices_info = get_voices(download_dir, update_voices=update_voices)
771
+
772
+ # Resolve aliases for backwards compatibility with old voice names
773
+ aliases_info: Dict[str, Any] = {}
774
+ for voice_info in voices_info.values():
775
+ for voice_alias in voice_info.get("aliases", []):
776
+ aliases_info[voice_alias] = {"_is_alias": True, **voice_info}
777
+
778
+ voices_info.update(aliases_info)
779
+ ensure_voice_exists(model, data_dir, download_dir, voices_info)
780
+ model, config = find_voice(model, data_dir)
781
+
782
+ replace_text_in_json(
783
+ config, "phoneme_type", "espeak", "PhonemeType.ESPEAK"
784
+ )
785
+
786
+ # Load voice
787
+ voice = PiperVoice.load(model, config_path=config, use_cuda=cuda)
788
+
789
+ return voice
790
+
791
+
792
+ def synthesize_text_to_audio_np_array(voice, text, synthesize_args):
793
+ audio_stream = voice.synthesize_stream_raw(text, **synthesize_args)
794
+
795
+ # Collect the audio bytes into a single NumPy array
796
+ audio_data = b""
797
+ for audio_bytes in audio_stream:
798
+ audio_data += audio_bytes
799
+
800
+ # Ensure correct data type and convert audio bytes to NumPy array
801
+ audio_np = np.frombuffer(audio_data, dtype=np.int16)
802
+ return audio_np
803
+
804
+
805
+ def segments_vits_onnx_tts(filtered_onnx_vits_segments, TRANSLATE_AUDIO_TO):
806
+ """
807
+ Install:
808
+ pip install -q piper-tts==1.2.0 onnxruntime-gpu # for cuda118
809
+ """
810
+
811
+ data_dir = [
812
+ str(Path.cwd())
813
+ ] # "Data directory to check for downloaded models (default: current directory)"
814
+ download_dir = "PIPER_MODELS"
815
+ # model_name = "en_US-lessac-medium" tts_name in a dict like VITS
816
+ update_voices = True # "Download latest voices.json during startup",
817
+
818
+ synthesize_args = {
819
+ "speaker_id": None,
820
+ "length_scale": 1.0,
821
+ "noise_scale": 0.667,
822
+ "noise_w": 0.8,
823
+ "sentence_silence": 0.0,
824
+ }
825
+
826
+ filtered_segments = filtered_onnx_vits_segments["segments"]
827
+ # Sorting the segments by 'tts_name'
828
+ sorted_segments = sorted(filtered_segments, key=lambda x: x["tts_name"])
829
+ logger.debug(sorted_segments)
830
+
831
+ model_name_key = None
832
+ for segment in tqdm(sorted_segments):
833
+ speaker = segment["speaker"] # noqa
834
+ text = segment["text"]
835
+ start = segment["start"]
836
+ tts_name = segment["tts_name"].replace(" VITS-onnx", "")
837
+
838
+ if tts_name != model_name_key:
839
+ model_name_key = tts_name
840
+ model = load_piper_model(
841
+ tts_name, data_dir, download_dir, update_voices
842
+ )
843
+ sampling_rate = model.config.sample_rate
844
+
845
+ # make the tts audio
846
+ filename = f"audio/{start}.ogg"
847
+ logger.info(f"{text} >> {filename}")
848
+ try:
849
+ # Infer
850
+ speech_output = synthesize_text_to_audio_np_array(
851
+ model, text, synthesize_args
852
+ )
853
+ data_tts = pad_array(
854
+ speech_output, # .cpu().numpy().squeeze().astype(np.float32),
855
+ sampling_rate,
856
+ )
857
+ # Save file
858
+ sf.write(
859
+ file=filename,
860
+ samplerate=sampling_rate,
861
+ data=data_tts,
862
+ format="ogg",
863
+ subtype="vorbis",
864
+ )
865
+ verify_saved_file_and_size(filename)
866
+ except Exception as error:
867
+ error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
868
+ gc.collect()
869
+ torch.cuda.empty_cache()
870
+ try:
871
+ del model
872
+ gc.collect()
873
+ torch.cuda.empty_cache()
874
+ except Exception as error:
875
+ logger.error(str(error))
876
+ gc.collect()
877
+ torch.cuda.empty_cache()
878
+
879
+
880
+ # =====================================
881
+ # CLOSEAI TTS
882
+ # =====================================
883
+
884
+
885
+ def segments_openai_tts(
886
+ filtered_openai_tts_segments, TRANSLATE_AUDIO_TO
887
+ ):
888
+ from openai import OpenAI
889
+
890
+ client = OpenAI()
891
+ sampling_rate = 24000
892
+
893
+ # filtered_segments = filtered_openai_tts_segments['segments']
894
+ # Sorting the segments by 'tts_name'
895
+ # sorted_segments = sorted(filtered_segments, key=lambda x: x['tts_name'])
896
+
897
+ for segment in tqdm(filtered_openai_tts_segments["segments"]):
898
+ speaker = segment["speaker"] # noqa
899
+ text = segment["text"].strip()
900
+ start = segment["start"]
901
+ tts_name = segment["tts_name"]
902
+
903
+ # make the tts audio
904
+ filename = f"audio/{start}.ogg"
905
+ logger.info(f"{text} >> {filename}")
906
+
907
+ try:
908
+ # Request
909
+ response = client.audio.speech.create(
910
+ model="tts-1-hd" if "HD" in tts_name else "tts-1",
911
+ voice=tts_name.split()[0][1:],
912
+ response_format="wav",
913
+ input=text
914
+ )
915
+
916
+ audio_bytes = b''
917
+ for data in response.iter_bytes(chunk_size=4096):
918
+ audio_bytes += data
919
+
920
+ speech_output = np.frombuffer(audio_bytes, dtype=np.int16)
921
+
922
+ # Save file
923
+ data_tts = pad_array(
924
+ speech_output[240:],
925
+ sampling_rate,
926
+ )
927
+
928
+ sf.write(
929
+ file=filename,
930
+ samplerate=sampling_rate,
931
+ data=data_tts,
932
+ format="ogg",
933
+ subtype="vorbis",
934
+ )
935
+ verify_saved_file_and_size(filename)
936
+
937
+ except Exception as error:
938
+ error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
939
+
940
+
941
+ # =====================================
942
+ # Select task TTS
943
+ # =====================================
944
+
945
+
946
+ def find_spkr(pattern, speaker_to_voice, segments):
947
+ return [
948
+ speaker
949
+ for speaker, voice in speaker_to_voice.items()
950
+ if pattern.match(voice) and any(
951
+ segment["speaker"] == speaker for segment in segments
952
+ )
953
+ ]
954
+
955
+
956
+ def filter_by_speaker(speakers, segments):
957
+ return {
958
+ "segments": [
959
+ segment
960
+ for segment in segments
961
+ if segment["speaker"] in speakers
962
+ ]
963
+ }
964
+
965
+
966
+ def audio_segmentation_to_voice(
967
+ result_diarize,
968
+ TRANSLATE_AUDIO_TO,
969
+ is_gui,
970
+ tts_voice00,
971
+ tts_voice01="",
972
+ tts_voice02="",
973
+ tts_voice03="",
974
+ tts_voice04="",
975
+ tts_voice05="",
976
+ tts_voice06="",
977
+ tts_voice07="",
978
+ tts_voice08="",
979
+ tts_voice09="",
980
+ tts_voice10="",
981
+ tts_voice11="",
982
+ dereverb_automatic=True,
983
+ model_id_bark="suno/bark-small",
984
+ model_id_coqui="tts_models/multilingual/multi-dataset/xtts_v2",
985
+ delete_previous_automatic=True,
986
+ ):
987
+
988
+ remove_directory_contents("audio")
989
+
990
+ # Mapping speakers to voice variables
991
+ speaker_to_voice = {
992
+ "SPEAKER_00": tts_voice00,
993
+ "SPEAKER_01": tts_voice01,
994
+ "SPEAKER_02": tts_voice02,
995
+ "SPEAKER_03": tts_voice03,
996
+ "SPEAKER_04": tts_voice04,
997
+ "SPEAKER_05": tts_voice05,
998
+ "SPEAKER_06": tts_voice06,
999
+ "SPEAKER_07": tts_voice07,
1000
+ "SPEAKER_08": tts_voice08,
1001
+ "SPEAKER_09": tts_voice09,
1002
+ "SPEAKER_10": tts_voice10,
1003
+ "SPEAKER_11": tts_voice11,
1004
+ }
1005
+
1006
+ # Assign 'SPEAKER_00' to segments without a 'speaker' key
1007
+ for segment in result_diarize["segments"]:
1008
+ if "speaker" not in segment:
1009
+ segment["speaker"] = "SPEAKER_00"
1010
+ logger.warning(
1011
+ "NO SPEAKER DETECT IN SEGMENT: First TTS will be used in the"
1012
+ f" segment time {segment['start'], segment['text']}"
1013
+ )
1014
+ # Assign the TTS name
1015
+ segment["tts_name"] = speaker_to_voice[segment["speaker"]]
1016
+
1017
+ # Find TTS method
1018
+ pattern_edge = re.compile(r".*-(Male|Female)$")
1019
+ pattern_bark = re.compile(r".* BARK$")
1020
+ pattern_vits = re.compile(r".* VITS$")
1021
+ pattern_coqui = re.compile(r".+\.(wav|mp3|ogg|m4a)$")
1022
+ pattern_vits_onnx = re.compile(r".* VITS-onnx$")
1023
+ pattern_openai_tts = re.compile(r".* OpenAI-TTS$")
1024
+
1025
+ all_segments = result_diarize["segments"]
1026
+
1027
+ speakers_edge = find_spkr(pattern_edge, speaker_to_voice, all_segments)
1028
+ speakers_bark = find_spkr(pattern_bark, speaker_to_voice, all_segments)
1029
+ speakers_vits = find_spkr(pattern_vits, speaker_to_voice, all_segments)
1030
+ speakers_coqui = find_spkr(pattern_coqui, speaker_to_voice, all_segments)
1031
+ speakers_vits_onnx = find_spkr(
1032
+ pattern_vits_onnx, speaker_to_voice, all_segments
1033
+ )
1034
+ speakers_openai_tts = find_spkr(
1035
+ pattern_openai_tts, speaker_to_voice, all_segments
1036
+ )
1037
+
1038
+ # Filter method in segments
1039
+ filtered_edge = filter_by_speaker(speakers_edge, all_segments)
1040
+ filtered_bark = filter_by_speaker(speakers_bark, all_segments)
1041
+ filtered_vits = filter_by_speaker(speakers_vits, all_segments)
1042
+ filtered_coqui = filter_by_speaker(speakers_coqui, all_segments)
1043
+ filtered_vits_onnx = filter_by_speaker(speakers_vits_onnx, all_segments)
1044
+ filtered_openai_tts = filter_by_speaker(speakers_openai_tts, all_segments)
1045
+
1046
+ # Infer
1047
+ if filtered_edge["segments"]:
1048
+ logger.info(f"EDGE TTS: {speakers_edge}")
1049
+ segments_egde_tts(filtered_edge, TRANSLATE_AUDIO_TO, is_gui) # mp3
1050
+ if filtered_bark["segments"]:
1051
+ logger.info(f"BARK TTS: {speakers_bark}")
1052
+ segments_bark_tts(
1053
+ filtered_bark, TRANSLATE_AUDIO_TO, model_id_bark
1054
+ ) # wav
1055
+ if filtered_vits["segments"]:
1056
+ logger.info(f"VITS TTS: {speakers_vits}")
1057
+ segments_vits_tts(filtered_vits, TRANSLATE_AUDIO_TO) # wav
1058
+ if filtered_coqui["segments"]:
1059
+ logger.info(f"Coqui TTS: {speakers_coqui}")
1060
+ segments_coqui_tts(
1061
+ filtered_coqui,
1062
+ TRANSLATE_AUDIO_TO,
1063
+ model_id_coqui,
1064
+ speakers_coqui,
1065
+ delete_previous_automatic,
1066
+ dereverb_automatic,
1067
+ ) # wav
1068
+ if filtered_vits_onnx["segments"]:
1069
+ logger.info(f"PIPER TTS: {speakers_vits_onnx}")
1070
+ segments_vits_onnx_tts(filtered_vits_onnx, TRANSLATE_AUDIO_TO) # wav
1071
+ if filtered_openai_tts["segments"]:
1072
+ logger.info(f"OpenAI TTS: {speakers_openai_tts}")
1073
+ segments_openai_tts(filtered_openai_tts, TRANSLATE_AUDIO_TO) # wav
1074
+
1075
+ [result.pop("tts_name", None) for result in result_diarize["segments"]]
1076
+ return [
1077
+ speakers_edge,
1078
+ speakers_bark,
1079
+ speakers_vits,
1080
+ speakers_coqui,
1081
+ speakers_vits_onnx,
1082
+ speakers_openai_tts
1083
+ ]
1084
+
1085
+
1086
+ def accelerate_segments(
1087
+ result_diarize,
1088
+ max_accelerate_audio,
1089
+ valid_speakers,
1090
+ acceleration_rate_regulation=False,
1091
+ folder_output="audio2",
1092
+ ):
1093
+ logger.info("Apply acceleration")
1094
+
1095
+ (
1096
+ speakers_edge,
1097
+ speakers_bark,
1098
+ speakers_vits,
1099
+ speakers_coqui,
1100
+ speakers_vits_onnx,
1101
+ speakers_openai_tts
1102
+ ) = valid_speakers
1103
+
1104
+ create_directories(f"{folder_output}/audio/")
1105
+ remove_directory_contents(f"{folder_output}/audio/")
1106
+
1107
+ audio_files = []
1108
+ speakers_list = []
1109
+
1110
+ max_count_segments_idx = len(result_diarize["segments"]) - 1
1111
+
1112
+ for i, segment in tqdm(enumerate(result_diarize["segments"])):
1113
+ text = segment["text"] # noqa
1114
+ start = segment["start"]
1115
+ end = segment["end"]
1116
+ speaker = segment["speaker"]
1117
+
1118
+ # find name audio
1119
+ # if speaker in speakers_edge:
1120
+ filename = f"audio/{start}.ogg"
1121
+ # elif speaker in speakers_bark + speakers_vits + speakers_coqui + speakers_vits_onnx:
1122
+ # filename = f"audio/{start}.wav" # wav
1123
+
1124
+ # duration
1125
+ duration_true = end - start
1126
+ duration_tts = librosa.get_duration(filename=filename)
1127
+
1128
+ # Accelerate percentage
1129
+ acc_percentage = duration_tts / duration_true
1130
+
1131
+ # Smoth
1132
+ if acceleration_rate_regulation and acc_percentage >= 1.3:
1133
+ try:
1134
+ next_segment = result_diarize["segments"][
1135
+ min(max_count_segments_idx, i + 1)
1136
+ ]
1137
+ next_start = next_segment["start"]
1138
+ next_speaker = next_segment["speaker"]
1139
+ duration_with_next_start = next_start - start
1140
+
1141
+ if duration_with_next_start > duration_true:
1142
+ extra_time = duration_with_next_start - duration_true
1143
+
1144
+ if speaker == next_speaker:
1145
+ # half
1146
+ smoth_duration = duration_true + (extra_time * 0.5)
1147
+ else:
1148
+ # 7/10
1149
+ smoth_duration = duration_true + (extra_time * 0.7)
1150
+ logger.debug(
1151
+ f"Base acc: {acc_percentage}, "
1152
+ f"smoth acc: {duration_tts / smoth_duration}"
1153
+ )
1154
+ acc_percentage = max(1.2, (duration_tts / smoth_duration))
1155
+
1156
+ except Exception as error:
1157
+ logger.error(str(error))
1158
+
1159
+ if acc_percentage > max_accelerate_audio:
1160
+ acc_percentage = max_accelerate_audio
1161
+ elif acc_percentage <= 1.15 and acc_percentage >= 0.8:
1162
+ acc_percentage = 1.0
1163
+ elif acc_percentage <= 0.79:
1164
+ acc_percentage = 0.8
1165
+
1166
+ # Round
1167
+ acc_percentage = round(acc_percentage + 0.0, 1)
1168
+
1169
+ # Format read if need
1170
+ if speaker in speakers_edge:
1171
+ info_enc = sf.info(filename).format
1172
+ else:
1173
+ info_enc = "OGG"
1174
+
1175
+ # Apply aceleration or opposite to the audio file in folder_output folder
1176
+ if acc_percentage == 1.0 and info_enc == "OGG":
1177
+ copy_files(filename, f"{folder_output}{os.sep}audio")
1178
+ else:
1179
+ os.system(
1180
+ f"ffmpeg -y -loglevel panic -i {filename} -filter:a atempo={acc_percentage} {folder_output}/{filename}"
1181
+ )
1182
+
1183
+ if logger.isEnabledFor(logging.DEBUG):
1184
+ duration_create = librosa.get_duration(
1185
+ filename=f"{folder_output}/{filename}"
1186
+ )
1187
+ logger.debug(
1188
+ f"acc_percen is {acc_percentage}, tts duration "
1189
+ f"is {duration_tts}, new duration is {duration_create}"
1190
+ f", for {filename}"
1191
+ )
1192
+
1193
+ audio_files.append(f"{folder_output}/{filename}")
1194
+ speaker = "TTS Speaker {:02d}".format(int(speaker[-2:]) + 1)
1195
+ speakers_list.append(speaker)
1196
+
1197
+ return audio_files, speakers_list
1198
+
1199
+
1200
+ # =====================================
1201
+ # Tone color converter
1202
+ # =====================================
1203
+
1204
+
1205
+ def se_process_audio_segments(
1206
+ source_seg, tone_color_converter, device, remove_previous_processed=True
1207
+ ):
1208
+ # list wav seg
1209
+ source_audio_segs = glob.glob(f"{source_seg}/*.wav")
1210
+ if not source_audio_segs:
1211
+ raise ValueError(
1212
+ f"No audio segments found in {str(source_audio_segs)}"
1213
+ )
1214
+
1215
+ source_se_path = os.path.join(source_seg, "se.pth")
1216
+
1217
+ # if exist not create wav
1218
+ if os.path.isfile(source_se_path):
1219
+ se = torch.load(source_se_path).to(device)
1220
+ logger.debug(f"Previous created {source_se_path}")
1221
+ else:
1222
+ se = tone_color_converter.extract_se(source_audio_segs, source_se_path)
1223
+
1224
+ return se
1225
+
1226
+
1227
+ def create_wav_vc(
1228
+ valid_speakers,
1229
+ segments_base,
1230
+ audio_name,
1231
+ max_segments=10,
1232
+ target_dir="processed",
1233
+ get_vocals_dereverb=False,
1234
+ ):
1235
+ # valid_speakers = list({item['speaker'] for item in segments_base})
1236
+
1237
+ # Before function delete automatic delete_previous_automatic
1238
+ output_dir = os.path.join(".", target_dir) # remove content
1239
+ # remove_directory_contents(output_dir)
1240
+
1241
+ path_source_segments = []
1242
+ path_target_segments = []
1243
+ for speaker in valid_speakers:
1244
+ filtered_speaker = [
1245
+ segment
1246
+ for segment in segments_base
1247
+ if segment["speaker"] == speaker
1248
+ ]
1249
+ if len(filtered_speaker) > 4:
1250
+ filtered_speaker = filtered_speaker[1:]
1251
+
1252
+ dir_name_speaker = speaker + audio_name
1253
+ dir_name_speaker_tts = "tts" + speaker + audio_name
1254
+ dir_path_speaker = os.path.join(output_dir, dir_name_speaker)
1255
+ dir_path_speaker_tts = os.path.join(output_dir, dir_name_speaker_tts)
1256
+ create_directories([dir_path_speaker, dir_path_speaker_tts])
1257
+
1258
+ path_target_segments.append(dir_path_speaker)
1259
+ path_source_segments.append(dir_path_speaker_tts)
1260
+
1261
+ # create wav
1262
+ max_segments_count = 0
1263
+ for seg in filtered_speaker:
1264
+ duration = float(seg["end"]) - float(seg["start"])
1265
+ if duration > 3.0 and duration < 18.0:
1266
+ logger.info(
1267
+ f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {duration}, {seg["text"]}'
1268
+ )
1269
+ name_new_wav = str(seg["start"])
1270
+
1271
+ check_segment_audio_target_file = os.path.join(
1272
+ dir_path_speaker, f"{name_new_wav}.wav"
1273
+ )
1274
+
1275
+ if os.path.exists(check_segment_audio_target_file):
1276
+ logger.debug(
1277
+ "Segment vc source exists: "
1278
+ f"{check_segment_audio_target_file}"
1279
+ )
1280
+ pass
1281
+ else:
1282
+ create_wav_file_vc(
1283
+ sample_name=name_new_wav,
1284
+ audio_wav="audio.wav",
1285
+ start=(float(seg["start"]) + 1.0),
1286
+ end=(float(seg["end"]) - 1.0),
1287
+ output_final_path=dir_path_speaker,
1288
+ get_vocals_dereverb=get_vocals_dereverb,
1289
+ )
1290
+
1291
+ file_name_tts = f"audio2/audio/{str(seg['start'])}.ogg"
1292
+ # copy_files(file_name_tts, os.path.join(output_dir, dir_name_speaker_tts)
1293
+ convert_to_xtts_good_sample(
1294
+ file_name_tts, dir_path_speaker_tts
1295
+ )
1296
+
1297
+ max_segments_count += 1
1298
+ if max_segments_count == max_segments:
1299
+ break
1300
+
1301
+ if max_segments_count == 0:
1302
+ logger.info("Taking the first segment")
1303
+ seg = filtered_speaker[0]
1304
+ logger.info(
1305
+ f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {seg["text"]}'
1306
+ )
1307
+ max_duration = float(seg["end"]) - float(seg["start"])
1308
+ max_duration = max(1.0, min(max_duration, 18.0))
1309
+
1310
+ name_new_wav = str(seg["start"])
1311
+ create_wav_file_vc(
1312
+ sample_name=name_new_wav,
1313
+ audio_wav="audio.wav",
1314
+ start=(float(seg["start"])),
1315
+ end=(float(seg["start"]) + max_duration),
1316
+ output_final_path=dir_path_speaker,
1317
+ get_vocals_dereverb=get_vocals_dereverb,
1318
+ )
1319
+
1320
+ file_name_tts = f"audio2/audio/{str(seg['start'])}.ogg"
1321
+ # copy_files(file_name_tts, os.path.join(output_dir, dir_name_speaker_tts)
1322
+ convert_to_xtts_good_sample(file_name_tts, dir_path_speaker_tts)
1323
+
1324
+ logger.debug(f"Base: {str(path_source_segments)}")
1325
+ logger.debug(f"Target: {str(path_target_segments)}")
1326
+
1327
+ return path_source_segments, path_target_segments
1328
+
1329
+
1330
+ def toneconverter_openvoice(
1331
+ result_diarize,
1332
+ preprocessor_max_segments,
1333
+ remove_previous_process=True,
1334
+ get_vocals_dereverb=False,
1335
+ model="openvoice",
1336
+ ):
1337
+ audio_path = "audio.wav"
1338
+ # se_path = "se.pth"
1339
+ target_dir = "processed"
1340
+ create_directories(target_dir)
1341
+
1342
+ from openvoice import se_extractor
1343
+ from openvoice.api import ToneColorConverter
1344
+
1345
+ audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{se_extractor.hash_numpy_array(audio_path)}"
1346
+ # se_path = os.path.join(target_dir, audio_name, 'se.pth')
1347
+
1348
+ # create wav seg original and target
1349
+
1350
+ valid_speakers = list(
1351
+ {item["speaker"] for item in result_diarize["segments"]}
1352
+ )
1353
+
1354
+ logger.info("Openvoice preprocessor...")
1355
+
1356
+ if remove_previous_process:
1357
+ remove_directory_contents(target_dir)
1358
+
1359
+ path_source_segments, path_target_segments = create_wav_vc(
1360
+ valid_speakers,
1361
+ result_diarize["segments"],
1362
+ audio_name,
1363
+ max_segments=preprocessor_max_segments,
1364
+ get_vocals_dereverb=get_vocals_dereverb,
1365
+ )
1366
+
1367
+ logger.info("Openvoice loading model...")
1368
+ model_path_openvoice = "./OPENVOICE_MODELS"
1369
+ url_model_openvoice = "https://huggingface.co/myshell-ai/OpenVoice/resolve/main/checkpoints/converter"
1370
+
1371
+ if "v2" in model:
1372
+ model_path = os.path.join(model_path_openvoice, "v2")
1373
+ url_model_openvoice = url_model_openvoice.replace(
1374
+ "OpenVoice", "OpenVoiceV2"
1375
+ ).replace("checkpoints/", "")
1376
+ else:
1377
+ model_path = os.path.join(model_path_openvoice, "v1")
1378
+ create_directories(model_path)
1379
+
1380
+ config_url = f"{url_model_openvoice}/config.json"
1381
+ checkpoint_url = f"{url_model_openvoice}/checkpoint.pth"
1382
+
1383
+ config_path = download_manager(url=config_url, path=model_path)
1384
+ checkpoint_path = download_manager(
1385
+ url=checkpoint_url, path=model_path
1386
+ )
1387
+
1388
+ device = os.environ.get("QUANTUM_DEVICE")
1389
+ tone_color_converter = ToneColorConverter(config_path, device=device)
1390
+ tone_color_converter.load_ckpt(checkpoint_path)
1391
+
1392
+ logger.info("Openvoice tone color converter:")
1393
+ global_progress_bar = tqdm(total=len(result_diarize["segments"]), desc="Progress")
1394
+
1395
+ for source_seg, target_seg, speaker in zip(
1396
+ path_source_segments, path_target_segments, valid_speakers
1397
+ ):
1398
+ # source_se_path = os.path.join(source_seg, 'se.pth')
1399
+ source_se = se_process_audio_segments(source_seg, tone_color_converter, device)
1400
+ # target_se_path = os.path.join(target_seg, 'se.pth')
1401
+ target_se = se_process_audio_segments(target_seg, tone_color_converter, device)
1402
+
1403
+ # Iterate throw segments
1404
+ encode_message = "@MyShell"
1405
+ filtered_speaker = [
1406
+ segment
1407
+ for segment in result_diarize["segments"]
1408
+ if segment["speaker"] == speaker
1409
+ ]
1410
+ for seg in filtered_speaker:
1411
+ src_path = (
1412
+ save_path
1413
+ ) = f"audio2/audio/{str(seg['start'])}.ogg" # overwrite
1414
+ logger.debug(f"{src_path}")
1415
+
1416
+ tone_color_converter.convert(
1417
+ audio_src_path=src_path,
1418
+ src_se=source_se,
1419
+ tgt_se=target_se,
1420
+ output_path=save_path,
1421
+ message=encode_message,
1422
+ )
1423
+
1424
+ global_progress_bar.update(1)
1425
+
1426
+ global_progress_bar.close()
1427
+
1428
+ try:
1429
+ del tone_color_converter
1430
+ gc.collect()
1431
+ torch.cuda.empty_cache()
1432
+ except Exception as error:
1433
+ logger.error(str(error))
1434
+ gc.collect()
1435
+ torch.cuda.empty_cache()
1436
+
1437
+
1438
+ def toneconverter_freevc(
1439
+ result_diarize,
1440
+ remove_previous_process=True,
1441
+ get_vocals_dereverb=False,
1442
+ ):
1443
+ audio_path = "audio.wav"
1444
+ target_dir = "processed"
1445
+ create_directories(target_dir)
1446
+
1447
+ from openvoice import se_extractor
1448
+
1449
+ audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{se_extractor.hash_numpy_array(audio_path)}"
1450
+
1451
+ # create wav seg; original is target and dubbing is source
1452
+ valid_speakers = list(
1453
+ {item["speaker"] for item in result_diarize["segments"]}
1454
+ )
1455
+
1456
+ logger.info("FreeVC preprocessor...")
1457
+
1458
+ if remove_previous_process:
1459
+ remove_directory_contents(target_dir)
1460
+
1461
+ path_source_segments, path_target_segments = create_wav_vc(
1462
+ valid_speakers,
1463
+ result_diarize["segments"],
1464
+ audio_name,
1465
+ max_segments=1,
1466
+ get_vocals_dereverb=get_vocals_dereverb,
1467
+ )
1468
+
1469
+ logger.info("FreeVC loading model...")
1470
+ device_id = os.environ.get("QUANTUM_DEVICE")
1471
+ device = None if device_id == "cpu" else device_id
1472
+ try:
1473
+ from TTS.api import TTS
1474
+ tts = TTS(
1475
+ model_name="voice_conversion_models/multilingual/vctk/freevc24",
1476
+ progress_bar=False
1477
+ ).to(device)
1478
+ except Exception as error:
1479
+ logger.error(str(error))
1480
+ logger.error("Error loading the FreeVC model.")
1481
+ return
1482
+
1483
+ logger.info("FreeVC process:")
1484
+ global_progress_bar = tqdm(total=len(result_diarize["segments"]), desc="Progress")
1485
+
1486
+ for source_seg, target_seg, speaker in zip(
1487
+ path_source_segments, path_target_segments, valid_speakers
1488
+ ):
1489
+
1490
+ filtered_speaker = [
1491
+ segment
1492
+ for segment in result_diarize["segments"]
1493
+ if segment["speaker"] == speaker
1494
+ ]
1495
+
1496
+ files_and_directories = os.listdir(target_seg)
1497
+ wav_files = [file for file in files_and_directories if file.endswith(".wav")]
1498
+ original_wav_audio_segment = os.path.join(target_seg, wav_files[0])
1499
+
1500
+ for seg in filtered_speaker:
1501
+
1502
+ src_path = (
1503
+ save_path
1504
+ ) = f"audio2/audio/{str(seg['start'])}.ogg" # overwrite
1505
+ logger.debug(f"{src_path} - {original_wav_audio_segment}")
1506
+
1507
+ wav = tts.voice_conversion(
1508
+ source_wav=src_path,
1509
+ target_wav=original_wav_audio_segment,
1510
+ )
1511
+
1512
+ sf.write(
1513
+ file=save_path,
1514
+ samplerate=tts.voice_converter.vc_config.audio.output_sample_rate,
1515
+ data=wav,
1516
+ format="ogg",
1517
+ subtype="vorbis",
1518
+ )
1519
+
1520
+ global_progress_bar.update(1)
1521
+
1522
+ global_progress_bar.close()
1523
+
1524
+ try:
1525
+ del tts
1526
+ gc.collect()
1527
+ torch.cuda.empty_cache()
1528
+ except Exception as error:
1529
+ logger.error(str(error))
1530
+ gc.collect()
1531
+ torch.cuda.empty_cache()
1532
+
1533
+
1534
+ def toneconverter(
1535
+ result_diarize,
1536
+ preprocessor_max_segments,
1537
+ remove_previous_process=True,
1538
+ get_vocals_dereverb=False,
1539
+ method_vc="freevc"
1540
+ ):
1541
+
1542
+ if method_vc == "freevc":
1543
+ if preprocessor_max_segments > 1:
1544
+ logger.info("FreeVC only uses one segment.")
1545
+ return toneconverter_freevc(
1546
+ result_diarize,
1547
+ remove_previous_process=remove_previous_process,
1548
+ get_vocals_dereverb=get_vocals_dereverb,
1549
+ )
1550
+ elif "openvoice" in method_vc:
1551
+ return toneconverter_openvoice(
1552
+ result_diarize,
1553
+ preprocessor_max_segments,
1554
+ remove_previous_process=remove_previous_process,
1555
+ get_vocals_dereverb=get_vocals_dereverb,
1556
+ model=method_vc,
1557
+ )
1558
+
1559
+
1560
+ if __name__ == "__main__":
1561
+ from segments import result_diarize
1562
+
1563
+ audio_segmentation_to_voice(
1564
+ result_diarize,
1565
+ TRANSLATE_AUDIO_TO="en",
1566
+ max_accelerate_audio=2.1,
1567
+ is_gui=True,
1568
+ tts_voice00="en-facebook-mms VITS",
1569
+ tts_voice01="en-CA-ClaraNeural-Female",
1570
+ tts_voice02="en-GB-ThomasNeural-Male",
1571
+ tts_voice03="en-GB-SoniaNeural-Female",
1572
+ tts_voice04="en-NZ-MitchellNeural-Male",
1573
+ tts_voice05="en-GB-MaisieNeural-Female",
1574
+ )
quantum_dubbing/translate_segments.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from deep_translator import GoogleTranslator
3
+ from itertools import chain
4
+ import copy
5
+ from .language_configuration import fix_code_language, INVERTED_LANGUAGES
6
+ from .logging_setup import logger
7
+ import re
8
+ import json
9
+ import time
10
+
11
+ TRANSLATION_PROCESS_OPTIONS = [
12
+ "google_translator_batch",
13
+ "google_translator",
14
+ "gpt-3.5-turbo-0125_batch",
15
+ "gpt-3.5-turbo-0125",
16
+ "gpt-4-turbo-preview_batch",
17
+ "gpt-4-turbo-preview",
18
+ "disable_translation",
19
+ ]
20
+ DOCS_TRANSLATION_PROCESS_OPTIONS = [
21
+ "google_translator",
22
+ "gpt-3.5-turbo-0125",
23
+ "gpt-4-turbo-preview",
24
+ "disable_translation",
25
+ ]
26
+
27
+
28
+ def translate_iterative(segments, target, source=None):
29
+ """
30
+ Translate text segments individually to the specified language.
31
+
32
+ Parameters:
33
+ - segments (list): A list of dictionaries with 'text' as a key for
34
+ segment text.
35
+ - target (str): Target language code.
36
+ - source (str, optional): Source language code. Defaults to None.
37
+
38
+ Returns:
39
+ - list: Translated text segments in the target language.
40
+
41
+ Notes:
42
+ - Translates each segment using Google Translate.
43
+
44
+ Example:
45
+ segments = [{'text': 'first segment.'}, {'text': 'second segment.'}]
46
+ translated_segments = translate_iterative(segments, 'es')
47
+ """
48
+
49
+ segments_ = copy.deepcopy(segments)
50
+
51
+ if (
52
+ not source
53
+ ):
54
+ logger.debug("No source language")
55
+ source = "auto"
56
+
57
+ translator = GoogleTranslator(source=source, target=target)
58
+
59
+ for line in tqdm(range(len(segments_))):
60
+ text = segments_[line]["text"]
61
+ translated_line = translator.translate(text.strip())
62
+ segments_[line]["text"] = translated_line
63
+
64
+ return segments_
65
+
66
+
67
+ def verify_translate(
68
+ segments,
69
+ segments_copy,
70
+ translated_lines,
71
+ target,
72
+ source
73
+ ):
74
+ """
75
+ Verify integrity and translate segments if lengths match, otherwise
76
+ switch to iterative translation.
77
+ """
78
+ if len(segments) == len(translated_lines):
79
+ for line in range(len(segments_copy)):
80
+ logger.debug(
81
+ f"{segments_copy[line]['text']} >> "
82
+ f"{translated_lines[line].strip()}"
83
+ )
84
+ segments_copy[line]["text"] = translated_lines[
85
+ line].replace("\t", "").replace("\n", "").strip()
86
+ return segments_copy
87
+ else:
88
+ logger.error(
89
+ "The translation failed, switching to google_translate iterative. "
90
+ f"{len(segments), len(translated_lines)}"
91
+ )
92
+ return translate_iterative(segments, target, source)
93
+
94
+
95
+ def translate_batch(segments, target, chunk_size=2000, source=None):
96
+ """
97
+ Translate a batch of text segments into the specified language in chunks,
98
+ respecting the character limit.
99
+
100
+ Parameters:
101
+ - segments (list): List of dictionaries with 'text' as a key for segment
102
+ text.
103
+ - target (str): Target language code.
104
+ - chunk_size (int, optional): Maximum character limit for each translation
105
+ chunk (default is 2000; max 5000).
106
+ - source (str, optional): Source language code. Defaults to None.
107
+
108
+ Returns:
109
+ - list: Translated text segments in the target language.
110
+
111
+ Notes:
112
+ - Splits input segments into chunks respecting the character limit for
113
+ translation.
114
+ - Translates the chunks using Google Translate.
115
+ - If chunked translation fails, switches to iterative translation using
116
+ `translate_iterative()`.
117
+
118
+ Example:
119
+ segments = [{'text': 'first segment.'}, {'text': 'second segment.'}]
120
+ translated = translate_batch(segments, 'es', chunk_size=4000, source='en')
121
+ """
122
+
123
+ segments_copy = copy.deepcopy(segments)
124
+
125
+ if (
126
+ not source
127
+ ):
128
+ logger.debug("No source language")
129
+ source = "auto"
130
+
131
+ # Get text
132
+ text_lines = []
133
+ for line in range(len(segments_copy)):
134
+ text = segments_copy[line]["text"].strip()
135
+ text_lines.append(text)
136
+
137
+ # chunk limit
138
+ text_merge = []
139
+ actual_chunk = ""
140
+ global_text_list = []
141
+ actual_text_list = []
142
+ for one_line in text_lines:
143
+ one_line = " " if not one_line else one_line
144
+ if (len(actual_chunk) + len(one_line)) <= chunk_size:
145
+ if actual_chunk:
146
+ actual_chunk += " ||||| "
147
+ actual_chunk += one_line
148
+ actual_text_list.append(one_line)
149
+ else:
150
+ text_merge.append(actual_chunk)
151
+ actual_chunk = one_line
152
+ global_text_list.append(actual_text_list)
153
+ actual_text_list = [one_line]
154
+ if actual_chunk:
155
+ text_merge.append(actual_chunk)
156
+ global_text_list.append(actual_text_list)
157
+
158
+ # translate chunks
159
+ progress_bar = tqdm(total=len(segments), desc="Translating")
160
+ translator = GoogleTranslator(source=source, target=target)
161
+ split_list = []
162
+ try:
163
+ for text, text_iterable in zip(text_merge, global_text_list):
164
+ translated_line = translator.translate(text.strip())
165
+ split_text = translated_line.split("|||||")
166
+ if len(split_text) == len(text_iterable):
167
+ progress_bar.update(len(split_text))
168
+ else:
169
+ logger.debug(
170
+ "Chunk fixing iteratively. Len chunk: "
171
+ f"{len(split_text)}, expected: {len(text_iterable)}"
172
+ )
173
+ split_text = []
174
+ for txt_iter in text_iterable:
175
+ translated_txt = translator.translate(txt_iter.strip())
176
+ split_text.append(translated_txt)
177
+ progress_bar.update(1)
178
+ split_list.append(split_text)
179
+ progress_bar.close()
180
+ except Exception as error:
181
+ progress_bar.close()
182
+ logger.error(str(error))
183
+ logger.warning(
184
+ "The translation in chunks failed, switching to iterative."
185
+ " Related: too many request"
186
+ ) # use proxy or less chunk size
187
+ return translate_iterative(segments, target, source)
188
+
189
+ # un chunk
190
+ translated_lines = list(chain.from_iterable(split_list))
191
+
192
+ return verify_translate(
193
+ segments, segments_copy, translated_lines, target, source
194
+ )
195
+
196
+
197
+ def call_gpt_translate(
198
+ client,
199
+ model,
200
+ system_prompt,
201
+ user_prompt,
202
+ original_text=None,
203
+ batch_lines=None,
204
+ ):
205
+
206
+ # https://platform.openai.com/docs/guides/text-generation/json-mode
207
+ response = client.chat.completions.create(
208
+ model=model,
209
+ response_format={"type": "json_object"},
210
+ messages=[
211
+ {"role": "system", "content": system_prompt},
212
+ {"role": "user", "content": user_prompt}
213
+ ]
214
+ )
215
+ result = response.choices[0].message.content
216
+ logger.debug(f"Result: {str(result)}")
217
+
218
+ try:
219
+ translation = json.loads(result)
220
+ except Exception as error:
221
+ match_result = re.search(r'\{.*?\}', result)
222
+ if match_result:
223
+ logger.error(str(error))
224
+ json_str = match_result.group(0)
225
+ translation = json.loads(json_str)
226
+ else:
227
+ raise error
228
+
229
+ # Get valid data
230
+ if batch_lines:
231
+ for conversation in translation.values():
232
+ if isinstance(conversation, dict):
233
+ conversation = list(conversation.values())[0]
234
+ if (
235
+ list(
236
+ original_text["conversation"][0].values()
237
+ )[0].strip() ==
238
+ list(conversation[0].values())[0].strip()
239
+ ):
240
+ continue
241
+ if len(conversation) == batch_lines:
242
+ break
243
+
244
+ fix_conversation_length = []
245
+ for line in conversation:
246
+ for speaker_code, text_tr in line.items():
247
+ fix_conversation_length.append({speaker_code: text_tr})
248
+
249
+ logger.debug(f"Data batch: {str(fix_conversation_length)}")
250
+ logger.debug(
251
+ f"Lines Received: {len(fix_conversation_length)},"
252
+ f" expected: {batch_lines}"
253
+ )
254
+
255
+ return fix_conversation_length
256
+
257
+ else:
258
+ if isinstance(translation, dict):
259
+ translation = list(translation.values())[0]
260
+ if isinstance(translation, list):
261
+ translation = translation[0]
262
+ if isinstance(translation, set):
263
+ translation = list(translation)[0]
264
+ if not isinstance(translation, str):
265
+ raise ValueError(f"No valid response received: {str(translation)}")
266
+
267
+ return translation
268
+
269
+
270
+ def gpt_sequential(segments, model, target, source=None):
271
+ from openai import OpenAI
272
+
273
+ translated_segments = copy.deepcopy(segments)
274
+
275
+ client = OpenAI()
276
+ progress_bar = tqdm(total=len(segments), desc="Translating")
277
+
278
+ lang_tg = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[target]).strip()
279
+ lang_sc = ""
280
+ if source:
281
+ lang_sc = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[source]).strip()
282
+
283
+ fixed_target = fix_code_language(target)
284
+ fixed_source = fix_code_language(source) if source else "auto"
285
+
286
+ system_prompt = "Machine translation designed to output the translated_text JSON."
287
+
288
+ for i, line in enumerate(translated_segments):
289
+ text = line["text"].strip()
290
+ start = line["start"]
291
+ user_prompt = f"Translate the following {lang_sc} text into {lang_tg}, write the fully translated text and nothing more:\n{text}"
292
+
293
+ time.sleep(0.5)
294
+
295
+ try:
296
+ translated_text = call_gpt_translate(
297
+ client,
298
+ model,
299
+ system_prompt,
300
+ user_prompt,
301
+ )
302
+
303
+ except Exception as error:
304
+ logger.error(
305
+ f"{str(error)} >> The text of segment {start} "
306
+ "is being corrected with Google Translate"
307
+ )
308
+ translator = GoogleTranslator(
309
+ source=fixed_source, target=fixed_target
310
+ )
311
+ translated_text = translator.translate(text.strip())
312
+
313
+ translated_segments[i]["text"] = translated_text.strip()
314
+ progress_bar.update(1)
315
+
316
+ progress_bar.close()
317
+
318
+ return translated_segments
319
+
320
+
321
+ def gpt_batch(segments, model, target, token_batch_limit=900, source=None):
322
+ from openai import OpenAI
323
+ import tiktoken
324
+
325
+ token_batch_limit = max(100, (token_batch_limit - 40) // 2)
326
+ progress_bar = tqdm(total=len(segments), desc="Translating")
327
+ segments_copy = copy.deepcopy(segments)
328
+ encoding = tiktoken.get_encoding("cl100k_base")
329
+ client = OpenAI()
330
+
331
+ lang_tg = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[target]).strip()
332
+ lang_sc = ""
333
+ if source:
334
+ lang_sc = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[source]).strip()
335
+
336
+ fixed_target = fix_code_language(target)
337
+ fixed_source = fix_code_language(source) if source else "auto"
338
+
339
+ name_speaker = "ABCDEFGHIJKL"
340
+
341
+ translated_lines = []
342
+ text_data_dict = []
343
+ num_tokens = 0
344
+ count_sk = {char: 0 for char in "ABCDEFGHIJKL"}
345
+
346
+ for i, line in enumerate(segments_copy):
347
+ text = line["text"]
348
+ speaker = line["speaker"]
349
+ last_start = line["start"]
350
+ # text_data_dict.append({str(int(speaker[-1])+1): text})
351
+ index_sk = int(speaker[-2:])
352
+ character_sk = name_speaker[index_sk]
353
+ count_sk[character_sk] += 1
354
+ code_sk = character_sk+str(count_sk[character_sk])
355
+ text_data_dict.append({code_sk: text})
356
+ num_tokens += len(encoding.encode(text)) + 7
357
+ if num_tokens >= token_batch_limit or i == len(segments_copy)-1:
358
+ try:
359
+ batch_lines = len(text_data_dict)
360
+ batch_conversation = {"conversation": copy.deepcopy(text_data_dict)}
361
+ # Reset vars
362
+ num_tokens = 0
363
+ text_data_dict = []
364
+ count_sk = {char: 0 for char in "ABCDEFGHIJKL"}
365
+ # Process translation
366
+ # https://arxiv.org/pdf/2309.03409.pdf
367
+ system_prompt = f"Machine translation designed to output the translated_conversation key JSON containing a list of {batch_lines} items."
368
+ user_prompt = f"Translate each of the following text values in conversation{' from' if lang_sc else ''} {lang_sc} to {lang_tg}:\n{batch_conversation}"
369
+ logger.debug(f"Prompt: {str(user_prompt)}")
370
+
371
+ conversation = call_gpt_translate(
372
+ client,
373
+ model,
374
+ system_prompt,
375
+ user_prompt,
376
+ original_text=batch_conversation,
377
+ batch_lines=batch_lines,
378
+ )
379
+
380
+ if len(conversation) < batch_lines:
381
+ raise ValueError(
382
+ "Incomplete result received. Batch lines: "
383
+ f"{len(conversation)}, expected: {batch_lines}"
384
+ )
385
+
386
+ for i, translated_text in enumerate(conversation):
387
+ if i+1 > batch_lines:
388
+ break
389
+ translated_lines.append(list(translated_text.values())[0])
390
+
391
+ progress_bar.update(batch_lines)
392
+
393
+ except Exception as error:
394
+ logger.error(str(error))
395
+
396
+ first_start = segments_copy[max(0, i-(batch_lines-1))]["start"]
397
+ logger.warning(
398
+ f"The batch from {first_start} to {last_start} "
399
+ "failed, is being corrected with Google Translate"
400
+ )
401
+
402
+ translator = GoogleTranslator(
403
+ source=fixed_source,
404
+ target=fixed_target
405
+ )
406
+
407
+ for txt_source in batch_conversation["conversation"]:
408
+ translated_txt = translator.translate(
409
+ list(txt_source.values())[0].strip()
410
+ )
411
+ translated_lines.append(translated_txt.strip())
412
+ progress_bar.update(1)
413
+
414
+ progress_bar.close()
415
+
416
+ return verify_translate(
417
+ segments, segments_copy, translated_lines, fixed_target, fixed_source
418
+ )
419
+
420
+
421
+ def translate_text(
422
+ segments,
423
+ target,
424
+ translation_process="google_translator_batch",
425
+ chunk_size=4500,
426
+ source=None,
427
+ token_batch_limit=1000,
428
+ ):
429
+ """Translates text segments using a specified process."""
430
+ match translation_process:
431
+ case "google_translator_batch":
432
+ return translate_batch(
433
+ segments,
434
+ fix_code_language(target),
435
+ chunk_size,
436
+ fix_code_language(source)
437
+ )
438
+ case "google_translator":
439
+ return translate_iterative(
440
+ segments,
441
+ fix_code_language(target),
442
+ fix_code_language(source)
443
+ )
444
+ case model if model in ["gpt-3.5-turbo-0125", "gpt-4-turbo-preview"]:
445
+ return gpt_sequential(segments, model, target, source)
446
+ case model if model in ["gpt-3.5-turbo-0125_batch", "gpt-4-turbo-preview_batch",]:
447
+ return gpt_batch(
448
+ segments,
449
+ translation_process.replace("_batch", ""),
450
+ target,
451
+ token_batch_limit,
452
+ source
453
+ )
454
+ case "disable_translation":
455
+ return segments
456
+ case _:
457
+ raise ValueError("No valid translation process")
quantum_dubbing/utils.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, zipfile, rarfile, shutil, subprocess, shlex, sys # noqa
2
+ from .logging_setup import logger
3
+ from urllib.parse import urlparse
4
+ from IPython.utils import capture
5
+ import re
6
+
7
+ VIDEO_EXTENSIONS = [
8
+ ".mp4",
9
+ ".avi",
10
+ ".mov",
11
+ ".mkv",
12
+ ".wmv",
13
+ ".flv",
14
+ ".webm",
15
+ ".m4v",
16
+ ".mpeg",
17
+ ".mpg",
18
+ ".3gp"
19
+ ]
20
+
21
+ AUDIO_EXTENSIONS = [
22
+ ".mp3",
23
+ ".wav",
24
+ ".aiff",
25
+ ".aif",
26
+ ".flac",
27
+ ".aac",
28
+ ".ogg",
29
+ ".wma",
30
+ ".m4a",
31
+ ".alac",
32
+ ".pcm",
33
+ ".opus",
34
+ ".ape",
35
+ ".amr",
36
+ ".ac3",
37
+ ".vox",
38
+ ".caf"
39
+ ]
40
+
41
+ SUBTITLE_EXTENSIONS = [
42
+ ".srt",
43
+ ".vtt",
44
+ ".ass"
45
+ ]
46
+
47
+
48
+ def run_command(command):
49
+ logger.debug(command)
50
+ if isinstance(command, str):
51
+ command = shlex.split(command)
52
+
53
+ sub_params = {
54
+ "stdout": subprocess.PIPE,
55
+ "stderr": subprocess.PIPE,
56
+ "creationflags": subprocess.CREATE_NO_WINDOW
57
+ if sys.platform == "win32"
58
+ else 0,
59
+ }
60
+ process_command = subprocess.Popen(command, **sub_params)
61
+ output, errors = process_command.communicate()
62
+ if (
63
+ process_command.returncode != 0
64
+ ): # or not os.path.exists(mono_path) or os.path.getsize(mono_path) == 0:
65
+ logger.error("Error comnand")
66
+ raise Exception(errors.decode())
67
+
68
+
69
+ def print_tree_directory(root_dir, indent=""):
70
+ if not os.path.exists(root_dir):
71
+ logger.error(f"{indent} Invalid directory or file: {root_dir}")
72
+ return
73
+
74
+ items = os.listdir(root_dir)
75
+
76
+ for index, item in enumerate(sorted(items)):
77
+ item_path = os.path.join(root_dir, item)
78
+ is_last_item = index == len(items) - 1
79
+
80
+ if os.path.isfile(item_path) and item_path.endswith(".zip"):
81
+ with zipfile.ZipFile(item_path, "r") as zip_file:
82
+ print(
83
+ f"{indent}{'└──' if is_last_item else '├──'} {item} (zip file)"
84
+ )
85
+ zip_contents = zip_file.namelist()
86
+ for zip_item in sorted(zip_contents):
87
+ print(
88
+ f"{indent}{' ' if is_last_item else '│ '}{zip_item}"
89
+ )
90
+ else:
91
+ print(f"{indent}{'└──' if is_last_item else '├──'} {item}")
92
+
93
+ if os.path.isdir(item_path):
94
+ new_indent = indent + (" " if is_last_item else "│ ")
95
+ print_tree_directory(item_path, new_indent)
96
+
97
+
98
+ def upload_model_list():
99
+ weight_root = "weights"
100
+ models = []
101
+ for name in os.listdir(weight_root):
102
+ if name.endswith(".pth"):
103
+ models.append("weights/" + name)
104
+ if models:
105
+ logger.debug(models)
106
+
107
+ index_root = "logs"
108
+ index_paths = [None]
109
+ for name in os.listdir(index_root):
110
+ if name.endswith(".index"):
111
+ index_paths.append("logs/" + name)
112
+ if index_paths:
113
+ logger.debug(index_paths)
114
+
115
+ return models, index_paths
116
+
117
+
118
+ def manual_download(url, dst):
119
+ if "drive.google" in url:
120
+ logger.info("Drive url")
121
+ if "folders" in url:
122
+ logger.info("folder")
123
+ os.system(f'gdown --folder "{url}" -O {dst} --fuzzy -c')
124
+ else:
125
+ logger.info("single")
126
+ os.system(f'gdown "{url}" -O {dst} --fuzzy -c')
127
+ elif "huggingface" in url:
128
+ logger.info("HuggingFace url")
129
+ if "/blob/" in url or "/resolve/" in url:
130
+ if "/blob/" in url:
131
+ url = url.replace("/blob/", "/resolve/")
132
+ download_manager(url=url, path=dst, overwrite=True, progress=True)
133
+ else:
134
+ os.system(f"git clone {url} {dst+'repo/'}")
135
+ elif "http" in url:
136
+ logger.info("URL")
137
+ download_manager(url=url, path=dst, overwrite=True, progress=True)
138
+ elif os.path.exists(url):
139
+ logger.info("Path")
140
+ copy_files(url, dst)
141
+ else:
142
+ logger.error(f"No valid URL: {url}")
143
+
144
+
145
+ def download_list(text_downloads):
146
+
147
+ if os.environ.get("ZERO_GPU") == "TRUE":
148
+ raise RuntimeError("This option is disabled in this demo.")
149
+
150
+ try:
151
+ urls = [elem.strip() for elem in text_downloads.split(",")]
152
+ except Exception as error:
153
+ raise ValueError(f"No valid URL. {str(error)}")
154
+
155
+ create_directories(["downloads", "logs", "weights"])
156
+
157
+ path_download = "downloads/"
158
+ for url in urls:
159
+ manual_download(url, path_download)
160
+
161
+ # Tree
162
+ print("####################################")
163
+ print_tree_directory("downloads", indent="")
164
+ print("####################################")
165
+
166
+ # Place files
167
+ select_zip_and_rar_files("downloads/")
168
+
169
+ models, _ = upload_model_list()
170
+
171
+ # hf space models files delete
172
+ remove_directory_contents("downloads/repo")
173
+
174
+ return f"Downloaded = {models}"
175
+
176
+
177
+ def select_zip_and_rar_files(directory_path="downloads/"):
178
+ # filter
179
+ zip_files = []
180
+ rar_files = []
181
+
182
+ for file_name in os.listdir(directory_path):
183
+ if file_name.endswith(".zip"):
184
+ zip_files.append(file_name)
185
+ elif file_name.endswith(".rar"):
186
+ rar_files.append(file_name)
187
+
188
+ # extract
189
+ for file_name in zip_files:
190
+ file_path = os.path.join(directory_path, file_name)
191
+ with zipfile.ZipFile(file_path, "r") as zip_ref:
192
+ zip_ref.extractall(directory_path)
193
+
194
+ for file_name in rar_files:
195
+ file_path = os.path.join(directory_path, file_name)
196
+ with rarfile.RarFile(file_path, "r") as rar_ref:
197
+ rar_ref.extractall(directory_path)
198
+
199
+ # set in path
200
+ def move_files_with_extension(src_dir, extension, destination_dir):
201
+ for root, _, files in os.walk(src_dir):
202
+ for file_name in files:
203
+ if file_name.endswith(extension):
204
+ source_file = os.path.join(root, file_name)
205
+ destination = os.path.join(destination_dir, file_name)
206
+ shutil.move(source_file, destination)
207
+
208
+ move_files_with_extension(directory_path, ".index", "logs/")
209
+ move_files_with_extension(directory_path, ".pth", "weights/")
210
+
211
+ return "Download complete"
212
+
213
+
214
+ def is_file_with_extensions(string_path, extensions):
215
+ return any(string_path.lower().endswith(ext) for ext in extensions)
216
+
217
+
218
+ def is_video_file(string_path):
219
+ return is_file_with_extensions(string_path, VIDEO_EXTENSIONS)
220
+
221
+
222
+ def is_audio_file(string_path):
223
+ return is_file_with_extensions(string_path, AUDIO_EXTENSIONS)
224
+
225
+
226
+ def is_subtitle_file(string_path):
227
+ return is_file_with_extensions(string_path, SUBTITLE_EXTENSIONS)
228
+
229
+
230
+ def get_directory_files(directory):
231
+ audio_files = []
232
+ video_files = []
233
+ sub_files = []
234
+
235
+ for item in os.listdir(directory):
236
+ item_path = os.path.join(directory, item)
237
+
238
+ if os.path.isfile(item_path):
239
+
240
+ if is_audio_file(item_path):
241
+ audio_files.append(item_path)
242
+
243
+ elif is_video_file(item_path):
244
+ video_files.append(item_path)
245
+
246
+ elif is_subtitle_file(item_path):
247
+ sub_files.append(item_path)
248
+
249
+ logger.info(
250
+ f"Files in path ({directory}): "
251
+ f"{str(audio_files + video_files + sub_files)}"
252
+ )
253
+
254
+ return audio_files, video_files, sub_files
255
+
256
+
257
+ def get_valid_files(paths):
258
+ valid_paths = []
259
+ for path in paths:
260
+ if os.path.isdir(path):
261
+ audio_files, video_files, sub_files = get_directory_files(path)
262
+ valid_paths.extend(audio_files)
263
+ valid_paths.extend(video_files)
264
+ valid_paths.extend(sub_files)
265
+ else:
266
+ valid_paths.append(path)
267
+
268
+ return valid_paths
269
+
270
+
271
+ def extract_video_links(link):
272
+
273
+ params_dlp = {"quiet": False, "no_warnings": True, "noplaylist": False}
274
+
275
+ try:
276
+ from yt_dlp import YoutubeDL
277
+ with capture.capture_output() as cap:
278
+ with YoutubeDL(params_dlp) as ydl:
279
+ info_dict = ydl.extract_info( # noqa
280
+ link, download=False, process=True
281
+ )
282
+
283
+ urls = re.findall(r'\[youtube\] Extracting URL: (.*?)\n', cap.stdout)
284
+ logger.info(f"List of videos in ({link}): {str(urls)}")
285
+ del cap
286
+ except Exception as error:
287
+ logger.error(f"{link} >> {str(error)}")
288
+ urls = [link]
289
+
290
+ return urls
291
+
292
+
293
+ def get_link_list(urls):
294
+ valid_links = []
295
+ for url_video in urls:
296
+ if "youtube.com" in url_video and "/watch?v=" not in url_video:
297
+ url_links = extract_video_links(url_video)
298
+ valid_links.extend(url_links)
299
+ else:
300
+ valid_links.append(url_video)
301
+ return valid_links
302
+
303
+ # =====================================
304
+ # Download Manager
305
+ # =====================================
306
+
307
+
308
+ def load_file_from_url(
309
+ url: str,
310
+ model_dir: str,
311
+ file_name: str | None = None,
312
+ overwrite: bool = False,
313
+ progress: bool = True,
314
+ ) -> str:
315
+ """Download a file from `url` into `model_dir`,
316
+ using the file present if possible.
317
+
318
+ Returns the path to the downloaded file.
319
+ """
320
+ os.makedirs(model_dir, exist_ok=True)
321
+ if not file_name:
322
+ parts = urlparse(url)
323
+ file_name = os.path.basename(parts.path)
324
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
325
+
326
+ # Overwrite
327
+ if os.path.exists(cached_file):
328
+ if overwrite or os.path.getsize(cached_file) == 0:
329
+ remove_files(cached_file)
330
+
331
+ # Download
332
+ if not os.path.exists(cached_file):
333
+ logger.info(f'Downloading: "{url}" to {cached_file}\n')
334
+ from torch.hub import download_url_to_file
335
+
336
+ download_url_to_file(url, cached_file, progress=progress)
337
+ else:
338
+ logger.debug(cached_file)
339
+
340
+ return cached_file
341
+
342
+
343
+ def friendly_name(file: str):
344
+ if file.startswith("http"):
345
+ file = urlparse(file).path
346
+
347
+ file = os.path.basename(file)
348
+ model_name, extension = os.path.splitext(file)
349
+ return model_name, extension
350
+
351
+
352
+ def download_manager(
353
+ url: str,
354
+ path: str,
355
+ extension: str = "",
356
+ overwrite: bool = False,
357
+ progress: bool = True,
358
+ ):
359
+ url = url.strip()
360
+
361
+ name, ext = friendly_name(url)
362
+ name += ext if not extension else f".{extension}"
363
+
364
+ if url.startswith("http"):
365
+ filename = load_file_from_url(
366
+ url=url,
367
+ model_dir=path,
368
+ file_name=name,
369
+ overwrite=overwrite,
370
+ progress=progress,
371
+ )
372
+ else:
373
+ filename = path
374
+
375
+ return filename
376
+
377
+
378
+ # =====================================
379
+ # File management
380
+ # =====================================
381
+
382
+
383
+ # only remove files
384
+ def remove_files(file_list):
385
+ if isinstance(file_list, str):
386
+ file_list = [file_list]
387
+
388
+ for file in file_list:
389
+ if os.path.exists(file):
390
+ os.remove(file)
391
+
392
+
393
+ def remove_directory_contents(directory_path):
394
+ """
395
+ Removes all files and subdirectories within a directory.
396
+
397
+ Parameters:
398
+ directory_path (str): Path to the directory whose
399
+ contents need to be removed.
400
+ """
401
+ if os.path.exists(directory_path):
402
+ for filename in os.listdir(directory_path):
403
+ file_path = os.path.join(directory_path, filename)
404
+ try:
405
+ if os.path.isfile(file_path):
406
+ os.remove(file_path)
407
+ elif os.path.isdir(file_path):
408
+ shutil.rmtree(file_path)
409
+ except Exception as e:
410
+ logger.error(f"Failed to delete {file_path}. Reason: {e}")
411
+ logger.info(f"Content in '{directory_path}' removed.")
412
+ else:
413
+ logger.error(f"Directory '{directory_path}' does not exist.")
414
+
415
+
416
+ # Create directory if not exists
417
+ def create_directories(directory_path):
418
+ if isinstance(directory_path, str):
419
+ directory_path = [directory_path]
420
+ for one_dir_path in directory_path:
421
+ if not os.path.exists(one_dir_path):
422
+ os.makedirs(one_dir_path)
423
+ logger.debug(f"Directory '{one_dir_path}' created.")
424
+
425
+
426
+ def move_files(source_dir, destination_dir, extension=""):
427
+ """
428
+ Moves file(s) from the source path to the destination path.
429
+
430
+ Parameters:
431
+ source_dir (str): Path to the source directory.
432
+ destination_dir (str): Path to the destination directory.
433
+ extension (str): Only move files with this extension.
434
+ """
435
+ create_directories(destination_dir)
436
+
437
+ for filename in os.listdir(source_dir):
438
+ source_path = os.path.join(source_dir, filename)
439
+ destination_path = os.path.join(destination_dir, filename)
440
+ if extension and not filename.endswith(extension):
441
+ continue
442
+ os.replace(source_path, destination_path)
443
+
444
+
445
+ def copy_files(source_path, destination_path):
446
+ """
447
+ Copies a file or multiple files from a source path to a destination path.
448
+
449
+ Parameters:
450
+ source_path (str or list): Path or list of paths to the source
451
+ file(s) or directory.
452
+ destination_path (str): Path to the destination directory.
453
+ """
454
+ create_directories(destination_path)
455
+
456
+ if isinstance(source_path, str):
457
+ source_path = [source_path]
458
+
459
+ if os.path.isdir(source_path[0]):
460
+ # Copy all files from the source directory to the destination directory
461
+ base_path = source_path[0]
462
+ source_path = os.listdir(source_path[0])
463
+ source_path = [
464
+ os.path.join(base_path, file_name) for file_name in source_path
465
+ ]
466
+
467
+ for one_source_path in source_path:
468
+ if os.path.exists(one_source_path):
469
+ shutil.copy2(one_source_path, destination_path)
470
+ logger.debug(
471
+ f"File '{one_source_path}' copied to '{destination_path}'."
472
+ )
473
+ else:
474
+ logger.error(f"File '{one_source_path}' does not exist.")
475
+
476
+
477
+ def rename_file(current_name, new_name):
478
+ file_directory = os.path.dirname(current_name)
479
+
480
+ if os.path.exists(current_name):
481
+ dir_new_name_file = os.path.join(file_directory, new_name)
482
+ os.rename(current_name, dir_new_name_file)
483
+ logger.debug(f"File '{current_name}' renamed to '{new_name}'.")
484
+ return dir_new_name_file
485
+ else:
486
+ logger.error(f"File '{current_name}' does not exist.")
487
+ return None