Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update autoregressive/models/generate.py
Browse files
    	
        autoregressive/models/generate.py
    CHANGED
    
    | @@ -68,6 +68,10 @@ def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sa | |
| 68 | 
             
                # mask = (probs == values).float()
         | 
| 69 | 
             
                # probs = probs * (1 - mask)
         | 
| 70 | 
             
                if sample_logits:
         | 
|  | |
|  | |
|  | |
|  | |
| 71 | 
             
                    idx = torch.multinomial(probs, num_samples=1)
         | 
| 72 | 
             
                else:
         | 
| 73 | 
             
                    _, idx = torch.topk(probs, k=1, dim=-1)
         | 
|  | |
| 68 | 
             
                # mask = (probs == values).float()
         | 
| 69 | 
             
                # probs = probs * (1 - mask)
         | 
| 70 | 
             
                if sample_logits:
         | 
| 71 | 
            +
                    ### add to fix 'nan' and 'inf'
         | 
| 72 | 
            +
                    probs = torch.clamp(probs, min=0, max=None)  
         | 
| 73 | 
            +
                    probs = probs / probs.sum()  
         | 
| 74 | 
            +
                    ###
         | 
| 75 | 
             
                    idx = torch.multinomial(probs, num_samples=1)
         | 
| 76 | 
             
                else:
         | 
| 77 | 
             
                    _, idx = torch.topk(probs, k=1, dim=-1)
         | 
