|
from __future__ import annotations |
|
|
|
from collections.abc import Iterable |
|
import os |
|
from typing import Any, Protocol |
|
|
|
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token |
|
import streamlit as st |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
from conversation import Conversation |
|
|
|
TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:' |
|
|
|
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b') |
|
|
|
@st.cache_resource |
|
def get_client() -> Client: |
|
client = HFClient(MODEL_PATH) |
|
return client |
|
|
|
class Client(Protocol): |
|
def generate_stream(self, |
|
system: str | None, |
|
tools: list[dict] | None, |
|
history: list[Conversation], |
|
**parameters: Any |
|
) -> Iterable[TextGenerationStreamResponse]: |
|
... |
|
|
|
def stream_chat(self, tokenizer, query: str, history: list[tuple[str, str]] = None, role: str = "user", |
|
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, |
|
logits_processor=None, return_past_key_values=False, **kwargs): |
|
|
|
from transformers.generation.logits_process import LogitsProcessor |
|
from transformers.generation.utils import LogitsProcessorList |
|
|
|
class InvalidScoreLogitsProcessor(LogitsProcessor): |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
if torch.isnan(scores).any() or torch.isinf(scores).any(): |
|
scores.zero_() |
|
scores[..., 5] = 5e4 |
|
return scores |
|
|
|
if history is None: |
|
history = [] |
|
if logits_processor is None: |
|
logits_processor = LogitsProcessorList() |
|
logits_processor.append(InvalidScoreLogitsProcessor()) |
|
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), |
|
tokenizer.get_command("<|observation|>")] |
|
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, |
|
"temperature": temperature, "logits_processor": logits_processor, **kwargs} |
|
if past_key_values is None: |
|
inputs = tokenizer.build_chat_input(query, history=history, role=role) |
|
else: |
|
inputs = tokenizer.build_chat_input(query, role=role) |
|
inputs = inputs.to(self.device) |
|
if past_key_values is not None: |
|
past_length = past_key_values[0][0].shape[0] |
|
if self.transformer.pre_seq_len is not None: |
|
past_length -= self.transformer.pre_seq_len |
|
inputs.position_ids += past_length |
|
attention_mask = inputs.attention_mask |
|
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) |
|
inputs['attention_mask'] = attention_mask |
|
history.append({"role": role, "content": query}) |
|
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, |
|
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, |
|
**gen_kwargs): |
|
if return_past_key_values: |
|
outputs, past_key_values = outputs |
|
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] |
|
response = tokenizer.decode(outputs) |
|
if response and response[-1] != "�": |
|
new_history = history |
|
if return_past_key_values: |
|
yield response, new_history, past_key_values |
|
else: |
|
yield response, new_history |
|
|
|
class HFClient(Client): |
|
def __init__(self, model_path: str): |
|
self.model_path = model_path |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to( |
|
'cuda' if torch.cuda.is_available() else |
|
'mps' if torch.backends.mps.is_available() else |
|
'cpu' |
|
) |
|
self.model = self.model.eval() |
|
|
|
def generate_stream(self, |
|
system: str | None, |
|
tools: list[dict] | None, |
|
history: list[Conversation], |
|
**parameters: Any |
|
) -> Iterable[TextGenerationStreamResponse]: |
|
chat_history = [{ |
|
'role': 'system', |
|
'content': system if not tools else TOOL_PROMPT, |
|
}] |
|
|
|
if tools: |
|
chat_history[0]['tools'] = tools |
|
|
|
for conversation in history[:-1]: |
|
chat_history.append({ |
|
'role': str(conversation.role).removeprefix('<|').removesuffix('|>'), |
|
'content': conversation.content, |
|
}) |
|
|
|
query = history[-1].content |
|
role = str(history[-1].role).removeprefix('<|').removesuffix('|>') |
|
|
|
text = '' |
|
|
|
for new_text, _ in stream_chat(self.model, |
|
self.tokenizer, |
|
query, |
|
chat_history, |
|
role, |
|
**parameters, |
|
): |
|
word = new_text.removeprefix(text) |
|
word_stripped = word.strip() |
|
text = new_text |
|
yield TextGenerationStreamResponse( |
|
generated_text=text, |
|
token=Token( |
|
id=0, |
|
logprob=0, |
|
text=word, |
|
special=word_stripped.startswith('<|') and word_stripped.endswith('|>'), |
|
) |
|
) |
|
|