Spaces:
Running
Running
A Le Thanh Son
commited on
Commit
·
6d75162
1
Parent(s):
d1fc75f
fix
Browse files- .gitignore +54 -0
- README.md +94 -4
- app.py +294 -0
- generator.py +186 -0
- hf_requirements.txt +11 -0
- models.py +203 -0
- requirements.txt +11 -0
- test_model.py +73 -0
- watermarking.py +79 -0
.gitignore
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
downloads/
|
11 |
+
eggs/
|
12 |
+
.eggs/
|
13 |
+
lib/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
wheels/
|
19 |
+
*.egg-info/
|
20 |
+
.installed.cfg
|
21 |
+
*.egg
|
22 |
+
|
23 |
+
# Environments
|
24 |
+
.env
|
25 |
+
.venv
|
26 |
+
env/
|
27 |
+
venv/
|
28 |
+
ENV/
|
29 |
+
env.bak/
|
30 |
+
venv.bak/
|
31 |
+
|
32 |
+
# Jupyter Notebook
|
33 |
+
.ipynb_checkpoints
|
34 |
+
|
35 |
+
# PyCharm
|
36 |
+
.idea/
|
37 |
+
|
38 |
+
# VS Code
|
39 |
+
.vscode/
|
40 |
+
|
41 |
+
# Temporary files
|
42 |
+
*.tmp
|
43 |
+
*.wav
|
44 |
+
*.mp3
|
45 |
+
*.ogg
|
46 |
+
temp/
|
47 |
+
|
48 |
+
# Logs
|
49 |
+
logs/
|
50 |
+
*.log
|
51 |
+
|
52 |
+
# HuggingFace
|
53 |
+
.cache/
|
54 |
+
huggingface/
|
README.md
CHANGED
@@ -1,12 +1,102 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: indigo
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: CSM-1B Gradio Demo
|
3 |
+
emoji: 🎙️
|
4 |
colorFrom: indigo
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.26.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# CSM-1B Text-to-Speech Demo
|
13 |
+
|
14 |
+
Ứng dụng này sử dụng mô hình CSM-1B (Collaborative Speech Model) để chuyển đổi văn bản thành giọng nói với chất lượng cao.
|
15 |
+
|
16 |
+
## Tính năng
|
17 |
+
|
18 |
+
- **Tạo âm thanh đơn giản**: Chuyển đổi văn bản thành giọng nói với các tùy chọn về ID người nói, thời lượng, temperature và top-k.
|
19 |
+
- **Tạo âm thanh với ngữ cảnh**: Cung cấp các đoạn âm thanh và văn bản làm ngữ cảnh để mô hình tạo ra âm thanh phù hợp hơn.
|
20 |
+
- **Tối ưu GPU**: Sử dụng ZeroGPU của Hugging Face Spaces để tối ưu việc sử dụng GPU.
|
21 |
+
|
22 |
+
## Cài đặt và Cấu hình
|
23 |
+
|
24 |
+
### Yêu cầu truy cập
|
25 |
+
|
26 |
+
Để sử dụng mô hình CSM-1B, bạn cần có quyền truy cập vào các mô hình sau trên Hugging Face:
|
27 |
+
|
28 |
+
- [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)
|
29 |
+
- [sesame/csm-1b](https://huggingface.co/sesame/csm-1b)
|
30 |
+
|
31 |
+
### Cấu hình Hugging Face Token
|
32 |
+
|
33 |
+
1. Tạo tài khoản Hugging Face nếu bạn chưa có.
|
34 |
+
2. Truy cập vào [Hugging Face Settings](https://huggingface.co/settings/tokens) để tạo token.
|
35 |
+
3. Yêu cầu quyền truy cập vào các mô hình nếu cần.
|
36 |
+
4. Đặt biến môi trường `HF_TOKEN` với giá trị là token của bạn:
|
37 |
+
```bash
|
38 |
+
export HF_TOKEN=your_token_here
|
39 |
+
```
|
40 |
+
5. Hoặc bạn có thể nhập token trực tiếp trong tab "Cấu hình" của ứng dụng.
|
41 |
+
|
42 |
+
### Cài đặt
|
43 |
+
|
44 |
+
```bash
|
45 |
+
git clone https://github.com/yourusername/csm-1b-gradio.git
|
46 |
+
cd csm-1b-gradio
|
47 |
+
pip install -r requirements.txt
|
48 |
+
```
|
49 |
+
|
50 |
+
## Cách sử dụng
|
51 |
+
|
52 |
+
1. Khởi động ứng dụng:
|
53 |
+
```bash
|
54 |
+
python app.py
|
55 |
+
```
|
56 |
+
2. Mở trình duyệt web và truy cập địa chỉ được hiển thị (thường là http://127.0.0.1:7860).
|
57 |
+
3. Nhập văn bản bạn muốn chuyển thành giọng nói.
|
58 |
+
4. Chọn ID người nói (từ 0-10).
|
59 |
+
5. Điều chỉnh các tham số như thời lượng tối đa, temperature và top-k.
|
60 |
+
6. Nhấn nút "Tạo âm thanh" để tạo giọng nói.
|
61 |
+
|
62 |
+
## Thông tin về mô hình
|
63 |
+
|
64 |
+
CSM-1B là một mô hình text-to-speech tiên tiến được phát triển bởi Sesame AI Labs. Mô hình này có khả năng tạo giọng nói tự nhiên từ văn bản với nhiều giọng nói khác nhau.
|
65 |
+
|
66 |
+
## ZeroGPU
|
67 |
+
|
68 |
+
Ứng dụng này sử dụng ZeroGPU của Hugging Face Spaces để tối ưu việc sử dụng GPU. ZeroGPU giúp giải phóng bộ nhớ GPU khi không sử dụng, giúp tiết kiệm tài nguyên và cải thiện hiệu suất.
|
69 |
+
|
70 |
+
```python
|
71 |
+
import spaces
|
72 |
+
|
73 |
+
@spaces.GPU
|
74 |
+
def my_gpu_function():
|
75 |
+
# Hàm này sẽ chỉ sử dụng GPU khi được gọi
|
76 |
+
# và giải phóng GPU sau khi hoàn thành
|
77 |
+
pass
|
78 |
+
```
|
79 |
+
|
80 |
+
Khi triển khai trên Hugging Face Spaces, ZeroGPU sẽ tự động quản lý việc sử dụng GPU, giúp ứng dụng hoạt động hiệu quả hơn.
|
81 |
+
|
82 |
+
## Lưu ý
|
83 |
+
|
84 |
+
- Mô hình này sử dụng watermarking để đánh dấu âm thanh được tạo ra bởi AI.
|
85 |
+
- Thời gian tạo âm thanh phụ thuộc vào độ dài văn bản và cấu hình phần cứng.
|
86 |
+
- Bạn cần có quyền truy cập vào mô hình CSM-1B trên Hugging Face để sử dụng ứng dụng này.
|
87 |
+
|
88 |
+
## Triển khai trên Hugging Face Spaces
|
89 |
+
|
90 |
+
Để triển khai ứng dụng này trên Hugging Face Spaces:
|
91 |
+
|
92 |
+
1. Tạo một Space mới trên Hugging Face với SDK là Gradio.
|
93 |
+
2. Tải lên tất cả các file của dự án.
|
94 |
+
3. Trong phần cài đặt của Space, thêm biến môi trường `HF_TOKEN` với giá trị là token của bạn.
|
95 |
+
4. Chọn cấu hình phần cứng phù hợp (khuyến nghị sử dụng GPU).
|
96 |
+
|
97 |
+
## Tài nguyên
|
98 |
+
|
99 |
+
- [GitHub Repository](https://github.com/SesameAILabs/csm-1b)
|
100 |
+
- [Hugging Face Model](https://huggingface.co/sesame/csm-1b)
|
101 |
+
- [Hugging Face Space Demo](https://huggingface.co/spaces/sesame/csm-1b)
|
102 |
+
- [Hugging Face Spaces ZeroGPU](https://huggingface.co/docs/hub/spaces-sdks-docker-zero-gpu)
|
app.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import time
|
4 |
+
from typing import List, Tuple
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
import spaces
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from generator import Segment, load_csm_1b
|
12 |
+
from huggingface_hub import login
|
13 |
+
|
14 |
+
# Kiểm tra xem có GPU không và cấu hình thiết bị phù hợp
|
15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
print(f"Sử dụng thiết bị: {device}")
|
17 |
+
|
18 |
+
# Đăng nhập vào Hugging Face Hub nếu có token
|
19 |
+
def login_huggingface():
|
20 |
+
hf_token = os.environ.get("HF_TOKEN")
|
21 |
+
if hf_token:
|
22 |
+
print("Đang đăng nhập vào Hugging Face Hub...")
|
23 |
+
login(token=hf_token)
|
24 |
+
print("Đã đăng nhập thành công!")
|
25 |
+
else:
|
26 |
+
print("Không tìm thấy HF_TOKEN trong biến môi trường. Một số mô hình có thể không truy cập được.")
|
27 |
+
|
28 |
+
# Đăng nhập khi khởi động
|
29 |
+
login_huggingface()
|
30 |
+
|
31 |
+
# Tải mô hình CSM-1B
|
32 |
+
generator = None
|
33 |
+
|
34 |
+
def load_model():
|
35 |
+
global generator
|
36 |
+
if generator is None:
|
37 |
+
print("Đang tải mô hình CSM-1B...")
|
38 |
+
generator = load_csm_1b(device=device)
|
39 |
+
print("Đã tải xong mô hình!")
|
40 |
+
return generator
|
41 |
+
|
42 |
+
# Hàm chuyển đổi âm thanh thành tensor
|
43 |
+
def audio_to_tensor(audio_path: str) -> Tuple[torch.Tensor, int]:
|
44 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
45 |
+
waveform = waveform.mean(dim=0) # Chuyển stereo thành mono nếu cần
|
46 |
+
return waveform, sample_rate
|
47 |
+
|
48 |
+
# Hàm lưu tensor âm thanh thành file
|
49 |
+
def save_audio(audio_tensor: torch.Tensor, sample_rate: int) -> str:
|
50 |
+
temp_dir = tempfile.gettempdir()
|
51 |
+
output_path = os.path.join(temp_dir, f"csm1b_output_{int(time.time())}.wav")
|
52 |
+
torchaudio.save(output_path, audio_tensor.unsqueeze(0), sample_rate)
|
53 |
+
return output_path
|
54 |
+
|
55 |
+
# Hàm tạo âm thanh từ văn bản sử dụng ZeroGPU
|
56 |
+
@spaces.GPU
|
57 |
+
def generate_speech(
|
58 |
+
text: str,
|
59 |
+
speaker_id: int,
|
60 |
+
context_audio_files: List[Tuple[str, str, int]],
|
61 |
+
max_duration_ms: float = 30000,
|
62 |
+
temperature: float = 0.9,
|
63 |
+
top_k: int = 50,
|
64 |
+
progress=gr.Progress()
|
65 |
+
) -> str:
|
66 |
+
# Tải mô hình nếu chưa tải
|
67 |
+
generator = load_model()
|
68 |
+
|
69 |
+
# Chuẩn bị ngữ cảnh (context)
|
70 |
+
context = []
|
71 |
+
progress(0.1, "Đang xử lý ngữ cảnh...")
|
72 |
+
|
73 |
+
for audio_file, text_content, speaker in context_audio_files:
|
74 |
+
if audio_file and text_content:
|
75 |
+
waveform, sample_rate = audio_to_tensor(audio_file)
|
76 |
+
# Resample nếu cần
|
77 |
+
if sample_rate != generator.sample_rate:
|
78 |
+
waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=generator.sample_rate)
|
79 |
+
context.append(Segment(speaker=speaker, text=text_content, audio=waveform))
|
80 |
+
|
81 |
+
progress(0.3, "Đang tạo âm thanh...")
|
82 |
+
# Tạo âm thanh từ văn bản
|
83 |
+
audio = generator.generate(
|
84 |
+
text=text,
|
85 |
+
speaker=speaker_id,
|
86 |
+
context=context,
|
87 |
+
max_audio_length_ms=max_duration_ms,
|
88 |
+
temperature=temperature,
|
89 |
+
topk=top_k
|
90 |
+
)
|
91 |
+
|
92 |
+
progress(0.8, "Đang lưu âm thanh...")
|
93 |
+
# Lưu âm thanh thành file
|
94 |
+
output_path = save_audio(audio, generator.sample_rate)
|
95 |
+
|
96 |
+
progress(1.0, "Hoàn thành!")
|
97 |
+
return output_path
|
98 |
+
|
99 |
+
# Tạo giao diện Gradio
|
100 |
+
def create_demo():
|
101 |
+
with gr.Blocks(title="CSM-1B Text-to-Speech") as demo:
|
102 |
+
gr.Markdown("# CSM-1B Text-to-Speech Demo")
|
103 |
+
gr.Markdown("Mô hình CSM-1B (Collaborative Speech Model) là một mô hình text-to-speech tiên tiến có khả năng tạo giọng nói tự nhiên từ văn bản.")
|
104 |
+
|
105 |
+
with gr.Tab("Tạo âm thanh đơn giản"):
|
106 |
+
with gr.Row():
|
107 |
+
with gr.Column():
|
108 |
+
text_input = gr.Textbox(
|
109 |
+
label="Văn bản cần chuyển thành giọng nói",
|
110 |
+
placeholder="Nhập văn bản bạn muốn chuyển thành giọng nói...",
|
111 |
+
lines=5
|
112 |
+
)
|
113 |
+
speaker_id = gr.Number(
|
114 |
+
label="ID người nói",
|
115 |
+
value=0,
|
116 |
+
precision=0,
|
117 |
+
minimum=0,
|
118 |
+
maximum=10
|
119 |
+
)
|
120 |
+
|
121 |
+
with gr.Row():
|
122 |
+
max_duration = gr.Slider(
|
123 |
+
label="Thời lượng tối đa (ms)",
|
124 |
+
minimum=1000,
|
125 |
+
maximum=90000,
|
126 |
+
value=30000,
|
127 |
+
step=1000
|
128 |
+
)
|
129 |
+
temperature = gr.Slider(
|
130 |
+
label="Temperature",
|
131 |
+
minimum=0.1,
|
132 |
+
maximum=1.5,
|
133 |
+
value=0.9,
|
134 |
+
step=0.1
|
135 |
+
)
|
136 |
+
top_k = gr.Slider(
|
137 |
+
label="Top-K",
|
138 |
+
minimum=1,
|
139 |
+
maximum=100,
|
140 |
+
value=50,
|
141 |
+
step=1
|
142 |
+
)
|
143 |
+
|
144 |
+
generate_btn = gr.Button("Tạo âm thanh")
|
145 |
+
|
146 |
+
with gr.Column():
|
147 |
+
output_audio = gr.Audio(label="Âm thanh đầu ra", type="filepath")
|
148 |
+
|
149 |
+
with gr.Tab("Tạo âm thanh với ngữ cảnh"):
|
150 |
+
gr.Markdown("Tính năng này cho phép bạn cung cấp các đoạn âm thanh và văn bản làm ngữ cảnh để mô hình tạo ra âm thanh phù hợp hơn.")
|
151 |
+
|
152 |
+
with gr.Row():
|
153 |
+
with gr.Column():
|
154 |
+
context_text1 = gr.Textbox(label="Văn bản ngữ cảnh 1", lines=2)
|
155 |
+
context_audio1 = gr.Audio(label="Âm thanh ngữ cảnh 1", type="filepath")
|
156 |
+
context_speaker1 = gr.Number(label="ID người nói 1", value=0, precision=0)
|
157 |
+
|
158 |
+
context_text2 = gr.Textbox(label="Văn bản ngữ cảnh 2", lines=2)
|
159 |
+
context_audio2 = gr.Audio(label="Âm thanh ngữ cảnh 2", type="filepath")
|
160 |
+
context_speaker2 = gr.Number(label="ID người nói 2", value=1, precision=0)
|
161 |
+
|
162 |
+
text_input_context = gr.Textbox(
|
163 |
+
label="Văn bản cần chuyển thành giọng nói",
|
164 |
+
placeholder="Nhập văn bản bạn muốn chuyển thành giọng nói...",
|
165 |
+
lines=3
|
166 |
+
)
|
167 |
+
speaker_id_context = gr.Number(
|
168 |
+
label="ID người nói",
|
169 |
+
value=0,
|
170 |
+
precision=0
|
171 |
+
)
|
172 |
+
|
173 |
+
with gr.Row():
|
174 |
+
max_duration_context = gr.Slider(
|
175 |
+
label="Thời lượng tối đa (ms)",
|
176 |
+
minimum=1000,
|
177 |
+
maximum=90000,
|
178 |
+
value=30000,
|
179 |
+
step=1000
|
180 |
+
)
|
181 |
+
temperature_context = gr.Slider(
|
182 |
+
label="Temperature",
|
183 |
+
minimum=0.1,
|
184 |
+
maximum=1.5,
|
185 |
+
value=0.9,
|
186 |
+
step=0.1
|
187 |
+
)
|
188 |
+
top_k_context = gr.Slider(
|
189 |
+
label="Top-K",
|
190 |
+
minimum=1,
|
191 |
+
maximum=100,
|
192 |
+
value=50,
|
193 |
+
step=1
|
194 |
+
)
|
195 |
+
|
196 |
+
generate_context_btn = gr.Button("Tạo âm thanh với ngữ cảnh")
|
197 |
+
|
198 |
+
with gr.Column():
|
199 |
+
output_audio_context = gr.Audio(label="Âm thanh đầu ra", type="filepath")
|
200 |
+
|
201 |
+
# Thêm tab cấu hình Hugging Face
|
202 |
+
with gr.Tab("Cấu hình"):
|
203 |
+
gr.Markdown("### Cấu hình Hugging Face Token")
|
204 |
+
gr.Markdown("""
|
205 |
+
Để sử dụng mô hình CSM-1B, bạn cần có quyền truy cập vào mô hình trên Hugging Face.
|
206 |
+
|
207 |
+
Bạn có thể cấu hình token của mình bằng cách:
|
208 |
+
1. Tạo token tại [Hugging Face Settings](https://huggingface.co/settings/tokens)
|
209 |
+
2. Đặt biến môi trường `HF_TOKEN` với giá trị là token của bạn
|
210 |
+
|
211 |
+
Lưu ý: Trong Hugging Face Spaces, bạn có thể đặt biến môi trường trong phần Cài đặt của Space.
|
212 |
+
""")
|
213 |
+
|
214 |
+
hf_token_input = gr.Textbox(
|
215 |
+
label="Hugging Face Token (Chỉ sử dụng trong phiên này)",
|
216 |
+
placeholder="Nhập token của bạn...",
|
217 |
+
type="password"
|
218 |
+
)
|
219 |
+
|
220 |
+
def set_token(token):
|
221 |
+
if token:
|
222 |
+
os.environ["HF_TOKEN"] = token
|
223 |
+
login(token=token)
|
224 |
+
return "Đã đặt token thành công! Bạn có thể tải mô hình bây giờ."
|
225 |
+
return "Token không hợp lệ. Vui lòng nhập token hợp lệ."
|
226 |
+
|
227 |
+
set_token_btn = gr.Button("Đặt Token")
|
228 |
+
token_status = gr.Textbox(label="Trạng thái", interactive=False)
|
229 |
+
|
230 |
+
set_token_btn.click(fn=set_token, inputs=hf_token_input, outputs=token_status)
|
231 |
+
|
232 |
+
# Thêm tab thông tin về ZeroGPU
|
233 |
+
with gr.Tab("Thông tin GPU"):
|
234 |
+
gr.Markdown("### Thông tin về ZeroGPU")
|
235 |
+
gr.Markdown("""
|
236 |
+
Ứng dụng này sử dụng ZeroGPU của Hugging Face Spaces để tối ưu việc sử dụng GPU.
|
237 |
+
|
238 |
+
ZeroGPU giúp giải phóng bộ nhớ GPU khi không sử dụng, giúp tiết kiệm tài nguyên và cải thiện hiệu suất.
|
239 |
+
|
240 |
+
Khi bạn tạo âm thanh, GPU sẽ được sử dụng tự động và giải phóng sau khi hoàn thành.
|
241 |
+
""")
|
242 |
+
|
243 |
+
def check_gpu():
|
244 |
+
if torch.cuda.is_available():
|
245 |
+
gpu_name = torch.cuda.get_device_name(0)
|
246 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
247 |
+
return f"GPU: {gpu_name}\nBộ nhớ: {gpu_memory:.2f} GB"
|
248 |
+
else:
|
249 |
+
return "Không tìm thấy GPU. Ứng dụng sẽ chạy trên CPU."
|
250 |
+
|
251 |
+
check_gpu_btn = gr.Button("Kiểm tra GPU")
|
252 |
+
gpu_info = gr.Textbox(label="Thông tin GPU", interactive=False)
|
253 |
+
|
254 |
+
check_gpu_btn.click(fn=check_gpu, inputs=None, outputs=gpu_info)
|
255 |
+
|
256 |
+
# Kết nối các thành phần
|
257 |
+
generate_btn.click(
|
258 |
+
fn=generate_speech,
|
259 |
+
inputs=[
|
260 |
+
text_input,
|
261 |
+
speaker_id,
|
262 |
+
gr.State([]), # Không có ngữ cảnh
|
263 |
+
max_duration,
|
264 |
+
temperature,
|
265 |
+
top_k
|
266 |
+
],
|
267 |
+
outputs=output_audio
|
268 |
+
)
|
269 |
+
|
270 |
+
generate_context_btn.click(
|
271 |
+
fn=generate_speech,
|
272 |
+
inputs=[
|
273 |
+
text_input_context,
|
274 |
+
speaker_id_context,
|
275 |
+
gr.State([
|
276 |
+
(context_audio1, context_text1, context_speaker1),
|
277 |
+
(context_audio2, context_text2, context_speaker2)
|
278 |
+
]),
|
279 |
+
max_duration_context,
|
280 |
+
temperature_context,
|
281 |
+
top_k_context
|
282 |
+
],
|
283 |
+
outputs=output_audio_context
|
284 |
+
)
|
285 |
+
|
286 |
+
# Tải mô hình khi khởi động
|
287 |
+
demo.load(fn=load_model)
|
288 |
+
|
289 |
+
return demo
|
290 |
+
|
291 |
+
# Khởi chạy ứng dụng
|
292 |
+
if __name__ == "__main__":
|
293 |
+
demo = create_demo()
|
294 |
+
demo.queue().launch()
|
generator.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
from huggingface_hub import hf_hub_download, login
|
7 |
+
from models import Model
|
8 |
+
from moshi.models import loaders
|
9 |
+
from tokenizers.processors import TemplateProcessing
|
10 |
+
from transformers import AutoTokenizer
|
11 |
+
from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class Segment:
|
16 |
+
speaker: int
|
17 |
+
text: str
|
18 |
+
# (num_samples,), sample_rate = 24_000
|
19 |
+
audio: torch.Tensor
|
20 |
+
|
21 |
+
|
22 |
+
def load_llama3_tokenizer():
|
23 |
+
"""
|
24 |
+
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
|
25 |
+
"""
|
26 |
+
tokenizer_name = "meta-llama/Llama-3.2-1B"
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
28 |
+
bos = tokenizer.bos_token
|
29 |
+
eos = tokenizer.eos_token
|
30 |
+
tokenizer._tokenizer.post_processor = TemplateProcessing(
|
31 |
+
single=f"{bos}:0 $A:0 {eos}:0",
|
32 |
+
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
|
33 |
+
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
|
34 |
+
)
|
35 |
+
|
36 |
+
return tokenizer
|
37 |
+
|
38 |
+
|
39 |
+
class Generator:
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
model: Model,
|
43 |
+
):
|
44 |
+
self._model = model
|
45 |
+
self._model.setup_caches(1)
|
46 |
+
|
47 |
+
self._text_tokenizer = load_llama3_tokenizer()
|
48 |
+
|
49 |
+
device = next(model.parameters()).device
|
50 |
+
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
|
51 |
+
mimi = loaders.get_mimi(mimi_weight, device=device)
|
52 |
+
mimi.set_num_codebooks(32)
|
53 |
+
self._audio_tokenizer = mimi
|
54 |
+
|
55 |
+
self._watermarker = load_watermarker(device=device)
|
56 |
+
|
57 |
+
self.sample_rate = mimi.sample_rate
|
58 |
+
self.device = device
|
59 |
+
|
60 |
+
def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
61 |
+
frame_tokens = []
|
62 |
+
frame_masks = []
|
63 |
+
|
64 |
+
text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
|
65 |
+
text_frame = torch.zeros(len(text_tokens), 33).long()
|
66 |
+
text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
|
67 |
+
text_frame[:, -1] = torch.tensor(text_tokens)
|
68 |
+
text_frame_mask[:, -1] = True
|
69 |
+
|
70 |
+
frame_tokens.append(text_frame.to(self.device))
|
71 |
+
frame_masks.append(text_frame_mask.to(self.device))
|
72 |
+
|
73 |
+
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
|
74 |
+
|
75 |
+
def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
76 |
+
frame_tokens = []
|
77 |
+
frame_masks = []
|
78 |
+
|
79 |
+
# (K, T)
|
80 |
+
audio = audio.to(self.device)
|
81 |
+
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
|
82 |
+
# add EOS frame
|
83 |
+
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
|
84 |
+
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
|
85 |
+
|
86 |
+
audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
|
87 |
+
audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
|
88 |
+
audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
|
89 |
+
audio_frame_mask[:, :-1] = True
|
90 |
+
|
91 |
+
frame_tokens.append(audio_frame)
|
92 |
+
frame_masks.append(audio_frame_mask)
|
93 |
+
|
94 |
+
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
|
95 |
+
|
96 |
+
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
|
97 |
+
"""
|
98 |
+
Returns:
|
99 |
+
(seq_len, 33), (seq_len, 33)
|
100 |
+
"""
|
101 |
+
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
|
102 |
+
audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
|
103 |
+
|
104 |
+
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
|
105 |
+
|
106 |
+
@torch.inference_mode()
|
107 |
+
def generate(
|
108 |
+
self,
|
109 |
+
text: str,
|
110 |
+
speaker: int,
|
111 |
+
context: List[Segment],
|
112 |
+
max_audio_length_ms: float = 90_000,
|
113 |
+
temperature: float = 0.9,
|
114 |
+
topk: int = 50,
|
115 |
+
) -> torch.Tensor:
|
116 |
+
self._model.reset_caches()
|
117 |
+
|
118 |
+
max_audio_frames = int(max_audio_length_ms / 80)
|
119 |
+
tokens, tokens_mask = [], []
|
120 |
+
for segment in context:
|
121 |
+
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
|
122 |
+
tokens.append(segment_tokens)
|
123 |
+
tokens_mask.append(segment_tokens_mask)
|
124 |
+
|
125 |
+
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
|
126 |
+
tokens.append(gen_segment_tokens)
|
127 |
+
tokens_mask.append(gen_segment_tokens_mask)
|
128 |
+
|
129 |
+
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
|
130 |
+
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
|
131 |
+
|
132 |
+
samples = []
|
133 |
+
curr_tokens = prompt_tokens.unsqueeze(0)
|
134 |
+
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
|
135 |
+
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
|
136 |
+
|
137 |
+
max_seq_len = 2048 - max_audio_frames
|
138 |
+
if curr_tokens.size(1) >= max_seq_len:
|
139 |
+
raise ValueError(f"Inputs too long, must be below max_seq_len - max_audio_frames: {max_seq_len}")
|
140 |
+
|
141 |
+
for _ in range(max_audio_frames):
|
142 |
+
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
|
143 |
+
if torch.all(sample == 0):
|
144 |
+
break # eos
|
145 |
+
|
146 |
+
samples.append(sample)
|
147 |
+
|
148 |
+
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
|
149 |
+
curr_tokens_mask = torch.cat(
|
150 |
+
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
|
151 |
+
).unsqueeze(1)
|
152 |
+
curr_pos = curr_pos[:, -1:] + 1
|
153 |
+
|
154 |
+
audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
|
155 |
+
|
156 |
+
# This applies an imperceptible watermark to identify audio as AI-generated.
|
157 |
+
# Watermarking ensures transparency, dissuades misuse, and enables traceability.
|
158 |
+
# Please be a responsible AI citizen and keep the watermarking in place.
|
159 |
+
# If using CSM 1B in another application, use your own private key and keep it secret.
|
160 |
+
audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
|
161 |
+
audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
|
162 |
+
|
163 |
+
return audio
|
164 |
+
|
165 |
+
|
166 |
+
def load_csm_1b(device: str = "cuda") -> Generator:
|
167 |
+
"""
|
168 |
+
Tải mô hình CSM-1B từ Hugging Face Hub.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
device: Thiết bị để chạy mô hình (cuda hoặc cpu)
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
Generator: Đối tượng Generator để tạo âm thanh từ văn bản
|
175 |
+
"""
|
176 |
+
try:
|
177 |
+
model = Model.from_pretrained("sesame/csm-1b")
|
178 |
+
model.to(device=device, dtype=torch.bfloat16)
|
179 |
+
|
180 |
+
generator = Generator(model)
|
181 |
+
return generator
|
182 |
+
except Exception as e:
|
183 |
+
print(f"Lỗi khi tải mô hình: {e}")
|
184 |
+
print("Vui lòng kiểm tra xem bạn đã đăng nhập vào Hugging Face Hub chưa.")
|
185 |
+
print("Bạn có thể cần phải yêu cầu quyền truy cập vào mô hình tại: https://huggingface.co/sesame/csm-1b")
|
186 |
+
raise e
|
hf_requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
torchaudio>=2.0.0
|
3 |
+
tokenizers>=0.13.0
|
4 |
+
transformers>=4.30.0
|
5 |
+
huggingface_hub>=0.16.0
|
6 |
+
moshi>=0.2.2
|
7 |
+
torchtune>=0.4.0
|
8 |
+
torchao>=0.9.0
|
9 |
+
silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
|
10 |
+
gradio>=4.13.0
|
11 |
+
huggingface-hub-spaces>=0.19.0
|
models.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torchtune
|
6 |
+
from huggingface_hub import PyTorchModelHubMixin
|
7 |
+
from torchtune.models import llama3_2
|
8 |
+
|
9 |
+
|
10 |
+
def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
|
11 |
+
return llama3_2.llama3_2(
|
12 |
+
vocab_size=128_256,
|
13 |
+
num_layers=16,
|
14 |
+
num_heads=32,
|
15 |
+
num_kv_heads=8,
|
16 |
+
embed_dim=2048,
|
17 |
+
max_seq_len=2048,
|
18 |
+
intermediate_dim=8192,
|
19 |
+
attn_dropout=0.0,
|
20 |
+
norm_eps=1e-5,
|
21 |
+
rope_base=500_000,
|
22 |
+
scale_factor=32,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
|
27 |
+
return llama3_2.llama3_2(
|
28 |
+
vocab_size=128_256,
|
29 |
+
num_layers=4,
|
30 |
+
num_heads=8,
|
31 |
+
num_kv_heads=2,
|
32 |
+
embed_dim=1024,
|
33 |
+
max_seq_len=2048,
|
34 |
+
intermediate_dim=8192,
|
35 |
+
attn_dropout=0.0,
|
36 |
+
norm_eps=1e-5,
|
37 |
+
rope_base=500_000,
|
38 |
+
scale_factor=32,
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
FLAVORS = {
|
43 |
+
"llama-1B": llama3_2_1B,
|
44 |
+
"llama-100M": llama3_2_100M,
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
def _prepare_transformer(model):
|
49 |
+
embed_dim = model.tok_embeddings.embedding_dim
|
50 |
+
model.tok_embeddings = nn.Identity()
|
51 |
+
model.output = nn.Identity()
|
52 |
+
return model, embed_dim
|
53 |
+
|
54 |
+
|
55 |
+
def _create_causal_mask(seq_len: int, device: torch.device):
|
56 |
+
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
|
57 |
+
|
58 |
+
|
59 |
+
def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
|
60 |
+
"""
|
61 |
+
Args:
|
62 |
+
mask: (max_seq_len, max_seq_len)
|
63 |
+
input_pos: (batch_size, seq_len)
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
(batch_size, seq_len, max_seq_len)
|
67 |
+
"""
|
68 |
+
r = mask[input_pos, :]
|
69 |
+
return r
|
70 |
+
|
71 |
+
|
72 |
+
def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization
|
73 |
+
q = torch.empty_like(probs).exponential_(1)
|
74 |
+
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
75 |
+
|
76 |
+
|
77 |
+
def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
|
78 |
+
logits = logits / temperature
|
79 |
+
|
80 |
+
filter_value: float = -float("Inf")
|
81 |
+
indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
|
82 |
+
scores_processed = logits.masked_fill(indices_to_remove, filter_value)
|
83 |
+
scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
|
84 |
+
probs = torch.nn.functional.softmax(scores_processed, dim=-1)
|
85 |
+
|
86 |
+
sample_token = _multinomial_sample_one_no_sync(probs)
|
87 |
+
return sample_token
|
88 |
+
|
89 |
+
|
90 |
+
@dataclass
|
91 |
+
class ModelArgs:
|
92 |
+
backbone_flavor: str
|
93 |
+
decoder_flavor: str
|
94 |
+
text_vocab_size: int
|
95 |
+
audio_vocab_size: int
|
96 |
+
audio_num_codebooks: int
|
97 |
+
|
98 |
+
|
99 |
+
class Model(
|
100 |
+
nn.Module,
|
101 |
+
PyTorchModelHubMixin,
|
102 |
+
repo_url="https://github.com/SesameAILabs/csm",
|
103 |
+
pipeline_tag="text-to-speech",
|
104 |
+
license="apache-2.0",
|
105 |
+
):
|
106 |
+
def __init__(self, config: ModelArgs):
|
107 |
+
super().__init__()
|
108 |
+
self.config = config
|
109 |
+
|
110 |
+
self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
|
111 |
+
self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
|
112 |
+
|
113 |
+
self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
|
114 |
+
self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
|
115 |
+
|
116 |
+
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
|
117 |
+
self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
|
118 |
+
self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
|
119 |
+
|
120 |
+
def setup_caches(self, max_batch_size: int) -> torch.Tensor:
|
121 |
+
"""Setup KV caches and return a causal mask."""
|
122 |
+
dtype = next(self.parameters()).dtype
|
123 |
+
device = next(self.parameters()).device
|
124 |
+
|
125 |
+
with device:
|
126 |
+
self.backbone.setup_caches(max_batch_size, dtype)
|
127 |
+
self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
|
128 |
+
|
129 |
+
self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
|
130 |
+
self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
|
131 |
+
|
132 |
+
def generate_frame(
|
133 |
+
self,
|
134 |
+
tokens: torch.Tensor,
|
135 |
+
tokens_mask: torch.Tensor,
|
136 |
+
input_pos: torch.Tensor,
|
137 |
+
temperature: float,
|
138 |
+
topk: int,
|
139 |
+
) -> torch.Tensor:
|
140 |
+
"""
|
141 |
+
Args:
|
142 |
+
tokens: (batch_size, seq_len, audio_num_codebooks+1)
|
143 |
+
tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
|
144 |
+
input_pos: (batch_size, seq_len) positions for each token
|
145 |
+
mask: (batch_size, seq_len, max_seq_len
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
(batch_size, audio_num_codebooks) sampled tokens
|
149 |
+
"""
|
150 |
+
dtype = next(self.parameters()).dtype
|
151 |
+
b, s, _ = tokens.size()
|
152 |
+
|
153 |
+
assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
|
154 |
+
curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
|
155 |
+
embeds = self._embed_tokens(tokens)
|
156 |
+
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
|
157 |
+
h = masked_embeds.sum(dim=2)
|
158 |
+
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
|
159 |
+
|
160 |
+
last_h = h[:, -1, :]
|
161 |
+
c0_logits = self.codebook0_head(last_h)
|
162 |
+
c0_sample = sample_topk(c0_logits, topk, temperature)
|
163 |
+
c0_embed = self._embed_audio(0, c0_sample)
|
164 |
+
|
165 |
+
curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
|
166 |
+
curr_sample = c0_sample.clone()
|
167 |
+
curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
|
168 |
+
|
169 |
+
# Decoder caches must be reset every frame.
|
170 |
+
self.decoder.reset_caches()
|
171 |
+
for i in range(1, self.config.audio_num_codebooks):
|
172 |
+
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
|
173 |
+
decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
|
174 |
+
dtype=dtype
|
175 |
+
)
|
176 |
+
ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
|
177 |
+
ci_sample = sample_topk(ci_logits, topk, temperature)
|
178 |
+
ci_embed = self._embed_audio(i, ci_sample)
|
179 |
+
|
180 |
+
curr_h = ci_embed
|
181 |
+
curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
|
182 |
+
curr_pos = curr_pos[:, -1:] + 1
|
183 |
+
|
184 |
+
return curr_sample
|
185 |
+
|
186 |
+
def reset_caches(self):
|
187 |
+
self.backbone.reset_caches()
|
188 |
+
self.decoder.reset_caches()
|
189 |
+
|
190 |
+
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
|
191 |
+
return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
|
192 |
+
|
193 |
+
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
194 |
+
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
|
195 |
+
|
196 |
+
audio_tokens = tokens[:, :, :-1] + (
|
197 |
+
self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
|
198 |
+
)
|
199 |
+
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
|
200 |
+
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
|
201 |
+
)
|
202 |
+
|
203 |
+
return torch.cat([audio_embeds, text_embeds], dim=-2)
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.4.0
|
2 |
+
torchaudio==2.4.0
|
3 |
+
tokenizers==0.21.0
|
4 |
+
transformers==4.49.0
|
5 |
+
huggingface_hub==0.28.1
|
6 |
+
moshi==0.2.2
|
7 |
+
torchtune==0.4.0
|
8 |
+
torchao==0.9.0
|
9 |
+
silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
|
10 |
+
gradio==4.26.0
|
11 |
+
huggingface-hub-spaces==0.22.0
|
test_model.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import spaces
|
5 |
+
from generator import Segment, load_csm_1b
|
6 |
+
from huggingface_hub import login
|
7 |
+
|
8 |
+
def login_huggingface():
|
9 |
+
"""Đăng nhập vào Hugging Face Hub sử dụng token từ biến môi trường hoặc nhập từ người dùng"""
|
10 |
+
hf_token = os.environ.get("HF_TOKEN")
|
11 |
+
|
12 |
+
if not hf_token:
|
13 |
+
print("Không tìm thấy HF_TOKEN trong biến môi trường.")
|
14 |
+
hf_token = input("Vui lòng nhập Hugging Face token của bạn: ")
|
15 |
+
|
16 |
+
if hf_token:
|
17 |
+
print("Đang đăng nhập vào Hugging Face Hub...")
|
18 |
+
login(token=hf_token)
|
19 |
+
print("Đã đăng nhập thành công!")
|
20 |
+
return True
|
21 |
+
else:
|
22 |
+
print("Không có token. Một số mô hình có thể không truy cập được.")
|
23 |
+
return False
|
24 |
+
|
25 |
+
@spaces.GPU
|
26 |
+
def generate_test_audio(text, speaker_id, device):
|
27 |
+
"""Tạo âm thanh kiểm tra sử dụng ZeroGPU"""
|
28 |
+
generator = load_csm_1b(device=device)
|
29 |
+
print("Đã tải xong mô hình!")
|
30 |
+
|
31 |
+
print(f"Đang tạo âm thanh cho văn bản: '{text}'")
|
32 |
+
audio = generator.generate(
|
33 |
+
text=text,
|
34 |
+
speaker=speaker_id,
|
35 |
+
context=[],
|
36 |
+
max_audio_length_ms=10000,
|
37 |
+
temperature=0.9,
|
38 |
+
topk=50
|
39 |
+
)
|
40 |
+
|
41 |
+
return audio, generator.sample_rate
|
42 |
+
|
43 |
+
def test_model():
|
44 |
+
print("Kiểm tra mô hình CSM-1B...")
|
45 |
+
|
46 |
+
# Đăng nhập vào Hugging Face Hub
|
47 |
+
login_huggingface()
|
48 |
+
|
49 |
+
# Kiểm tra xem có GPU không và cấu hình thiết bị phù hợp
|
50 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
51 |
+
print(f"Sử dụng thiết bị: {device}")
|
52 |
+
|
53 |
+
# Tải mô hình CSM-1B và tạo âm thanh
|
54 |
+
print("Đang tải mô hình CSM-1B...")
|
55 |
+
try:
|
56 |
+
# Sử dụng ZeroGPU để tạo âm thanh
|
57 |
+
text = "Xin chào, đây là bài kiểm tra mô hình CSM-1B."
|
58 |
+
speaker_id = 0
|
59 |
+
|
60 |
+
audio, sample_rate = generate_test_audio(text, speaker_id, device)
|
61 |
+
|
62 |
+
# Lưu âm thanh thành file
|
63 |
+
output_path = "test_output.wav"
|
64 |
+
torchaudio.save(output_path, audio.unsqueeze(0), sample_rate)
|
65 |
+
print(f"Đã lưu âm thanh vào file: {output_path}")
|
66 |
+
|
67 |
+
print("Kiểm tra hoàn tất!")
|
68 |
+
except Exception as e:
|
69 |
+
print(f"Lỗi khi kiểm tra mô hình: {e}")
|
70 |
+
print("Vui lòng kiểm tra lại token và quyền truy cập của bạn.")
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
test_model()
|
watermarking.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import silentcipher
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
|
7 |
+
# This watermark key is public, it is not secure.
|
8 |
+
# If using CSM 1B in another application, use a new private key and keep it secret.
|
9 |
+
CSM_1B_GH_WATERMARK = [212, 211, 146, 56, 201]
|
10 |
+
|
11 |
+
|
12 |
+
def cli_check_audio() -> None:
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument("--audio_path", type=str, required=True)
|
15 |
+
args = parser.parse_args()
|
16 |
+
|
17 |
+
check_audio_from_file(args.audio_path)
|
18 |
+
|
19 |
+
|
20 |
+
def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
|
21 |
+
model = silentcipher.get_model(
|
22 |
+
model_type="44.1k",
|
23 |
+
device=device,
|
24 |
+
)
|
25 |
+
return model
|
26 |
+
|
27 |
+
|
28 |
+
@torch.inference_mode()
|
29 |
+
def watermark(
|
30 |
+
watermarker: silentcipher.server.Model,
|
31 |
+
audio_array: torch.Tensor,
|
32 |
+
sample_rate: int,
|
33 |
+
watermark_key: list[int],
|
34 |
+
) -> tuple[torch.Tensor, int]:
|
35 |
+
audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100)
|
36 |
+
encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36)
|
37 |
+
|
38 |
+
output_sample_rate = min(44100, sample_rate)
|
39 |
+
encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate)
|
40 |
+
return encoded, output_sample_rate
|
41 |
+
|
42 |
+
|
43 |
+
@torch.inference_mode()
|
44 |
+
def verify(
|
45 |
+
watermarker: silentcipher.server.Model,
|
46 |
+
watermarked_audio: torch.Tensor,
|
47 |
+
sample_rate: int,
|
48 |
+
watermark_key: list[int],
|
49 |
+
) -> bool:
|
50 |
+
watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100)
|
51 |
+
result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True)
|
52 |
+
|
53 |
+
is_watermarked = result["status"]
|
54 |
+
if is_watermarked:
|
55 |
+
is_csm_watermarked = result["messages"][0] == watermark_key
|
56 |
+
else:
|
57 |
+
is_csm_watermarked = False
|
58 |
+
|
59 |
+
return is_watermarked and is_csm_watermarked
|
60 |
+
|
61 |
+
|
62 |
+
def check_audio_from_file(audio_path: str) -> None:
|
63 |
+
watermarker = load_watermarker(device="cuda")
|
64 |
+
|
65 |
+
audio_array, sample_rate = load_audio(audio_path)
|
66 |
+
is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK)
|
67 |
+
|
68 |
+
outcome = "Watermarked" if is_watermarked else "Not watermarked"
|
69 |
+
print(f"{outcome}: {audio_path}")
|
70 |
+
|
71 |
+
|
72 |
+
def load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
|
73 |
+
audio_array, sample_rate = torchaudio.load(audio_path)
|
74 |
+
audio_array = audio_array.mean(dim=0)
|
75 |
+
return audio_array, int(sample_rate)
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
cli_check_audio()
|