Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Added unconditional generation
Browse files
    	
        app.py
    CHANGED
    
    | @@ -180,29 +180,49 @@ def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content): | |
| 180 | 
             
            def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
         | 
| 181 | 
             
                if len(text) > 150:
         | 
| 182 | 
             
                    return "Rejected, Text too long (should be less than 150 characters)", None
         | 
| 183 | 
            -
                 | 
| 184 | 
            -
             | 
| 185 | 
            -
             | 
| 186 | 
            -
                     | 
| 187 | 
            -
                if not isinstance(wav_pr, torch.FloatTensor):
         | 
| 188 | 
            -
                    wav_pr = torch.FloatTensor(wav_pr)
         | 
| 189 | 
            -
                if wav_pr.abs().max() > 1:
         | 
| 190 | 
            -
                    wav_pr /= wav_pr.abs().max()
         | 
| 191 | 
            -
                if wav_pr.size(-1) == 2:
         | 
| 192 | 
            -
                    wav_pr = wav_pr[:, 0]
         | 
| 193 | 
            -
                if wav_pr.ndim == 1:
         | 
| 194 | 
            -
                    wav_pr = wav_pr.unsqueeze(0)
         | 
| 195 | 
            -
                assert wav_pr.ndim and wav_pr.size(0) == 1
         | 
| 196 | 
            -
             | 
| 197 | 
            -
                if transcript_content == "":
         | 
| 198 | 
            -
                    lang_pr, text_pr = transcribe_one(wav_pr, sr)
         | 
| 199 | 
            -
                    lang_token = lang2token[lang_pr]
         | 
| 200 | 
            -
                    text_pr = lang_token + text_pr + lang_token
         | 
| 201 | 
             
                else:
         | 
| 202 | 
            -
                     | 
| 203 | 
            -
                     | 
| 204 | 
            -
                     | 
| 205 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 206 |  | 
| 207 | 
             
                if language == 'auto-detect':
         | 
| 208 | 
             
                    lang_token = lang2token[langid.classify(text)[0]]
         | 
| @@ -212,13 +232,6 @@ def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, | |
| 212 | 
             
                text = text.replace("\n", "")
         | 
| 213 | 
             
                text = lang_token + text + lang_token
         | 
| 214 |  | 
| 215 | 
            -
                if lang_pr not in ['ja', 'zh', 'en']:
         | 
| 216 | 
            -
                    return f"Reference audio must be a speech of one of model-supported languages, got {lang_pr} instead", None
         | 
| 217 | 
            -
             | 
| 218 | 
            -
                # tokenize audio
         | 
| 219 | 
            -
                encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
         | 
| 220 | 
            -
                audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
         | 
| 221 | 
            -
             | 
| 222 | 
             
                # tokenize text
         | 
| 223 | 
             
                logging.info(f"synthesize text: {text}")
         | 
| 224 | 
             
                phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
         | 
| @@ -228,14 +241,7 @@ def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, | |
| 228 | 
             
                    ]
         | 
| 229 | 
             
                )
         | 
| 230 |  | 
| 231 | 
            -
             | 
| 232 | 
            -
                if text_pr:
         | 
| 233 | 
            -
                    text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
         | 
| 234 | 
            -
                    text_prompts, enroll_x_lens = text_collater(
         | 
| 235 | 
            -
                        [
         | 
| 236 | 
            -
                            text_prompts
         | 
| 237 | 
            -
                        ]
         | 
| 238 | 
            -
                    )
         | 
| 239 | 
             
                text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
         | 
| 240 | 
             
                text_tokens_lens += enroll_x_lens
         | 
| 241 | 
             
                lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
         | 
|  | |
| 180 | 
             
            def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
         | 
| 181 | 
             
                if len(text) > 150:
         | 
| 182 | 
             
                    return "Rejected, Text too long (should be less than 150 characters)", None
         | 
| 183 | 
            +
                if audio_prompt is None and record_audio_prompt is None:
         | 
| 184 | 
            +
                    audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
         | 
| 185 | 
            +
                    text_prompts = torch.zeros([1, 0]).type(torch.int32)
         | 
| 186 | 
            +
                    lang_pr = language if language != 'mix' else 'en'
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 187 | 
             
                else:
         | 
| 188 | 
            +
                    audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
         | 
| 189 | 
            +
                    sr, wav_pr = audio_prompt
         | 
| 190 | 
            +
                    if len(wav_pr) / sr > 15:
         | 
| 191 | 
            +
                        return "Rejected, Audio too long (should be less than 15 seconds)", None
         | 
| 192 | 
            +
                    if not isinstance(wav_pr, torch.FloatTensor):
         | 
| 193 | 
            +
                        wav_pr = torch.FloatTensor(wav_pr)
         | 
| 194 | 
            +
                    if wav_pr.abs().max() > 1:
         | 
| 195 | 
            +
                        wav_pr /= wav_pr.abs().max()
         | 
| 196 | 
            +
                    if wav_pr.size(-1) == 2:
         | 
| 197 | 
            +
                        wav_pr = wav_pr[:, 0]
         | 
| 198 | 
            +
                    if wav_pr.ndim == 1:
         | 
| 199 | 
            +
                        wav_pr = wav_pr.unsqueeze(0)
         | 
| 200 | 
            +
                    assert wav_pr.ndim and wav_pr.size(0) == 1
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    if transcript_content == "":
         | 
| 203 | 
            +
                        lang_pr, text_pr = transcribe_one(wav_pr, sr)
         | 
| 204 | 
            +
                        lang_token = lang2token[lang_pr]
         | 
| 205 | 
            +
                        text_pr = lang_token + text_pr + lang_token
         | 
| 206 | 
            +
                    else:
         | 
| 207 | 
            +
                        lang_pr = langid.classify(str(transcript_content))[0]
         | 
| 208 | 
            +
                        text_pr = transcript_content.replace("\n", "")
         | 
| 209 | 
            +
                        if lang_pr not in ['ja', 'zh', 'en']:
         | 
| 210 | 
            +
                            return f"Reference audio must be a speech of one of model-supported languages, got {lang_pr} instead", None
         | 
| 211 | 
            +
                        lang_token = lang2token[lang_pr]
         | 
| 212 | 
            +
                        text_pr = lang_token + text_pr + lang_token
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    # tokenize audio
         | 
| 215 | 
            +
                    encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
         | 
| 216 | 
            +
                    audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    enroll_x_lens = None
         | 
| 219 | 
            +
                    if text_pr:
         | 
| 220 | 
            +
                        text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
         | 
| 221 | 
            +
                        text_prompts, enroll_x_lens = text_collater(
         | 
| 222 | 
            +
                            [
         | 
| 223 | 
            +
                                text_prompts
         | 
| 224 | 
            +
                            ]
         | 
| 225 | 
            +
                        )
         | 
| 226 |  | 
| 227 | 
             
                if language == 'auto-detect':
         | 
| 228 | 
             
                    lang_token = lang2token[langid.classify(text)[0]]
         | 
|  | |
| 232 | 
             
                text = text.replace("\n", "")
         | 
| 233 | 
             
                text = lang_token + text + lang_token
         | 
| 234 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 235 | 
             
                # tokenize text
         | 
| 236 | 
             
                logging.info(f"synthesize text: {text}")
         | 
| 237 | 
             
                phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
         | 
|  | |
| 241 | 
             
                    ]
         | 
| 242 | 
             
                )
         | 
| 243 |  | 
| 244 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 245 | 
             
                text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
         | 
| 246 | 
             
                text_tokens_lens += enroll_x_lens
         | 
| 247 | 
             
                lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
         | 
