Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria | |
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stable-code-3b", trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
"stabilityai/stable-code-3b", | |
trust_remote_code=True, | |
torch_dtype="auto" | |
) | |
class StopOnTokens(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
stop_ids = [0, 2] | |
for stop_id in stop_ids: | |
if input_ids[0][-1] == stop_id: | |
return True | |
return False | |
def chat(message, history): | |
stop = StopOnTokens() | |
history = history or [] | |
inputs = tokenizer(message, return_tensors="pt").to(model.device) | |
print('generate') | |
tokens = model.generate( | |
**inputs, | |
max_new_tokens=4096, | |
temperature=0.2, | |
do_sample=True, | |
) | |
print('decode') | |
response = tokenizer.decode(tokens[0], skip_special_tokens=True) | |
history.append((message, response)) | |
return history, history | |
iface = gr.Interface( | |
chat, | |
["text", "state"], | |
["chatbot", "state"], | |
allow_flagging="never" | |
) | |
iface.launch() | |