Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	File size: 2,470 Bytes
			
			| be6f6cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | """
Modified From https://github.com/XXXXRT666/GPT-SoVITS
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Literal, Optional
import torch
from AR.models.t2s_model_abc import Sampler, T2SDecoderABC
Tensor = torch.Tensor
@dataclass
class T2SResult:
    result: List[Tensor] | None = None
    infer_speed: float = 0.0
    status: Literal["Success", "Error"] = "Success"
    exception: Optional[Exception] = None
    traceback: Optional[str] = None
@dataclass
class T2SRequest:
    x: List[torch.Tensor]
    x_lens: Tensor
    prompts: torch.Tensor
    bert_feature: List[Tensor]
    valid_length: int
    top_k: int = 5
    top_p: float = 1
    early_stop_num: int = -1
    temperature: float = 1.0
    repetition_penalty: float = 1.35
    use_cuda_graph: bool = False
    debug: bool = False
class T2SSession:
    def __init__(self, decoder: T2SDecoderABC, request: T2SRequest, device: torch.device, dtype: torch.dtype):
        with device:
            self.decoder = decoder
            self.request = request
            self.device = device
            self.dtype = dtype
            bsz = len(request.x)
            y_len = request.prompts.size(-1)
            self.bsz = bsz
            self.y_len = y_len
            # Cache
            self.sampler = Sampler(bsz, decoder.vocab_size)
            # Forward args
            self.x = request.x
            self.x_lens = request.x_lens.to(torch.int32)
            self.y = request.prompts
            self.bert_feature = request.bert_feature
            self.prefill_len = self.x_lens + self.y.size(1)
            self.input_pos = torch.zeros_like(self.prefill_len)
            self.input_pos.add_(self.prefill_len)
            # EOS
            self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
            self.y_results: List[Tensor] = [None] * len(self.x)  # type: ignore
            self.xy_pos = decoder.embed(self.x, self.y, self.bert_feature)
            attn_mask = []
            for bs in range(bsz):
                pos = int(self.x_lens[bs].item())
                mask = torch.zeros(pos + y_len, pos + y_len).bool()
                mask[:, :pos].fill_(True)
                if y_len > 0:
                    mask[-y_len:, -y_len:] = ~torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1)
                attn_mask.append(mask)
            self.attn_mask_nested = torch.nested.nested_tensor(attn_mask)
 |