Spaces:
Running
Running
| import copy | |
| import logging | |
| from typing import Dict, List, Optional, Union | |
| from lagent.schema import ModelStatusCode | |
| from .base_api import APITemplateParser | |
| from .base_llm import BaseLLM | |
| logger = logging.getLogger(__name__) | |
| class HFTransformer(BaseLLM): | |
| """Model wrapper around HuggingFace general models. | |
| Adapted from Internlm (https://github.com/InternLM/InternLM/blob/main/ | |
| chat/web_demo.py) | |
| Args: | |
| path (str): The name or path to HuggingFace's model. | |
| tokenizer_path (str): The path to the tokenizer. Defaults to None. | |
| tokenizer_kwargs (dict): Keyword arguments for the tokenizer. | |
| Defaults to {}. | |
| tokenizer_only (bool): If True, only the tokenizer will be initialized. | |
| Defaults to False. | |
| model_kwargs (dict): Keyword arguments for the model, used in loader. | |
| Defaults to dict(device_map='auto'). | |
| meta_template (Dict, optional): The model's meta prompt | |
| template if needed, in case the requirement of injecting or | |
| wrapping of any meta instructions. | |
| """ | |
| def __init__(self, | |
| path: str, | |
| tokenizer_path: Optional[str] = None, | |
| tokenizer_kwargs: dict = dict(), | |
| tokenizer_only: bool = False, | |
| model_kwargs: dict = dict(device_map='auto'), | |
| meta_template: Optional[Dict] = None, | |
| stop_words_id: Union[List[int], int] = None, | |
| **kwargs): | |
| super().__init__( | |
| path=path, | |
| tokenizer_only=tokenizer_only, | |
| meta_template=meta_template, | |
| **kwargs) | |
| if isinstance(stop_words_id, int): | |
| stop_words_id = [stop_words_id] | |
| self.gen_params.update(stop_words_id=stop_words_id) | |
| if self.gen_params['stop_words'] is not None and \ | |
| self.gen_params['stop_words_id'] is not None: | |
| logger.warning('Both stop_words and stop_words_id are specified,' | |
| 'only stop_words_id will be used.') | |
| self._load_tokenizer( | |
| path=path, | |
| tokenizer_path=tokenizer_path, | |
| tokenizer_kwargs=tokenizer_kwargs) | |
| if not tokenizer_only: | |
| self._load_model(path=path, model_kwargs=model_kwargs) | |
| from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList # noqa: E501 | |
| self.logits_processor = LogitsProcessorList() | |
| self.stopping_criteria = StoppingCriteriaList() | |
| self.prefix_allowed_tokens_fn = None | |
| stop_words_id = [] | |
| if self.gen_params.get('stop_words_id'): | |
| stop_words_id = self.gen_params.get('stop_words_id') | |
| elif self.gen_params.get('stop_words'): | |
| for sw in self.gen_params.get('stop_words'): | |
| stop_words_id.append(self.tokenizer(sw)['input_ids'][-1]) | |
| self.additional_eos_token_id = stop_words_id | |
| def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], | |
| tokenizer_kwargs: dict): | |
| from transformers import AutoTokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| tokenizer_path if tokenizer_path else path, | |
| trust_remote_code=True, | |
| **tokenizer_kwargs) | |
| if self.tokenizer.pad_token_id is None: | |
| if self.tokenizer.eos_token is not None: | |
| logger.warning( | |
| f'Using eos_token_id {self.tokenizer.eos_token} ' | |
| 'as pad_token_id.') | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| else: | |
| from transformers.generation import GenerationConfig | |
| self.gcfg = GenerationConfig.from_pretrained(path) | |
| if self.gcfg.pad_token_id is not None: | |
| logger.warning( | |
| f'Using pad_token_id {self.gcfg.pad_token_id} ' | |
| 'as pad_token_id.') | |
| self.tokenizer.pad_token_id = self.gcfg.pad_token_id | |
| else: | |
| raise ValueError( | |
| 'pad_token_id is not set for this tokenizer. Try to ' | |
| 'set pad_token_id via passing ' | |
| '`pad_token_id={PAD_TOKEN_ID}` in model_cfg.') | |
| def _load_model(self, path: str, model_kwargs: dict): | |
| import torch | |
| from transformers import AutoModel | |
| model_kwargs.setdefault('torch_dtype', torch.float16) | |
| self.model = AutoModel.from_pretrained( | |
| path, trust_remote_code=True, **model_kwargs) | |
| self.model.eval() | |
| def tokenize(self, inputs: str): | |
| assert isinstance(inputs, str) | |
| inputs = self.tokenizer( | |
| inputs, return_tensors='pt', return_length=True) | |
| return inputs['input_ids'].tolist() | |
| def generate( | |
| self, | |
| inputs: Union[str, List[str]], | |
| do_sample: bool = True, | |
| **kwargs, | |
| ): | |
| """Return the chat completions in non-stream mode. | |
| Args: | |
| inputs (Union[str, List[str]]): input texts to be completed. | |
| do_sample (bool): do sampling if enabled | |
| Returns: | |
| (a list of/batched) text/chat completion | |
| """ | |
| for status, chunk, _ in self.stream_generate(inputs, do_sample, | |
| **kwargs): | |
| response = chunk | |
| return response | |
| def stream_generate( | |
| self, | |
| inputs: List[str], | |
| do_sample: bool = True, | |
| **kwargs, | |
| ): | |
| """Return the chat completions in stream mode. | |
| Args: | |
| inputs (Union[str, List[str]]): input texts to be completed. | |
| do_sample (bool): do sampling if enabled | |
| Returns: | |
| tuple(Status, str, int): status, text/chat completion, | |
| generated token number | |
| """ | |
| import torch | |
| from torch import nn | |
| with torch.no_grad(): | |
| batched = True | |
| if isinstance(inputs, str): | |
| inputs = [inputs] | |
| batched = False | |
| inputs = self.tokenizer( | |
| inputs, padding=True, return_tensors='pt', return_length=True) | |
| input_length = inputs['length'] | |
| for k, v in inputs.items(): | |
| inputs[k] = v.cuda() | |
| input_ids = inputs['input_ids'] | |
| attention_mask = inputs['attention_mask'] | |
| batch_size = input_ids.shape[0] | |
| input_ids_seq_length = input_ids.shape[-1] | |
| generation_config = self.model.generation_config | |
| generation_config = copy.deepcopy(generation_config) | |
| new_gen_params = self.update_gen_params(**kwargs) | |
| generation_config.update(**new_gen_params) | |
| generation_config.update(**kwargs) | |
| model_kwargs = generation_config.to_dict() | |
| model_kwargs['attention_mask'] = attention_mask | |
| _, eos_token_id = ( # noqa: F841 # pylint: disable=W0612 | |
| generation_config.bos_token_id, | |
| generation_config.eos_token_id, | |
| ) | |
| if eos_token_id is None: | |
| if self.gcfg.eos_token_id is not None: | |
| eos_token_id = self.gcfg.eos_token_id | |
| else: | |
| eos_token_id = [] | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| if self.additional_eos_token_id is not None: | |
| eos_token_id.extend(self.additional_eos_token_id) | |
| eos_token_id_tensor = torch.tensor(eos_token_id).to( | |
| input_ids.device) if eos_token_id is not None else None | |
| generation_config.max_length = ( | |
| generation_config.max_new_tokens + input_ids_seq_length) | |
| # Set generation parameters if not already defined | |
| logits_processor = self.logits_processor | |
| stopping_criteria = self.stopping_criteria | |
| logits_processor = self.model._get_logits_processor( | |
| generation_config=generation_config, | |
| input_ids_seq_length=input_ids_seq_length, | |
| encoder_input_ids=input_ids, | |
| prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn, | |
| logits_processor=logits_processor, | |
| ) | |
| stopping_criteria = self.model._get_stopping_criteria( | |
| generation_config=generation_config, | |
| stopping_criteria=stopping_criteria) | |
| logits_warper = self.model._get_logits_warper(generation_config) | |
| unfinished_sequences = input_ids.new(batch_size).fill_(1) | |
| scores = None | |
| while True: | |
| model_inputs = self.model.prepare_inputs_for_generation( | |
| input_ids, **model_kwargs) | |
| # forward pass to get next token | |
| outputs = self.model( | |
| **model_inputs, | |
| return_dict=True, | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| ) | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # pre-process distribution | |
| next_token_scores = logits_processor(input_ids, | |
| next_token_logits) | |
| next_token_scores = logits_warper(input_ids, next_token_scores) | |
| # sample | |
| probs = nn.functional.softmax(next_token_scores, dim=-1) | |
| if do_sample: | |
| next_tokens = torch.multinomial( | |
| probs, num_samples=1).squeeze(1) | |
| else: | |
| next_tokens = torch.argmax(probs, dim=-1) | |
| # update generated ids, model inputs, | |
| # and length for next step | |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], | |
| dim=-1) | |
| model_kwargs = self.model._update_model_kwargs_for_generation( # noqa: E501 | |
| outputs, | |
| model_kwargs, | |
| is_encoder_decoder=False) | |
| unfinished_sequences = unfinished_sequences.mul( | |
| next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne( | |
| eos_token_id_tensor.unsqueeze(1)).prod(dim=0)) | |
| output_token_ids = input_ids.cpu().tolist() | |
| for i in range(len(output_token_ids)): | |
| output_token_ids[i] = output_token_ids[i][:][ | |
| input_length[i]:] | |
| # Find the first occurrence of | |
| # an EOS token in the sequence | |
| first_eos_idx = next( | |
| (idx | |
| for idx, token_id in enumerate(output_token_ids[i]) | |
| if token_id in eos_token_id), None) | |
| # If an EOS token is found, only the previous | |
| # part of it is retained | |
| if first_eos_idx is not None: | |
| output_token_ids[i] = output_token_ids[ | |
| i][:first_eos_idx] | |
| response = self.tokenizer.batch_decode(output_token_ids) | |
| # print(response) | |
| if not batched: | |
| response = response[0] | |
| yield ModelStatusCode.STREAM_ING, response, None | |
| # stop when each sentence is finished, | |
| # or if we exceed the maximum length | |
| if (unfinished_sequences.max() == 0 | |
| or stopping_criteria(input_ids, scores)): | |
| break | |
| yield ModelStatusCode.END, response, None | |
| def stream_chat( | |
| self, | |
| inputs: List[dict], | |
| do_sample: bool = True, | |
| **kwargs, | |
| ): | |
| """Return the chat completions in stream mode. | |
| Args: | |
| inputs (List[dict]): input messages to be completed. | |
| do_sample (bool): do sampling if enabled | |
| Returns: | |
| the text/chat completion | |
| """ | |
| prompt = self.template_parser(inputs) | |
| yield from self.stream_generate(prompt, do_sample, **kwargs) | |
| class HFTransformerCasualLM(HFTransformer): | |
| def _load_model(self, path: str, model_kwargs: dict): | |
| import torch | |
| from transformers import AutoModelForCausalLM | |
| model_kwargs.setdefault('torch_dtype', torch.float16) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| path, trust_remote_code=True, **model_kwargs) | |
| self.model.eval() | |
| class HFTransformerChat(HFTransformerCasualLM): | |
| def __init__(self, template_parser=APITemplateParser, **kwargs): | |
| super().__init__(template_parser=template_parser, **kwargs) | |
| def chat(self, | |
| inputs: Union[List[dict], List[List[dict]]], | |
| do_sample: bool = True, | |
| **kwargs): | |
| """Return the chat completions in stream mode. | |
| Args: | |
| inputs (Union[List[dict], List[List[dict]]]): input messages to be completed. | |
| do_sample (bool): do sampling if enabled | |
| Returns: | |
| the text/chat completion | |
| """ | |
| # handle batch inference with vanilla for loop | |
| if isinstance(inputs[0], list): | |
| resps = [] | |
| for input in inputs: | |
| resps.append(self.chat(input, do_sample, **kwargs)) | |
| return resps | |
| prompt = self.template_parser(inputs) | |
| query = prompt[-1]['content'] | |
| history = prompt[:-1] | |
| try: | |
| response, history = self.model.chat( | |
| self.tokenizer, query, history=history) | |
| except Exception as e: | |
| # handle over-length input error | |
| logger.warning(str(e)) | |
| response = '' | |
| return response | |