HusseinBashir commited on
Commit
857794c
·
verified ·
1 Parent(s): 6d901db

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["HF_HOME"] = "/tmp"
4
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp"
5
+ os.environ["TORCH_HOME"] = "/tmp"
6
+ os.environ["XDG_CACHE_HOME"] = "/tmp"
7
+
8
+ import io
9
+ import re
10
+ import math
11
+ import numpy as np
12
+ import scipy.io.wavfile
13
+ import torch
14
+ from fastapi import FastAPI, Query
15
+ from fastapi.responses import StreamingResponse
16
+ from pydantic import BaseModel
17
+ from transformers import VitsModel, AutoTokenizer
18
+
19
+ app = FastAPI()
20
+
21
+ model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
22
+ tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ model.to(device)
26
+ model.eval()
27
+
28
+ number_words = {
29
+ 0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
30
+ 6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
31
+ 11: "toban iyo koow", 12: "toban iyo labo", 13: "toban iyo seddex",
32
+ 14: "toban iyo afar", 15: "toban iyo shan", 16: "toban iyo lix",
33
+ 17: "toban iyo todobo", 18: "toban iyo sideed", 19: "toban iyo sagaal",
34
+ 20: "labaatan", 30: "sodon", 40: "afartan", 50: "konton",
35
+ 60: "lixdan", 70: "todobaatan", 80: "sideetan", 90: "sagaashan",
36
+ 100: "boqol", 1000: "kun"
37
+ }
38
+
39
+ def number_to_words(number: int) -> str:
40
+ if number < 20:
41
+ return number_words[number]
42
+ elif number < 100:
43
+ tens, unit = divmod(number, 10)
44
+ return number_words[tens * 10] + (" iyo " + number_words[unit] if unit else "")
45
+ elif number < 1000:
46
+ hundreds, remainder = divmod(number, 100)
47
+ part = (number_words[hundreds] + " boqol") if hundreds > 1 else "boqol"
48
+ if remainder:
49
+ part += " iyo " + number_to_words(remainder)
50
+ return part
51
+ elif number < 1000000:
52
+ thousands, remainder = divmod(number, 1000)
53
+ words = []
54
+ if thousands == 1:
55
+ words.append("kun")
56
+ else:
57
+ words.append(number_to_words(thousands) + " kun")
58
+ if remainder:
59
+ words.append("iyo " + number_to_words(remainder))
60
+ return " ".join(words)
61
+ elif number < 1000000000:
62
+ millions, remainder = divmod(number, 1000000)
63
+ words = []
64
+ if millions == 1:
65
+ words.append("milyan")
66
+ else:
67
+ words.append(number_to_words(millions) + " milyan")
68
+ if remainder:
69
+ words.append(number_to_words(remainder))
70
+ return " ".join(words)
71
+ else:
72
+ return str(number)
73
+
74
+ def normalize_text(text: str) -> str:
75
+ numbers = re.findall(r'\d+', text)
76
+ for num in numbers:
77
+ text = text.replace(num, number_to_words(int(num)))
78
+ text = text.replace("KH", "qa").replace("Z", "S")
79
+ text = text.replace("SH", "SHa'a").replace("DH", "Dha'a")
80
+ text = text.replace("ZamZam", "SamSam")
81
+ return text
82
+
83
+ def waveform_to_wav_bytes(waveform: torch.Tensor, sample_rate: int = 22050) -> bytes:
84
+ np_waveform = waveform.cpu().numpy()
85
+ if np_waveform.ndim == 3:
86
+ np_waveform = np_waveform[0]
87
+ if np_waveform.ndim == 2:
88
+ np_waveform = np_waveform.mean(axis=0)
89
+ np_waveform = np.clip(np_waveform, -1.0, 1.0).astype(np.float32)
90
+ pcm_waveform = (np_waveform * 32767).astype(np.int16)
91
+ buf = io.BytesIO()
92
+ scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
93
+ buf.seek(0)
94
+ return buf.read()
95
+
96
+ class TextIn(BaseModel):
97
+ inputs: str
98
+
99
+ @app.post("/synthesize")
100
+ async def synthesize_post(data: TextIn):
101
+ text = normalize_text(data.inputs)
102
+ inputs = tokenizer(text, return_tensors="pt").to(device)
103
+ with torch.no_grad():
104
+ output = model(**inputs)
105
+ if hasattr(output, "waveform"):
106
+ waveform = output.waveform
107
+ elif isinstance(output, dict) and "waveform" in output:
108
+ waveform = output["waveform"]
109
+ elif isinstance(output, (tuple, list)):
110
+ waveform = output[0]
111
+ else:
112
+ return {"error": "Waveform not found in model output"}
113
+ sample_rate = getattr(model.config, "sampling_rate", 22050)
114
+ wav_bytes = waveform_to_wav_bytes(waveform, sample_rate=sample_rate)
115
+ return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav")
116
+
117
+ @app.get("/synthesize")
118
+ async def synthesize_get(text: str = Query(..., description="Text to synthesize"), test: bool = Query(False)):
119
+ if test:
120
+ duration_s = 2.0
121
+ sample_rate = 22050
122
+ t = np.linspace(0, duration_s, int(sample_rate * duration_s), endpoint=False)
123
+ freq = 440
124
+ waveform = 0.5 * np.sin(2 * math.pi * freq * t).astype(np.float32)
125
+ pcm_waveform = (waveform * 32767).astype(np.int16)
126
+ buf = io.BytesIO()
127
+ scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
128
+ buf.seek(0)
129
+ return StreamingResponse(buf, media_type="audio/wav")
130
+ normalized = normalize_text(text)
131
+ inputs = tokenizer(normalized, return_tensors="pt").to(device)
132
+ with torch.no_grad():
133
+ output = model(**inputs)
134
+ if hasattr(output, "waveform"):
135
+ waveform = output.waveform
136
+ elif isinstance(output, dict) and "waveform" in output:
137
+ waveform = output["waveform"]
138
+ elif isinstance(output, (tuple, list)):
139
+ waveform = output[0]
140
+ else:
141
+ return {"error": "Waveform not found in model output"}
142
+ sample_rate = getattr(model.config, "sampling_rate", 22050)
143
+ wav_bytes = waveform_to_wav_bytes(waveform, sample_rate=sample_rate)
144
+ return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav")