import time import re import torch import torch.nn.functional as F from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, LogitsProcessor, GenerationConfig, TextIteratorStreamer, ) # --- Helper Function for Input Preparation --- def create_masked_attention(input_ids, target_strings, tokenizer): """ Creates an attention mask where tokens corresponding to any of the target strings have 0 attention. """ # Ensure input_ids is 2D if len(input_ids.shape) == 1: input_ids = input_ids.unsqueeze(0) # Create default attention mask (all 1s) attention_mask = torch.ones_like(input_ids) # Convert single string to list for uniform processing if isinstance(target_strings, str): target_strings = [target_strings] # Get the input IDs as a list input_ids_list = input_ids[0].tolist() # Decode each token individually for comparison token_texts = [] for token_id in input_ids_list: token_texts.append(tokenizer.decode([token_id])) masked_indices = [] # Try tokenizing each target string to find its exact token representation for target_string in target_strings: if not target_string: continue # Tokenize the target string to get its expected token IDs target_ids = tokenizer.encode(target_string, add_special_tokens=False) target_tokens = [tokenizer.decode([id]) for id in target_ids] # First approach: Direct token sequence matching # Look for the sequence of tokens in the input for i in range(len(token_texts) - len(target_tokens) + 1): # Check if this position starts a matching sequence all_match = True for j, target_token in enumerate(target_tokens): if i+j >= len(token_texts) or target_token != token_texts[i+j]: all_match = False break if all_match: for j in range(len(target_tokens)): attention_mask[0, i+j] = 0 masked_indices.append(i+j) # Second approach: Look for individual tokens that make up the target for i, token_text in enumerate(token_texts): if token_text.strip() in target_tokens: attention_mask[0, i] = 0 masked_indices.append(i) # Third approach: If the target is split between tokens, try to detect it # For example 'MASKTOKEN' might be split as ' MASK' and 'TOKEN' if len(target_tokens) == 1 and len(target_tokens[0]) > 2: # Only for substantial single tokens # Look for token pairs that might contain the target for i in range(len(token_texts) - 1): pair = token_texts[i].strip() + token_texts[i+1].strip() if target_string in pair: attention_mask[0, i] = 0 attention_mask[0, i+1] = 0 masked_indices.extend([i, i+1]) # Check for triplet if possible if i < len(token_texts) - 2: triplet = token_texts[i].strip() + token_texts[i+1].strip() + token_texts[i+2].strip() if target_string in triplet: attention_mask[0, i] = 0 attention_mask[0, i+1] = 0 attention_mask[0, i+2] = 0 masked_indices.extend([i, i+1, i+2]) # Print the final mask mask_positions = list(set(masked_indices)) # Remove duplicates mask_positions.sort() if mask_positions: masked_text = [token_texts[idx] for idx in mask_positions] else: print("WARNING: No tokens were masked!") # Last resort - just mask any token containing part of the target for target_string in target_strings: for i, token_text in enumerate(token_texts): if (target_string in token_text) or (token_text.strip() in target_string and len(token_text.strip()) > 2): attention_mask[0, i] = 0 masked_indices.append(i) # Check again mask_positions = list(set(masked_indices)) mask_positions.sort() return attention_mask def preprocess_anchors(anchors): # remove duplicates in anchors anchors = list(set(anchors)) # remove "", " " in anchors anchors = [anchor for anchor in anchors if anchor != "" and anchor != " "] # sort the anchors by length anchors = sorted(anchors, key=len, reverse=True) return anchors # Define a wrapper function to handle different cases # The provided anchors are viewed as global anchors def format_spa_input(input, anchors, mask_token, whole_word_only=True): # check if the input is a string or a list of messages if isinstance(input, str): # 1. Collect all anchors current_anchors = list(anchors) # Start with global anchors tag_anchors = [] if re.search(r"", input): tag_anchors = re.findall(r"(.*?)", input, flags=re.DOTALL) current_anchors.extend(tag_anchors) # 2. Clean the input string (remove tags) cleaned_input = re.sub(r"|", "", input) # 3. Preprocess all collected anchors (unique, non-empty, sorted desc) final_anchors = preprocess_anchors(current_anchors) # 4. Escape anchors for regex and build pattern (longest first) masked_input = cleaned_input # Initialize with cleaned input if final_anchors: if whole_word_only: # Use lookarounds to assert boundaries without consuming them (Fix 1) escaped_anchors = [rf"(? mask_token # merge_pattern_no_space = f"{escaped_mask_token}{escaped_mask_token}" # while re.search(merge_pattern_no_space, masked_input): # masked_input = re.sub(merge_pattern_no_space, mask_token, masked_input) return cleaned_input, masked_input elif isinstance(input, list): cleaned_input_list = [] masked_input_list = [] for msg in input: msg_copy = msg.copy() # Work on a copy content = msg_copy.get("content", "") # 1. Collect all anchors for this message current_anchors = list(anchors) # Start with global anchors if "anchors" in msg_copy: dict_anchors = msg_copy.get("anchors", []) if isinstance(dict_anchors, list): current_anchors.extend(dict_anchors) tag_anchors = [] if re.search(r"", content): tag_anchors = re.findall(r"(.*?)", content, flags=re.DOTALL) current_anchors.extend(tag_anchors) # 2. Clean the message content (remove tags) cleaned_content = re.sub(r"|", "", content) # 3. Preprocess all collected anchors for this message final_anchors = preprocess_anchors(current_anchors) # 4. Escape anchors, build pattern, and replace in one pass masked_content = cleaned_content # Initialize if final_anchors: if whole_word_only: # Use lookarounds to assert boundaries without consuming them (Fix 1) escaped_anchors = [rf"(? 1 (that's what can cause random behavior. If -1 < strength < 1, it is semantic dimishment, disable this for more precise control) if self.modulated_by_prob and (self.strength > 1 or self.strength < -1): # Convert logits to probabilities with temperature scaling for stability temperature = 1.0 scaled_logits = scores / temperature main_probs = F.softmax(scaled_logits, dim=-1) # Clamp probabilities to avoid numerical issues main_probs = torch.clamp(main_probs, min=1e-6, max=1.0) # Each token's weight is scaled by its probability # get the max probability max_prob = torch.max(main_probs) # normalize the base weight by the max probability base_weight = base_weight / max_prob # get different weights for each token based on their main probability token_weights = base_weight * main_probs # Apply the weighted adjustment adjustment = token_weights * diff # Clamp the adjustment to avoid extreme values adjustment = torch.clamp(adjustment, min=-1e2, max=1e2) # Compute final scores final_scores = scores + adjustment else: # Safe computation of weighted difference weighted_diff = base_weight * diff # Check for and handle any NaNs that might have appeared weighted_diff = torch.nan_to_num(weighted_diff, nan=0.0) # Clamp to avoid extreme values weighted_diff = torch.clamp(weighted_diff, min=-1e3, max=1e3) final_scores = scores + weighted_diff # Final stability check final_scores = torch.clamp(final_scores, min=-1e3, max=1e3) return final_scores