A Le Thanh Son commited on
Commit
6d75162
·
1 Parent(s): d1fc75f
Files changed (9) hide show
  1. .gitignore +54 -0
  2. README.md +94 -4
  3. app.py +294 -0
  4. generator.py +186 -0
  5. hf_requirements.txt +11 -0
  6. models.py +203 -0
  7. requirements.txt +11 -0
  8. test_model.py +73 -0
  9. watermarking.py +79 -0
.gitignore ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Environments
24
+ .env
25
+ .venv
26
+ env/
27
+ venv/
28
+ ENV/
29
+ env.bak/
30
+ venv.bak/
31
+
32
+ # Jupyter Notebook
33
+ .ipynb_checkpoints
34
+
35
+ # PyCharm
36
+ .idea/
37
+
38
+ # VS Code
39
+ .vscode/
40
+
41
+ # Temporary files
42
+ *.tmp
43
+ *.wav
44
+ *.mp3
45
+ *.ogg
46
+ temp/
47
+
48
+ # Logs
49
+ logs/
50
+ *.log
51
+
52
+ # HuggingFace
53
+ .cache/
54
+ huggingface/
README.md CHANGED
@@ -1,12 +1,102 @@
1
  ---
2
- title: Csm 1b Gradio V2
3
- emoji: 📊
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.21.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: CSM-1B Gradio Demo
3
+ emoji: 🎙️
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.26.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # CSM-1B Text-to-Speech Demo
13
+
14
+ Ứng dụng này sử dụng mô hình CSM-1B (Collaborative Speech Model) để chuyển đổi văn bản thành giọng nói với chất lượng cao.
15
+
16
+ ## Tính năng
17
+
18
+ - **Tạo âm thanh đơn giản**: Chuyển đổi văn bản thành giọng nói với các tùy chọn về ID người nói, thời lượng, temperature và top-k.
19
+ - **Tạo âm thanh với ngữ cảnh**: Cung cấp các đoạn âm thanh và văn bản làm ngữ cảnh để mô hình tạo ra âm thanh phù hợp hơn.
20
+ - **Tối ưu GPU**: Sử dụng ZeroGPU của Hugging Face Spaces để tối ưu việc sử dụng GPU.
21
+
22
+ ## Cài đặt và Cấu hình
23
+
24
+ ### Yêu cầu truy cập
25
+
26
+ Để sử dụng mô hình CSM-1B, bạn cần có quyền truy cập vào các mô hình sau trên Hugging Face:
27
+
28
+ - [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)
29
+ - [sesame/csm-1b](https://huggingface.co/sesame/csm-1b)
30
+
31
+ ### Cấu hình Hugging Face Token
32
+
33
+ 1. Tạo tài khoản Hugging Face nếu bạn chưa có.
34
+ 2. Truy cập vào [Hugging Face Settings](https://huggingface.co/settings/tokens) để tạo token.
35
+ 3. Yêu cầu quyền truy cập vào các mô hình nếu cần.
36
+ 4. Đặt biến môi trường `HF_TOKEN` với giá trị là token của bạn:
37
+ ```bash
38
+ export HF_TOKEN=your_token_here
39
+ ```
40
+ 5. Hoặc bạn có thể nhập token trực tiếp trong tab "Cấu hình" của ứng dụng.
41
+
42
+ ### Cài đặt
43
+
44
+ ```bash
45
+ git clone https://github.com/yourusername/csm-1b-gradio.git
46
+ cd csm-1b-gradio
47
+ pip install -r requirements.txt
48
+ ```
49
+
50
+ ## Cách sử dụng
51
+
52
+ 1. Khởi động ứng dụng:
53
+ ```bash
54
+ python app.py
55
+ ```
56
+ 2. Mở trình duyệt web và truy cập địa chỉ được hiển thị (thường là http://127.0.0.1:7860).
57
+ 3. Nhập văn bản bạn muốn chuyển thành giọng nói.
58
+ 4. Chọn ID người nói (từ 0-10).
59
+ 5. Điều chỉnh các tham số như thời lượng tối đa, temperature và top-k.
60
+ 6. Nhấn nút "Tạo âm thanh" để tạo giọng nói.
61
+
62
+ ## Thông tin về mô hình
63
+
64
+ CSM-1B là một mô hình text-to-speech tiên tiến được phát triển bởi Sesame AI Labs. Mô hình này có khả năng tạo giọng nói tự nhiên từ văn bản với nhiều giọng nói khác nhau.
65
+
66
+ ## ZeroGPU
67
+
68
+ Ứng dụng này sử dụng ZeroGPU của Hugging Face Spaces để tối ưu việc sử dụng GPU. ZeroGPU giúp giải phóng bộ nhớ GPU khi không sử dụng, giúp tiết kiệm tài nguyên và cải thiện hiệu suất.
69
+
70
+ ```python
71
+ import spaces
72
+
73
+ @spaces.GPU
74
+ def my_gpu_function():
75
+ # Hàm này sẽ chỉ sử dụng GPU khi được gọi
76
+ # và giải phóng GPU sau khi hoàn thành
77
+ pass
78
+ ```
79
+
80
+ Khi triển khai trên Hugging Face Spaces, ZeroGPU sẽ tự động quản lý việc sử dụng GPU, giúp ứng dụng hoạt động hiệu quả hơn.
81
+
82
+ ## Lưu ý
83
+
84
+ - Mô hình này sử dụng watermarking để đánh dấu âm thanh được tạo ra bởi AI.
85
+ - Thời gian tạo âm thanh phụ thuộc vào độ dài văn bản và cấu hình phần cứng.
86
+ - Bạn cần có quyền truy cập vào mô hình CSM-1B trên Hugging Face để sử dụng ứng dụng này.
87
+
88
+ ## Triển khai trên Hugging Face Spaces
89
+
90
+ Để triển khai ứng dụng này trên Hugging Face Spaces:
91
+
92
+ 1. Tạo một Space mới trên Hugging Face với SDK là Gradio.
93
+ 2. Tải lên tất cả các file của dự án.
94
+ 3. Trong phần cài đặt của Space, thêm biến môi trường `HF_TOKEN` với giá trị là token của bạn.
95
+ 4. Chọn cấu hình phần cứng phù hợp (khuyến nghị sử dụng GPU).
96
+
97
+ ## Tài nguyên
98
+
99
+ - [GitHub Repository](https://github.com/SesameAILabs/csm-1b)
100
+ - [Hugging Face Model](https://huggingface.co/sesame/csm-1b)
101
+ - [Hugging Face Space Demo](https://huggingface.co/spaces/sesame/csm-1b)
102
+ - [Hugging Face Spaces ZeroGPU](https://huggingface.co/docs/hub/spaces-sdks-docker-zero-gpu)
app.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import time
4
+ from typing import List, Tuple
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import torchaudio
9
+ import spaces
10
+ from dataclasses import dataclass
11
+ from generator import Segment, load_csm_1b
12
+ from huggingface_hub import login
13
+
14
+ # Kiểm tra xem có GPU không và cấu hình thiết bị phù hợp
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ print(f"Sử dụng thiết bị: {device}")
17
+
18
+ # Đăng nhập vào Hugging Face Hub nếu có token
19
+ def login_huggingface():
20
+ hf_token = os.environ.get("HF_TOKEN")
21
+ if hf_token:
22
+ print("Đang đăng nhập vào Hugging Face Hub...")
23
+ login(token=hf_token)
24
+ print("Đã đăng nhập thành công!")
25
+ else:
26
+ print("Không tìm thấy HF_TOKEN trong biến môi trường. Một số mô hình có thể không truy cập được.")
27
+
28
+ # Đăng nhập khi khởi động
29
+ login_huggingface()
30
+
31
+ # Tải mô hình CSM-1B
32
+ generator = None
33
+
34
+ def load_model():
35
+ global generator
36
+ if generator is None:
37
+ print("Đang tải mô hình CSM-1B...")
38
+ generator = load_csm_1b(device=device)
39
+ print("Đã tải xong mô hình!")
40
+ return generator
41
+
42
+ # Hàm chuyển đổi âm thanh thành tensor
43
+ def audio_to_tensor(audio_path: str) -> Tuple[torch.Tensor, int]:
44
+ waveform, sample_rate = torchaudio.load(audio_path)
45
+ waveform = waveform.mean(dim=0) # Chuyển stereo thành mono nếu cần
46
+ return waveform, sample_rate
47
+
48
+ # Hàm lưu tensor âm thanh thành file
49
+ def save_audio(audio_tensor: torch.Tensor, sample_rate: int) -> str:
50
+ temp_dir = tempfile.gettempdir()
51
+ output_path = os.path.join(temp_dir, f"csm1b_output_{int(time.time())}.wav")
52
+ torchaudio.save(output_path, audio_tensor.unsqueeze(0), sample_rate)
53
+ return output_path
54
+
55
+ # Hàm tạo âm thanh từ văn bản sử dụng ZeroGPU
56
+ @spaces.GPU
57
+ def generate_speech(
58
+ text: str,
59
+ speaker_id: int,
60
+ context_audio_files: List[Tuple[str, str, int]],
61
+ max_duration_ms: float = 30000,
62
+ temperature: float = 0.9,
63
+ top_k: int = 50,
64
+ progress=gr.Progress()
65
+ ) -> str:
66
+ # Tải mô hình nếu chưa tải
67
+ generator = load_model()
68
+
69
+ # Chuẩn bị ngữ cảnh (context)
70
+ context = []
71
+ progress(0.1, "Đang xử lý ngữ cảnh...")
72
+
73
+ for audio_file, text_content, speaker in context_audio_files:
74
+ if audio_file and text_content:
75
+ waveform, sample_rate = audio_to_tensor(audio_file)
76
+ # Resample nếu cần
77
+ if sample_rate != generator.sample_rate:
78
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=generator.sample_rate)
79
+ context.append(Segment(speaker=speaker, text=text_content, audio=waveform))
80
+
81
+ progress(0.3, "Đang tạo âm thanh...")
82
+ # Tạo âm thanh từ văn bản
83
+ audio = generator.generate(
84
+ text=text,
85
+ speaker=speaker_id,
86
+ context=context,
87
+ max_audio_length_ms=max_duration_ms,
88
+ temperature=temperature,
89
+ topk=top_k
90
+ )
91
+
92
+ progress(0.8, "Đang lưu âm thanh...")
93
+ # Lưu âm thanh thành file
94
+ output_path = save_audio(audio, generator.sample_rate)
95
+
96
+ progress(1.0, "Hoàn thành!")
97
+ return output_path
98
+
99
+ # Tạo giao diện Gradio
100
+ def create_demo():
101
+ with gr.Blocks(title="CSM-1B Text-to-Speech") as demo:
102
+ gr.Markdown("# CSM-1B Text-to-Speech Demo")
103
+ gr.Markdown("Mô hình CSM-1B (Collaborative Speech Model) là một mô hình text-to-speech tiên tiến có khả năng tạo giọng nói tự nhiên từ văn bản.")
104
+
105
+ with gr.Tab("Tạo âm thanh đơn giản"):
106
+ with gr.Row():
107
+ with gr.Column():
108
+ text_input = gr.Textbox(
109
+ label="Văn bản cần chuyển thành giọng nói",
110
+ placeholder="Nhập văn bản bạn muốn chuyển thành giọng nói...",
111
+ lines=5
112
+ )
113
+ speaker_id = gr.Number(
114
+ label="ID người nói",
115
+ value=0,
116
+ precision=0,
117
+ minimum=0,
118
+ maximum=10
119
+ )
120
+
121
+ with gr.Row():
122
+ max_duration = gr.Slider(
123
+ label="Thời lượng tối đa (ms)",
124
+ minimum=1000,
125
+ maximum=90000,
126
+ value=30000,
127
+ step=1000
128
+ )
129
+ temperature = gr.Slider(
130
+ label="Temperature",
131
+ minimum=0.1,
132
+ maximum=1.5,
133
+ value=0.9,
134
+ step=0.1
135
+ )
136
+ top_k = gr.Slider(
137
+ label="Top-K",
138
+ minimum=1,
139
+ maximum=100,
140
+ value=50,
141
+ step=1
142
+ )
143
+
144
+ generate_btn = gr.Button("Tạo âm thanh")
145
+
146
+ with gr.Column():
147
+ output_audio = gr.Audio(label="Âm thanh đầu ra", type="filepath")
148
+
149
+ with gr.Tab("Tạo âm thanh với ngữ cảnh"):
150
+ gr.Markdown("Tính năng này cho phép bạn cung cấp các đoạn âm thanh và văn bản làm ngữ cảnh để mô hình tạo ra âm thanh phù hợp hơn.")
151
+
152
+ with gr.Row():
153
+ with gr.Column():
154
+ context_text1 = gr.Textbox(label="Văn bản ngữ cảnh 1", lines=2)
155
+ context_audio1 = gr.Audio(label="Âm thanh ngữ cảnh 1", type="filepath")
156
+ context_speaker1 = gr.Number(label="ID người nói 1", value=0, precision=0)
157
+
158
+ context_text2 = gr.Textbox(label="Văn bản ngữ cảnh 2", lines=2)
159
+ context_audio2 = gr.Audio(label="Âm thanh ngữ cảnh 2", type="filepath")
160
+ context_speaker2 = gr.Number(label="ID người nói 2", value=1, precision=0)
161
+
162
+ text_input_context = gr.Textbox(
163
+ label="Văn bản cần chuyển thành giọng nói",
164
+ placeholder="Nhập văn bản bạn muốn chuyển thành giọng nói...",
165
+ lines=3
166
+ )
167
+ speaker_id_context = gr.Number(
168
+ label="ID người nói",
169
+ value=0,
170
+ precision=0
171
+ )
172
+
173
+ with gr.Row():
174
+ max_duration_context = gr.Slider(
175
+ label="Thời lượng tối đa (ms)",
176
+ minimum=1000,
177
+ maximum=90000,
178
+ value=30000,
179
+ step=1000
180
+ )
181
+ temperature_context = gr.Slider(
182
+ label="Temperature",
183
+ minimum=0.1,
184
+ maximum=1.5,
185
+ value=0.9,
186
+ step=0.1
187
+ )
188
+ top_k_context = gr.Slider(
189
+ label="Top-K",
190
+ minimum=1,
191
+ maximum=100,
192
+ value=50,
193
+ step=1
194
+ )
195
+
196
+ generate_context_btn = gr.Button("Tạo âm thanh với ngữ cảnh")
197
+
198
+ with gr.Column():
199
+ output_audio_context = gr.Audio(label="Âm thanh đầu ra", type="filepath")
200
+
201
+ # Thêm tab cấu hình Hugging Face
202
+ with gr.Tab("Cấu hình"):
203
+ gr.Markdown("### Cấu hình Hugging Face Token")
204
+ gr.Markdown("""
205
+ Để sử dụng mô hình CSM-1B, bạn cần có quyền truy cập vào mô hình trên Hugging Face.
206
+
207
+ Bạn có thể cấu hình token của mình bằng cách:
208
+ 1. Tạo token tại [Hugging Face Settings](https://huggingface.co/settings/tokens)
209
+ 2. Đặt biến môi trường `HF_TOKEN` với giá trị là token của bạn
210
+
211
+ Lưu ý: Trong Hugging Face Spaces, bạn có thể đặt biến môi trường trong phần Cài đặt của Space.
212
+ """)
213
+
214
+ hf_token_input = gr.Textbox(
215
+ label="Hugging Face Token (Chỉ sử dụng trong phiên này)",
216
+ placeholder="Nhập token của bạn...",
217
+ type="password"
218
+ )
219
+
220
+ def set_token(token):
221
+ if token:
222
+ os.environ["HF_TOKEN"] = token
223
+ login(token=token)
224
+ return "Đã đặt token thành công! Bạn có thể tải mô hình bây giờ."
225
+ return "Token không hợp lệ. Vui lòng nhập token hợp lệ."
226
+
227
+ set_token_btn = gr.Button("Đặt Token")
228
+ token_status = gr.Textbox(label="Trạng thái", interactive=False)
229
+
230
+ set_token_btn.click(fn=set_token, inputs=hf_token_input, outputs=token_status)
231
+
232
+ # Thêm tab thông tin về ZeroGPU
233
+ with gr.Tab("Thông tin GPU"):
234
+ gr.Markdown("### Thông tin về ZeroGPU")
235
+ gr.Markdown("""
236
+ Ứng dụng này sử dụng ZeroGPU của Hugging Face Spaces để tối ưu việc sử dụng GPU.
237
+
238
+ ZeroGPU giúp giải phóng bộ nhớ GPU khi không sử dụng, giúp tiết kiệm tài nguyên và cải thiện hiệu suất.
239
+
240
+ Khi bạn tạo âm thanh, GPU sẽ được sử dụng tự động và giải phóng sau khi hoàn thành.
241
+ """)
242
+
243
+ def check_gpu():
244
+ if torch.cuda.is_available():
245
+ gpu_name = torch.cuda.get_device_name(0)
246
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
247
+ return f"GPU: {gpu_name}\nBộ nhớ: {gpu_memory:.2f} GB"
248
+ else:
249
+ return "Không tìm thấy GPU. Ứng dụng sẽ chạy trên CPU."
250
+
251
+ check_gpu_btn = gr.Button("Kiểm tra GPU")
252
+ gpu_info = gr.Textbox(label="Thông tin GPU", interactive=False)
253
+
254
+ check_gpu_btn.click(fn=check_gpu, inputs=None, outputs=gpu_info)
255
+
256
+ # Kết nối các thành phần
257
+ generate_btn.click(
258
+ fn=generate_speech,
259
+ inputs=[
260
+ text_input,
261
+ speaker_id,
262
+ gr.State([]), # Không có ngữ cảnh
263
+ max_duration,
264
+ temperature,
265
+ top_k
266
+ ],
267
+ outputs=output_audio
268
+ )
269
+
270
+ generate_context_btn.click(
271
+ fn=generate_speech,
272
+ inputs=[
273
+ text_input_context,
274
+ speaker_id_context,
275
+ gr.State([
276
+ (context_audio1, context_text1, context_speaker1),
277
+ (context_audio2, context_text2, context_speaker2)
278
+ ]),
279
+ max_duration_context,
280
+ temperature_context,
281
+ top_k_context
282
+ ],
283
+ outputs=output_audio_context
284
+ )
285
+
286
+ # Tải mô hình khi khởi động
287
+ demo.load(fn=load_model)
288
+
289
+ return demo
290
+
291
+ # Khởi chạy ứng dụng
292
+ if __name__ == "__main__":
293
+ demo = create_demo()
294
+ demo.queue().launch()
generator.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+ import torchaudio
6
+ from huggingface_hub import hf_hub_download, login
7
+ from models import Model
8
+ from moshi.models import loaders
9
+ from tokenizers.processors import TemplateProcessing
10
+ from transformers import AutoTokenizer
11
+ from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
12
+
13
+
14
+ @dataclass
15
+ class Segment:
16
+ speaker: int
17
+ text: str
18
+ # (num_samples,), sample_rate = 24_000
19
+ audio: torch.Tensor
20
+
21
+
22
+ def load_llama3_tokenizer():
23
+ """
24
+ https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
25
+ """
26
+ tokenizer_name = "meta-llama/Llama-3.2-1B"
27
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
28
+ bos = tokenizer.bos_token
29
+ eos = tokenizer.eos_token
30
+ tokenizer._tokenizer.post_processor = TemplateProcessing(
31
+ single=f"{bos}:0 $A:0 {eos}:0",
32
+ pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
33
+ special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
34
+ )
35
+
36
+ return tokenizer
37
+
38
+
39
+ class Generator:
40
+ def __init__(
41
+ self,
42
+ model: Model,
43
+ ):
44
+ self._model = model
45
+ self._model.setup_caches(1)
46
+
47
+ self._text_tokenizer = load_llama3_tokenizer()
48
+
49
+ device = next(model.parameters()).device
50
+ mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
51
+ mimi = loaders.get_mimi(mimi_weight, device=device)
52
+ mimi.set_num_codebooks(32)
53
+ self._audio_tokenizer = mimi
54
+
55
+ self._watermarker = load_watermarker(device=device)
56
+
57
+ self.sample_rate = mimi.sample_rate
58
+ self.device = device
59
+
60
+ def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ frame_tokens = []
62
+ frame_masks = []
63
+
64
+ text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
65
+ text_frame = torch.zeros(len(text_tokens), 33).long()
66
+ text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
67
+ text_frame[:, -1] = torch.tensor(text_tokens)
68
+ text_frame_mask[:, -1] = True
69
+
70
+ frame_tokens.append(text_frame.to(self.device))
71
+ frame_masks.append(text_frame_mask.to(self.device))
72
+
73
+ return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
74
+
75
+ def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
76
+ frame_tokens = []
77
+ frame_masks = []
78
+
79
+ # (K, T)
80
+ audio = audio.to(self.device)
81
+ audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
82
+ # add EOS frame
83
+ eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
84
+ audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
85
+
86
+ audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
87
+ audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
88
+ audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
89
+ audio_frame_mask[:, :-1] = True
90
+
91
+ frame_tokens.append(audio_frame)
92
+ frame_masks.append(audio_frame_mask)
93
+
94
+ return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
95
+
96
+ def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
97
+ """
98
+ Returns:
99
+ (seq_len, 33), (seq_len, 33)
100
+ """
101
+ text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
102
+ audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
103
+
104
+ return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
105
+
106
+ @torch.inference_mode()
107
+ def generate(
108
+ self,
109
+ text: str,
110
+ speaker: int,
111
+ context: List[Segment],
112
+ max_audio_length_ms: float = 90_000,
113
+ temperature: float = 0.9,
114
+ topk: int = 50,
115
+ ) -> torch.Tensor:
116
+ self._model.reset_caches()
117
+
118
+ max_audio_frames = int(max_audio_length_ms / 80)
119
+ tokens, tokens_mask = [], []
120
+ for segment in context:
121
+ segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
122
+ tokens.append(segment_tokens)
123
+ tokens_mask.append(segment_tokens_mask)
124
+
125
+ gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
126
+ tokens.append(gen_segment_tokens)
127
+ tokens_mask.append(gen_segment_tokens_mask)
128
+
129
+ prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
130
+ prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
131
+
132
+ samples = []
133
+ curr_tokens = prompt_tokens.unsqueeze(0)
134
+ curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
135
+ curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
136
+
137
+ max_seq_len = 2048 - max_audio_frames
138
+ if curr_tokens.size(1) >= max_seq_len:
139
+ raise ValueError(f"Inputs too long, must be below max_seq_len - max_audio_frames: {max_seq_len}")
140
+
141
+ for _ in range(max_audio_frames):
142
+ sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
143
+ if torch.all(sample == 0):
144
+ break # eos
145
+
146
+ samples.append(sample)
147
+
148
+ curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
149
+ curr_tokens_mask = torch.cat(
150
+ [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
151
+ ).unsqueeze(1)
152
+ curr_pos = curr_pos[:, -1:] + 1
153
+
154
+ audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
155
+
156
+ # This applies an imperceptible watermark to identify audio as AI-generated.
157
+ # Watermarking ensures transparency, dissuades misuse, and enables traceability.
158
+ # Please be a responsible AI citizen and keep the watermarking in place.
159
+ # If using CSM 1B in another application, use your own private key and keep it secret.
160
+ audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
161
+ audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
162
+
163
+ return audio
164
+
165
+
166
+ def load_csm_1b(device: str = "cuda") -> Generator:
167
+ """
168
+ Tải mô hình CSM-1B từ Hugging Face Hub.
169
+
170
+ Args:
171
+ device: Thiết bị để chạy mô hình (cuda hoặc cpu)
172
+
173
+ Returns:
174
+ Generator: Đối tượng Generator để tạo âm thanh từ văn bản
175
+ """
176
+ try:
177
+ model = Model.from_pretrained("sesame/csm-1b")
178
+ model.to(device=device, dtype=torch.bfloat16)
179
+
180
+ generator = Generator(model)
181
+ return generator
182
+ except Exception as e:
183
+ print(f"Lỗi khi tải mô hình: {e}")
184
+ print("Vui lòng kiểm tra xem bạn đã đăng nhập vào Hugging Face Hub chưa.")
185
+ print("Bạn có thể cần phải yêu cầu quyền truy cập vào mô hình tại: https://huggingface.co/sesame/csm-1b")
186
+ raise e
hf_requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchaudio>=2.0.0
3
+ tokenizers>=0.13.0
4
+ transformers>=4.30.0
5
+ huggingface_hub>=0.16.0
6
+ moshi>=0.2.2
7
+ torchtune>=0.4.0
8
+ torchao>=0.9.0
9
+ silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
10
+ gradio>=4.13.0
11
+ huggingface-hub-spaces>=0.19.0
models.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchtune
6
+ from huggingface_hub import PyTorchModelHubMixin
7
+ from torchtune.models import llama3_2
8
+
9
+
10
+ def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
11
+ return llama3_2.llama3_2(
12
+ vocab_size=128_256,
13
+ num_layers=16,
14
+ num_heads=32,
15
+ num_kv_heads=8,
16
+ embed_dim=2048,
17
+ max_seq_len=2048,
18
+ intermediate_dim=8192,
19
+ attn_dropout=0.0,
20
+ norm_eps=1e-5,
21
+ rope_base=500_000,
22
+ scale_factor=32,
23
+ )
24
+
25
+
26
+ def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
27
+ return llama3_2.llama3_2(
28
+ vocab_size=128_256,
29
+ num_layers=4,
30
+ num_heads=8,
31
+ num_kv_heads=2,
32
+ embed_dim=1024,
33
+ max_seq_len=2048,
34
+ intermediate_dim=8192,
35
+ attn_dropout=0.0,
36
+ norm_eps=1e-5,
37
+ rope_base=500_000,
38
+ scale_factor=32,
39
+ )
40
+
41
+
42
+ FLAVORS = {
43
+ "llama-1B": llama3_2_1B,
44
+ "llama-100M": llama3_2_100M,
45
+ }
46
+
47
+
48
+ def _prepare_transformer(model):
49
+ embed_dim = model.tok_embeddings.embedding_dim
50
+ model.tok_embeddings = nn.Identity()
51
+ model.output = nn.Identity()
52
+ return model, embed_dim
53
+
54
+
55
+ def _create_causal_mask(seq_len: int, device: torch.device):
56
+ return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
57
+
58
+
59
+ def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
60
+ """
61
+ Args:
62
+ mask: (max_seq_len, max_seq_len)
63
+ input_pos: (batch_size, seq_len)
64
+
65
+ Returns:
66
+ (batch_size, seq_len, max_seq_len)
67
+ """
68
+ r = mask[input_pos, :]
69
+ return r
70
+
71
+
72
+ def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization
73
+ q = torch.empty_like(probs).exponential_(1)
74
+ return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
75
+
76
+
77
+ def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
78
+ logits = logits / temperature
79
+
80
+ filter_value: float = -float("Inf")
81
+ indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
82
+ scores_processed = logits.masked_fill(indices_to_remove, filter_value)
83
+ scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
84
+ probs = torch.nn.functional.softmax(scores_processed, dim=-1)
85
+
86
+ sample_token = _multinomial_sample_one_no_sync(probs)
87
+ return sample_token
88
+
89
+
90
+ @dataclass
91
+ class ModelArgs:
92
+ backbone_flavor: str
93
+ decoder_flavor: str
94
+ text_vocab_size: int
95
+ audio_vocab_size: int
96
+ audio_num_codebooks: int
97
+
98
+
99
+ class Model(
100
+ nn.Module,
101
+ PyTorchModelHubMixin,
102
+ repo_url="https://github.com/SesameAILabs/csm",
103
+ pipeline_tag="text-to-speech",
104
+ license="apache-2.0",
105
+ ):
106
+ def __init__(self, config: ModelArgs):
107
+ super().__init__()
108
+ self.config = config
109
+
110
+ self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
111
+ self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
112
+
113
+ self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
114
+ self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
115
+
116
+ self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
117
+ self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
118
+ self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
119
+
120
+ def setup_caches(self, max_batch_size: int) -> torch.Tensor:
121
+ """Setup KV caches and return a causal mask."""
122
+ dtype = next(self.parameters()).dtype
123
+ device = next(self.parameters()).device
124
+
125
+ with device:
126
+ self.backbone.setup_caches(max_batch_size, dtype)
127
+ self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
128
+
129
+ self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
130
+ self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
131
+
132
+ def generate_frame(
133
+ self,
134
+ tokens: torch.Tensor,
135
+ tokens_mask: torch.Tensor,
136
+ input_pos: torch.Tensor,
137
+ temperature: float,
138
+ topk: int,
139
+ ) -> torch.Tensor:
140
+ """
141
+ Args:
142
+ tokens: (batch_size, seq_len, audio_num_codebooks+1)
143
+ tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
144
+ input_pos: (batch_size, seq_len) positions for each token
145
+ mask: (batch_size, seq_len, max_seq_len
146
+
147
+ Returns:
148
+ (batch_size, audio_num_codebooks) sampled tokens
149
+ """
150
+ dtype = next(self.parameters()).dtype
151
+ b, s, _ = tokens.size()
152
+
153
+ assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
154
+ curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
155
+ embeds = self._embed_tokens(tokens)
156
+ masked_embeds = embeds * tokens_mask.unsqueeze(-1)
157
+ h = masked_embeds.sum(dim=2)
158
+ h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
159
+
160
+ last_h = h[:, -1, :]
161
+ c0_logits = self.codebook0_head(last_h)
162
+ c0_sample = sample_topk(c0_logits, topk, temperature)
163
+ c0_embed = self._embed_audio(0, c0_sample)
164
+
165
+ curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
166
+ curr_sample = c0_sample.clone()
167
+ curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
168
+
169
+ # Decoder caches must be reset every frame.
170
+ self.decoder.reset_caches()
171
+ for i in range(1, self.config.audio_num_codebooks):
172
+ curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
173
+ decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
174
+ dtype=dtype
175
+ )
176
+ ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
177
+ ci_sample = sample_topk(ci_logits, topk, temperature)
178
+ ci_embed = self._embed_audio(i, ci_sample)
179
+
180
+ curr_h = ci_embed
181
+ curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
182
+ curr_pos = curr_pos[:, -1:] + 1
183
+
184
+ return curr_sample
185
+
186
+ def reset_caches(self):
187
+ self.backbone.reset_caches()
188
+ self.decoder.reset_caches()
189
+
190
+ def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
191
+ return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
192
+
193
+ def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
194
+ text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
195
+
196
+ audio_tokens = tokens[:, :, :-1] + (
197
+ self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
198
+ )
199
+ audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
200
+ tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
201
+ )
202
+
203
+ return torch.cat([audio_embeds, text_embeds], dim=-2)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.0
2
+ torchaudio==2.4.0
3
+ tokenizers==0.21.0
4
+ transformers==4.49.0
5
+ huggingface_hub==0.28.1
6
+ moshi==0.2.2
7
+ torchtune==0.4.0
8
+ torchao==0.9.0
9
+ silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
10
+ gradio==4.26.0
11
+ huggingface-hub-spaces==0.22.0
test_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import spaces
5
+ from generator import Segment, load_csm_1b
6
+ from huggingface_hub import login
7
+
8
+ def login_huggingface():
9
+ """Đăng nhập vào Hugging Face Hub sử dụng token từ biến môi trường hoặc nhập từ người dùng"""
10
+ hf_token = os.environ.get("HF_TOKEN")
11
+
12
+ if not hf_token:
13
+ print("Không tìm thấy HF_TOKEN trong biến môi trường.")
14
+ hf_token = input("Vui lòng nhập Hugging Face token của bạn: ")
15
+
16
+ if hf_token:
17
+ print("Đang đăng nhập vào Hugging Face Hub...")
18
+ login(token=hf_token)
19
+ print("Đã đăng nhập thành công!")
20
+ return True
21
+ else:
22
+ print("Không có token. Một số mô hình có thể không truy cập được.")
23
+ return False
24
+
25
+ @spaces.GPU
26
+ def generate_test_audio(text, speaker_id, device):
27
+ """Tạo âm thanh kiểm tra sử dụng ZeroGPU"""
28
+ generator = load_csm_1b(device=device)
29
+ print("Đã tải xong mô hình!")
30
+
31
+ print(f"Đang tạo âm thanh cho văn bản: '{text}'")
32
+ audio = generator.generate(
33
+ text=text,
34
+ speaker=speaker_id,
35
+ context=[],
36
+ max_audio_length_ms=10000,
37
+ temperature=0.9,
38
+ topk=50
39
+ )
40
+
41
+ return audio, generator.sample_rate
42
+
43
+ def test_model():
44
+ print("Kiểm tra mô hình CSM-1B...")
45
+
46
+ # Đăng nhập vào Hugging Face Hub
47
+ login_huggingface()
48
+
49
+ # Kiểm tra xem có GPU không và cấu hình thiết bị phù hợp
50
+ device = "cuda" if torch.cuda.is_available() else "cpu"
51
+ print(f"Sử dụng thiết bị: {device}")
52
+
53
+ # Tải mô hình CSM-1B và tạo âm thanh
54
+ print("Đang tải mô hình CSM-1B...")
55
+ try:
56
+ # Sử dụng ZeroGPU để tạo âm thanh
57
+ text = "Xin chào, đây là bài kiểm tra mô hình CSM-1B."
58
+ speaker_id = 0
59
+
60
+ audio, sample_rate = generate_test_audio(text, speaker_id, device)
61
+
62
+ # Lưu âm thanh thành file
63
+ output_path = "test_output.wav"
64
+ torchaudio.save(output_path, audio.unsqueeze(0), sample_rate)
65
+ print(f"Đã lưu âm thanh vào file: {output_path}")
66
+
67
+ print("Kiểm tra hoàn tất!")
68
+ except Exception as e:
69
+ print(f"Lỗi khi kiểm tra mô hình: {e}")
70
+ print("Vui lòng kiểm tra lại token và quyền truy cập của bạn.")
71
+
72
+ if __name__ == "__main__":
73
+ test_model()
watermarking.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import silentcipher
4
+ import torch
5
+ import torchaudio
6
+
7
+ # This watermark key is public, it is not secure.
8
+ # If using CSM 1B in another application, use a new private key and keep it secret.
9
+ CSM_1B_GH_WATERMARK = [212, 211, 146, 56, 201]
10
+
11
+
12
+ def cli_check_audio() -> None:
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--audio_path", type=str, required=True)
15
+ args = parser.parse_args()
16
+
17
+ check_audio_from_file(args.audio_path)
18
+
19
+
20
+ def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
21
+ model = silentcipher.get_model(
22
+ model_type="44.1k",
23
+ device=device,
24
+ )
25
+ return model
26
+
27
+
28
+ @torch.inference_mode()
29
+ def watermark(
30
+ watermarker: silentcipher.server.Model,
31
+ audio_array: torch.Tensor,
32
+ sample_rate: int,
33
+ watermark_key: list[int],
34
+ ) -> tuple[torch.Tensor, int]:
35
+ audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100)
36
+ encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36)
37
+
38
+ output_sample_rate = min(44100, sample_rate)
39
+ encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate)
40
+ return encoded, output_sample_rate
41
+
42
+
43
+ @torch.inference_mode()
44
+ def verify(
45
+ watermarker: silentcipher.server.Model,
46
+ watermarked_audio: torch.Tensor,
47
+ sample_rate: int,
48
+ watermark_key: list[int],
49
+ ) -> bool:
50
+ watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100)
51
+ result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True)
52
+
53
+ is_watermarked = result["status"]
54
+ if is_watermarked:
55
+ is_csm_watermarked = result["messages"][0] == watermark_key
56
+ else:
57
+ is_csm_watermarked = False
58
+
59
+ return is_watermarked and is_csm_watermarked
60
+
61
+
62
+ def check_audio_from_file(audio_path: str) -> None:
63
+ watermarker = load_watermarker(device="cuda")
64
+
65
+ audio_array, sample_rate = load_audio(audio_path)
66
+ is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK)
67
+
68
+ outcome = "Watermarked" if is_watermarked else "Not watermarked"
69
+ print(f"{outcome}: {audio_path}")
70
+
71
+
72
+ def load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
73
+ audio_array, sample_rate = torchaudio.load(audio_path)
74
+ audio_array = audio_array.mean(dim=0)
75
+ return audio_array, int(sample_rate)
76
+
77
+
78
+ if __name__ == "__main__":
79
+ cli_check_audio()