tokusan2 commited on
Commit
8625ea3
·
verified ·
1 Parent(s): fb8bdd5

Add カスタムハンドラー

Browse files
Files changed (1) hide show
  1. handler.py +265 -0
handler.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Style-BERT-VITS2 Custom Handler for Hugging Face Inference Endpoints
3
+ 日本語テキスト読み上げ用カスタムハンドラー
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import logging
9
+ import traceback
10
+ from typing import Dict, List, Any, Optional
11
+ import torch
12
+ import numpy as np
13
+ from io import BytesIO
14
+ import base64
15
+
16
+ # ログ設定
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class EndpointHandler:
21
+ """Style-BERT-VITS2用のカスタムハンドラー"""
22
+
23
+ def __init__(self, path: str = ""):
24
+ """
25
+ ハンドラーの初期化
26
+
27
+ Args:
28
+ path: モデルファイルのパス
29
+ """
30
+ logger.info("Style-BERT-VITS2 Handler初期化開始")
31
+
32
+ try:
33
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ logger.info(f"使用デバイス: {self.device}")
35
+
36
+ # Style-BERT-VITS2の依存関係をインポート
37
+ self._import_dependencies()
38
+
39
+ # モデル初期化
40
+ self._load_model(path)
41
+
42
+ # デフォルト設定
43
+ self.default_config = {
44
+ "speaker_id": 0,
45
+ "emotion": "neutral",
46
+ "speed": 1.0,
47
+ "pitch": 0.0,
48
+ "intonation": 1.0,
49
+ "volume": 1.0,
50
+ "pre_phoneme_length": 0.1,
51
+ "post_phoneme_length": 0.1,
52
+ "sample_rate": 44100
53
+ }
54
+
55
+ logger.info("Handler初期化完了")
56
+
57
+ except Exception as e:
58
+ logger.error(f"Handler初期化エラー: {e}")
59
+ logger.error(traceback.format_exc())
60
+ raise
61
+
62
+ def _import_dependencies(self):
63
+ """必要な依存関係をインポート"""
64
+ try:
65
+ # Style-BERT-VITS2の主要モジュール
66
+ try:
67
+ global style_bert_vits2
68
+ import style_bert_vits2
69
+ self.has_style_bert_vits2 = True
70
+ logger.info("Style-BERT-VITS2依存関係インポート完了")
71
+ except ImportError:
72
+ logger.warning("Style-BERT-VITS2がインストールされていません - モックモードで動作")
73
+ self.has_style_bert_vits2 = False
74
+
75
+ except Exception as e:
76
+ logger.error(f"依存関係インポートエラー: {e}")
77
+ raise
78
+
79
+ def _load_model(self, path: str):
80
+ """モデルをロード"""
81
+ try:
82
+ logger.info(f"モデルロード開始: {path}")
83
+
84
+ # モデル設定ファイルのパス
85
+ config_path = os.path.join(path, "config.json")
86
+ model_path = os.path.join(path, "model.safetensors")
87
+
88
+ if not os.path.exists(config_path):
89
+ logger.warning(f"設定ファイルが見つかりません: {config_path}")
90
+ # デフォルト設定を使用
91
+ self.model_config = self.default_config.copy()
92
+ else:
93
+ with open(config_path, "r", encoding="utf-8") as f:
94
+ self.model_config = json.load(f)
95
+
96
+ # モデルの実際のロード処理
97
+ if self.has_style_bert_vits2:
98
+ # 実際のStyle-BERT-VITS2モデルをロード
99
+ logger.info("実際のStyle-BERT-VITS2モデルロード開始")
100
+ # ここで実際のモデルロード処理を実装
101
+ logger.info("モデルロード完了")
102
+ else:
103
+ # モックモード
104
+ logger.info("モックモードでモデル初期化完了")
105
+
106
+ except Exception as e:
107
+ logger.error(f"モデルロードエラー: {e}")
108
+ raise
109
+
110
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
111
+ """
112
+ 推論実行のメインメソッド
113
+
114
+ Args:
115
+ data: リクエストデータ
116
+ - inputs: テキスト(必須)
117
+ - parameters: 音声生成パラメータ(オプション)
118
+
119
+ Returns:
120
+ 音声データとメタデータのリスト
121
+ """
122
+ try:
123
+ logger.info("推論開始")
124
+
125
+ # 入力データの検証
126
+ inputs = data.get("inputs", "")
127
+ if not inputs or not isinstance(inputs, str):
128
+ raise ValueError("'inputs'に有効なテキストを指定してください")
129
+
130
+ parameters = data.get("parameters", {})
131
+
132
+ # パラメータのマージ
133
+ config = self.default_config.copy()
134
+ config.update(parameters)
135
+
136
+ logger.info(f"入力テキスト: {inputs[:50]}...")
137
+ logger.info(f"使用パラメータ: {config}")
138
+
139
+ # 音声合成実行
140
+ audio_result = self._synthesize_speech(inputs, config)
141
+
142
+ # 結果の準備
143
+ result = [
144
+ {
145
+ "audio_base64": audio_result["audio_base64"],
146
+ "sample_rate": audio_result["sample_rate"],
147
+ "duration": audio_result["duration"],
148
+ "text": inputs,
149
+ "parameters_used": config,
150
+ "model_info": {
151
+ "name": "Style-BERT-VITS2",
152
+ "language": "ja",
153
+ "device": self.device
154
+ }
155
+ }
156
+ ]
157
+
158
+ logger.info("推論完了")
159
+ return result
160
+
161
+ except Exception as e:
162
+ logger.error(f"推論エラー: {e}")
163
+ logger.error(traceback.format_exc())
164
+
165
+ # エラー情報を返す
166
+ return [
167
+ {
168
+ "error": str(e),
169
+ "error_type": type(e).__name__,
170
+ "traceback": traceback.format_exc(),
171
+ "inputs": data.get("inputs", ""),
172
+ "status": "error"
173
+ }
174
+ ]
175
+
176
+ def _synthesize_speech(self, text: str, config: Dict[str, Any]) -> Dict[str, Any]:
177
+ """
178
+ テキストから音声を合成
179
+
180
+ Args:
181
+ text: 合成するテキスト
182
+ config: 音声合成設定
183
+
184
+ Returns:
185
+ 音声データとメタデータ
186
+ """
187
+ try:
188
+ logger.info("音声合成開始")
189
+
190
+ sample_rate = config["sample_rate"]
191
+
192
+ if self.has_style_bert_vits2:
193
+ # 実際のStyle-BERT-VITS2による音声合成
194
+ logger.info("実際のStyle-BERT-VITS2で音声合成実行")
195
+ # ここで実際の音声合成処理を実装
196
+ duration = len(text) * 0.1 # テキスト長に基づく概算時間
197
+ samples = int(sample_rate * duration)
198
+ # 実際の音声データを生成
199
+ audio_data = np.zeros(samples) # プレースホルダー
200
+ else:
201
+ # モックモード - ダミー音声データ(サイン波)
202
+ logger.info("モックモードでダミー音声生成")
203
+ duration = len(text) * 0.1 # テキスト長に基づく概算時間
204
+ samples = int(sample_rate * duration)
205
+ t = np.linspace(0, duration, samples)
206
+ frequency = 440 # A4音程
207
+ audio_data = np.sin(2 * np.pi * frequency * t) * 0.3
208
+
209
+ # 16bit PCMに変換
210
+ audio_int16 = (audio_data * 32767).astype(np.int16)
211
+
212
+ # WAVファイル形式でエンコード
213
+ audio_bytes = self._encode_wav(audio_int16, sample_rate)
214
+
215
+ # Base64エンコード
216
+ audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
217
+
218
+ result = {
219
+ "audio_base64": audio_base64,
220
+ "sample_rate": sample_rate,
221
+ "duration": duration,
222
+ "format": "wav"
223
+ }
224
+
225
+ logger.info(f"音声合成完了 - 時間: {duration:.2f}秒, サンプル数: {samples}")
226
+ return result
227
+
228
+ except Exception as e:
229
+ logger.error(f"音声合成エラー: {e}")
230
+ raise
231
+
232
+ def _encode_wav(self, audio_data: np.ndarray, sample_rate: int) -> bytes:
233
+ """
234
+ 音声データをWAV形式でエンコード
235
+
236
+ Args:
237
+ audio_data: 音声データ(int16)
238
+ sample_rate: サンプリングレート
239
+
240
+ Returns:
241
+ WAVファイルのバイナリデータ
242
+ """
243
+ import struct
244
+ import wave
245
+
246
+ # BytesIOでWAVファイルを作成
247
+ wav_buffer = BytesIO()
248
+
249
+ with wave.open(wav_buffer, 'wb') as wav_file:
250
+ wav_file.setnchannels(1) # モノラル
251
+ wav_file.setsampwidth(2) # 16bit
252
+ wav_file.setframerate(sample_rate)
253
+ wav_file.writeframes(audio_data.tobytes())
254
+
255
+ wav_buffer.seek(0)
256
+ return wav_buffer.read()
257
+
258
+ def health_check(self) -> Dict[str, Any]:
259
+ """ヘルスチェック"""
260
+ return {
261
+ "status": "healthy",
262
+ "model_loaded": True,
263
+ "device": self.device,
264
+ "timestamp": str(torch.tensor([1.0]).item())
265
+ }