kakuguo's picture
Upload 52 files
afd4069
raw
history blame
No virus
5.38 kB
from __future__ import annotations
from collections.abc import Iterable
import os
from typing import Any, Protocol
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token
import streamlit as st
import torch
from transformers import AutoModel, AutoTokenizer
from conversation import Conversation
TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:'
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
@st.cache_resource
def get_client() -> Client:
client = HFClient(MODEL_PATH)
return client
class Client(Protocol):
def generate_stream(self,
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
**parameters: Any
) -> Iterable[TextGenerationStreamResponse]:
...
def stream_chat(self, tokenizer, query: str, history: list[tuple[str, str]] = None, role: str = "user",
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
logits_processor=None, return_past_key_values=False, **kwargs):
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")]
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
if past_key_values is None:
inputs = tokenizer.build_chat_input(query, history=history, role=role)
else:
inputs = tokenizer.build_chat_input(query, role=role)
inputs = inputs.to(self.device)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[0]
if self.transformer.pre_seq_len is not None:
past_length -= self.transformer.pre_seq_len
inputs.position_ids += past_length
attention_mask = inputs.attention_mask
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
inputs['attention_mask'] = attention_mask
history.append({"role": role, "content": query})
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
**gen_kwargs):
if return_past_key_values:
outputs, past_key_values = outputs
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs)
if response and response[-1] != "�":
new_history = history
if return_past_key_values:
yield response, new_history, past_key_values
else:
yield response, new_history
class HFClient(Client):
def __init__(self, model_path: str):
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(
'cuda' if torch.cuda.is_available() else
'mps' if torch.backends.mps.is_available() else
'cpu'
)
self.model = self.model.eval()
def generate_stream(self,
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
**parameters: Any
) -> Iterable[TextGenerationStreamResponse]:
chat_history = [{
'role': 'system',
'content': system if not tools else TOOL_PROMPT,
}]
if tools:
chat_history[0]['tools'] = tools
for conversation in history[:-1]:
chat_history.append({
'role': str(conversation.role).removeprefix('<|').removesuffix('|>'),
'content': conversation.content,
})
query = history[-1].content
role = str(history[-1].role).removeprefix('<|').removesuffix('|>')
text = ''
for new_text, _ in stream_chat(self.model,
self.tokenizer,
query,
chat_history,
role,
**parameters,
):
word = new_text.removeprefix(text)
word_stripped = word.strip()
text = new_text
yield TextGenerationStreamResponse(
generated_text=text,
token=Token(
id=0,
logprob=0,
text=word,
special=word_stripped.startswith('<|') and word_stripped.endswith('|>'),
)
)