Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| from typing import List, Tuple | |
| from openai.types.chat import ChatCompletionMessageParam | |
| from transformers.generation.logits_process import ( | |
| LogitsProcessorList, | |
| RepetitionPenaltyLogitsProcessor, | |
| TemperatureLogitsWarper, | |
| TopKLogitsWarper, | |
| TopPLogitsWarper, | |
| ) | |
| from api.utils.protocol import Role | |
| def parse_messages( | |
| messages: List[ChatCompletionMessageParam], split_role=Role.USER | |
| ) -> Tuple[str, List[List[ChatCompletionMessageParam]]]: | |
| """ | |
| Parse a list of chat completion messages into system and rounds. | |
| Args: | |
| messages (List[ChatCompletionMessageParam]): The list of chat completion messages. | |
| split_role: The role at which to split the rounds. Defaults to Role.USER. | |
| Returns: | |
| Tuple[str, List[List[ChatCompletionMessageParam]]]: A tuple containing the system message and a list of rounds. | |
| """ | |
| system, rounds = "", [] | |
| r = [] | |
| for i, message in enumerate(messages): | |
| if message["role"] == Role.SYSTEM: | |
| system = message["content"] | |
| continue | |
| if message["role"] == split_role and r: | |
| rounds.append(r) | |
| r = [] | |
| r.append(message) | |
| if r: | |
| rounds.append(r) | |
| return system, rounds | |
| def prepare_logits_processor( | |
| temperature: float, repetition_penalty: float, top_p: float, top_k: int | |
| ) -> LogitsProcessorList: | |
| """ | |
| Prepare a list of logits processors based on the provided parameters. | |
| Args: | |
| temperature (float): The temperature value for temperature warping. | |
| repetition_penalty (float): The repetition penalty value. | |
| top_p (float): The top-p value for top-p warping. | |
| top_k (int): The top-k value for top-k warping. | |
| Returns: | |
| LogitsProcessorList: A list of logits processors. | |
| """ | |
| processor_list = LogitsProcessorList() | |
| # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases. | |
| if temperature >= 1e-5 and temperature != 1.0: | |
| processor_list.append(TemperatureLogitsWarper(temperature)) | |
| if repetition_penalty > 1.0: | |
| processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) | |
| if 1e-8 <= top_p < 1.0: | |
| processor_list.append(TopPLogitsWarper(top_p)) | |
| if top_k > 0: | |
| processor_list.append(TopKLogitsWarper(top_k)) | |
| return processor_list | |
| def is_partial_stop(output: str, stop_str: str): | |
| """ Check whether the output contains a partial stop str. """ | |
| return any( | |
| stop_str.startswith(output[-i:]) | |
| for i in range(0, min(len(output), len(stop_str))) | |
| ) | |
| # Models don't use the same configuration key for determining the maximum | |
| # sequence length. Store them here so we can sanely check them. | |
| # NOTE: The ordering here is important. Some models have two of these, and we | |
| # have a preference for which value gets used. | |
| SEQUENCE_LENGTH_KEYS = [ | |
| "max_sequence_length", | |
| "seq_length", | |
| "max_position_embeddings", | |
| "max_seq_len", | |
| "model_max_length", | |
| ] | |
| def get_context_length(config) -> int: | |
| """ Get the context length of a model from a huggingface model config. """ | |
| rope_scaling = getattr(config, "rope_scaling", None) | |
| rope_scaling_factor = config.rope_scaling["factor"] if rope_scaling else 1 | |
| for key in SEQUENCE_LENGTH_KEYS: | |
| val = getattr(config, key, None) | |
| if val is not None: | |
| return int(rope_scaling_factor * val) | |
| return 2048 | |
| def apply_stopping_strings(reply: str, stop_strings: List[str]) -> Tuple[str, bool]: | |
| """ | |
| Apply stopping strings to the reply and check if a stop string is found. | |
| Args: | |
| reply (str): The reply to apply stopping strings to. | |
| stop_strings (List[str]): The list of stopping strings to check for. | |
| Returns: | |
| Tuple[str, bool]: A tuple containing the modified reply and a boolean indicating if a stop string was found. | |
| """ | |
| stop_found = False | |
| for string in stop_strings: | |
| idx = reply.find(string) | |
| if idx != -1: | |
| reply = reply[:idx] | |
| stop_found = True | |
| break | |
| if not stop_found: | |
| # If something like "\nYo" is generated just before "\nYou: is completed, trim it | |
| for string in stop_strings: | |
| for j in range(len(string) - 1, 0, -1): | |
| if reply[-j:] == string[:j]: | |
| reply = reply[:-j] | |
| break | |
| else: | |
| continue | |
| break | |
| return reply, stop_found | |