|
import time |
|
import re |
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
BitsAndBytesConfig, |
|
LogitsProcessor, |
|
GenerationConfig, |
|
TextIteratorStreamer, |
|
) |
|
|
|
|
|
|
|
def create_masked_attention(input_ids, target_strings, tokenizer): |
|
""" |
|
Creates an attention mask where tokens corresponding to any of the target strings have 0 attention. |
|
""" |
|
|
|
if len(input_ids.shape) == 1: |
|
input_ids = input_ids.unsqueeze(0) |
|
|
|
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
|
|
if isinstance(target_strings, str): |
|
target_strings = [target_strings] |
|
|
|
|
|
input_ids_list = input_ids[0].tolist() |
|
|
|
|
|
token_texts = [] |
|
for token_id in input_ids_list: |
|
token_texts.append(tokenizer.decode([token_id])) |
|
|
|
|
|
|
|
masked_indices = [] |
|
|
|
|
|
for target_string in target_strings: |
|
if not target_string: |
|
continue |
|
|
|
|
|
target_ids = tokenizer.encode(target_string, add_special_tokens=False) |
|
target_tokens = [tokenizer.decode([id]) for id in target_ids] |
|
|
|
|
|
|
|
|
|
for i in range(len(token_texts) - len(target_tokens) + 1): |
|
|
|
all_match = True |
|
for j, target_token in enumerate(target_tokens): |
|
if i+j >= len(token_texts) or target_token != token_texts[i+j]: |
|
all_match = False |
|
break |
|
|
|
if all_match: |
|
for j in range(len(target_tokens)): |
|
attention_mask[0, i+j] = 0 |
|
masked_indices.append(i+j) |
|
|
|
|
|
for i, token_text in enumerate(token_texts): |
|
if token_text.strip() in target_tokens: |
|
attention_mask[0, i] = 0 |
|
masked_indices.append(i) |
|
|
|
|
|
|
|
if len(target_tokens) == 1 and len(target_tokens[0]) > 2: |
|
|
|
for i in range(len(token_texts) - 1): |
|
pair = token_texts[i].strip() + token_texts[i+1].strip() |
|
if target_string in pair: |
|
attention_mask[0, i] = 0 |
|
attention_mask[0, i+1] = 0 |
|
masked_indices.extend([i, i+1]) |
|
|
|
|
|
if i < len(token_texts) - 2: |
|
triplet = token_texts[i].strip() + token_texts[i+1].strip() + token_texts[i+2].strip() |
|
if target_string in triplet: |
|
attention_mask[0, i] = 0 |
|
attention_mask[0, i+1] = 0 |
|
attention_mask[0, i+2] = 0 |
|
masked_indices.extend([i, i+1, i+2]) |
|
|
|
|
|
|
|
mask_positions = list(set(masked_indices)) |
|
mask_positions.sort() |
|
|
|
if mask_positions: |
|
masked_text = [token_texts[idx] for idx in mask_positions] |
|
else: |
|
print("WARNING: No tokens were masked!") |
|
|
|
for target_string in target_strings: |
|
for i, token_text in enumerate(token_texts): |
|
if (target_string in token_text) or (token_text.strip() in target_string and len(token_text.strip()) > 2): |
|
attention_mask[0, i] = 0 |
|
masked_indices.append(i) |
|
|
|
|
|
mask_positions = list(set(masked_indices)) |
|
mask_positions.sort() |
|
|
|
return attention_mask |
|
|
|
|
|
def preprocess_anchors(anchors): |
|
|
|
anchors = list(set(anchors)) |
|
|
|
anchors = [anchor for anchor in anchors if anchor != "" and anchor != " "] |
|
|
|
anchors = sorted(anchors, key=len, reverse=True) |
|
return anchors |
|
|
|
|
|
|
|
|
|
def format_spa_input(input, anchors, mask_token, whole_word_only=True): |
|
|
|
if isinstance(input, str): |
|
|
|
current_anchors = list(anchors) |
|
tag_anchors = [] |
|
if re.search(r"<anchor>", input): |
|
tag_anchors = re.findall(r"<anchor>(.*?)</anchor>", input, flags=re.DOTALL) |
|
current_anchors.extend(tag_anchors) |
|
|
|
|
|
cleaned_input = re.sub(r"<anchor>|</anchor>", "", input) |
|
|
|
|
|
final_anchors = preprocess_anchors(current_anchors) |
|
|
|
|
|
masked_input = cleaned_input |
|
if final_anchors: |
|
if whole_word_only: |
|
|
|
escaped_anchors = [rf"(?<!\w){re.escape(a)}(?!\w)" for a in final_anchors] |
|
else: |
|
escaped_anchors = [re.escape(a) for a in final_anchors] |
|
|
|
pattern = "|".join(escaped_anchors) |
|
|
|
masked_input = re.sub(pattern, mask_token, cleaned_input) |
|
|
|
|
|
if mask_token: |
|
escaped_mask_token = re.escape(mask_token) |
|
|
|
merge_pattern = f"{escaped_mask_token}\s+{escaped_mask_token}" |
|
while re.search(merge_pattern, masked_input): |
|
masked_input = re.sub(merge_pattern, mask_token, masked_input) |
|
|
|
|
|
|
|
|
|
|
|
return cleaned_input, masked_input |
|
|
|
elif isinstance(input, list): |
|
cleaned_input_list = [] |
|
masked_input_list = [] |
|
|
|
for msg in input: |
|
msg_copy = msg.copy() |
|
content = msg_copy.get("content", "") |
|
|
|
|
|
current_anchors = list(anchors) |
|
if "anchors" in msg_copy: |
|
dict_anchors = msg_copy.get("anchors", []) |
|
if isinstance(dict_anchors, list): |
|
current_anchors.extend(dict_anchors) |
|
tag_anchors = [] |
|
if re.search(r"<anchor>", content): |
|
tag_anchors = re.findall(r"<anchor>(.*?)</anchor>", content, flags=re.DOTALL) |
|
current_anchors.extend(tag_anchors) |
|
|
|
|
|
cleaned_content = re.sub(r"<anchor>|</anchor>", "", content) |
|
|
|
|
|
final_anchors = preprocess_anchors(current_anchors) |
|
|
|
|
|
masked_content = cleaned_content |
|
if final_anchors: |
|
if whole_word_only: |
|
|
|
escaped_anchors = [rf"(?<!\w){re.escape(a)}(?!\w)" for a in final_anchors] |
|
else: |
|
escaped_anchors = [re.escape(a) for a in final_anchors] |
|
|
|
pattern = "|".join(escaped_anchors) |
|
masked_content = re.sub(pattern, mask_token, cleaned_content) |
|
|
|
|
|
if mask_token: |
|
escaped_mask_token = re.escape(mask_token) |
|
|
|
merge_pattern = f"{escaped_mask_token}\s+{escaped_mask_token}" |
|
while re.search(merge_pattern, masked_content): |
|
masked_content = re.sub(merge_pattern, mask_token, masked_content) |
|
|
|
|
|
|
|
|
|
|
|
|
|
final_cleaned_msg = msg_copy.copy() |
|
final_cleaned_msg["content"] = cleaned_content |
|
if "anchors" in final_cleaned_msg: |
|
del final_cleaned_msg["anchors"] |
|
|
|
final_masked_msg = msg_copy.copy() |
|
final_masked_msg["content"] = masked_content |
|
if "anchors" in final_masked_msg: |
|
del final_masked_msg["anchors"] |
|
|
|
cleaned_input_list.append(final_cleaned_msg) |
|
masked_input_list.append(final_masked_msg) |
|
|
|
return cleaned_input_list, masked_input_list |
|
else: |
|
raise ValueError("Invalid input type. Must be string or list of dictionaries.") |
|
|
|
|
|
def get_mask_messages(messages, mask_token): |
|
mask_msg = messages.copy() |
|
|
|
|
|
for msg in mask_msg: |
|
if "anchors" in msg: |
|
|
|
original_content = msg["content"] |
|
|
|
|
|
anchors = sorted(msg["anchors"], key=len, reverse=True) |
|
|
|
for anchor in anchors: |
|
if anchor in msg["content"]: |
|
|
|
msg["content"] = msg["content"].replace(anchor, mask_token) |
|
|
|
|
|
if original_content == msg["content"]: |
|
print(f"WARNING: No anchors were replaced in message: {original_content[:50]}...") |
|
print(f"Anchors: {anchors}") |
|
|
|
return mask_msg |
|
|
|
|
|
def convert_to_tensor_format(inputs, device=None): |
|
|
|
if isinstance(inputs, torch.Tensor) and len(inputs.shape) == 2: |
|
if device is not None: |
|
inputs = inputs.to(device) |
|
return inputs |
|
|
|
|
|
if hasattr(inputs, 'input_ids'): |
|
inputs = inputs.input_ids |
|
|
|
|
|
elif isinstance(inputs, dict) and 'input_ids' in inputs: |
|
inputs = inputs['input_ids'] |
|
|
|
|
|
elif isinstance(inputs, list): |
|
inputs = torch.tensor([inputs], device=device) |
|
|
|
|
|
elif isinstance(inputs, torch.Tensor): |
|
if len(inputs.shape) == 1: |
|
inputs = inputs.unsqueeze(0) |
|
|
|
|
|
if isinstance(inputs, torch.Tensor) and device is not None: |
|
inputs = inputs.to(device) |
|
|
|
return inputs |
|
|
|
def create_default_attention_mask(input_ids, device=None): |
|
""" |
|
Creates a default attention mask (all 1s) for the given input_ids tensor. |
|
|
|
Args: |
|
input_ids (torch.Tensor): The input IDs tensor, shape (batch_size, seq_len) |
|
device: The device to place the attention mask on |
|
|
|
Returns: |
|
torch.Tensor: Attention mask with the same shape as input_ids, all values set to 1 |
|
""" |
|
|
|
if device is not None and input_ids.device != device: |
|
input_ids = input_ids.to(device) |
|
|
|
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
return attention_mask |
|
|
|
def spa_tokenize(prompt_with_anchors, global_anchors, tokenizer, device): |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
print("Setting pad token to EOS token") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
if tokenizer.mask_token: |
|
mask_token = tokenizer.mask_token |
|
else: |
|
mask_token = "MASKTOKEN" |
|
|
|
|
|
main_prompt, aux_prompt = format_spa_input( |
|
input=prompt_with_anchors, |
|
anchors=global_anchors, |
|
mask_token=mask_token, |
|
whole_word_only=False |
|
) |
|
|
|
|
|
|
|
if isinstance(main_prompt, list): |
|
|
|
|
|
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template: |
|
|
|
main_inputs = tokenizer.apply_chat_template( |
|
main_prompt, |
|
tokenize=True, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(device) |
|
|
|
aux_inputs = tokenizer.apply_chat_template( |
|
aux_prompt, |
|
tokenize=True, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(device) |
|
|
|
else: |
|
|
|
|
|
flat_prompt_main = "" |
|
for msg in main_prompt: |
|
flat_prompt_main += f"{msg['role']}: {msg['content']}\n" |
|
flat_prompt_main += "Assistant: " |
|
|
|
flat_prompt_aux = "" |
|
for msg in aux_prompt: |
|
flat_prompt_aux += f"{msg['role']}: {msg['content']}\n" |
|
flat_prompt_aux += "Assistant: " |
|
|
|
|
|
main_inputs = tokenizer(flat_prompt_main, return_tensors="pt").to(device) |
|
aux_inputs = tokenizer(flat_prompt_aux, return_tensors="pt").to(device) |
|
|
|
|
|
elif isinstance(prompt_with_anchors, str): |
|
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template: |
|
|
|
|
|
|
|
main_prompt = [{"role": "user", "content": main_prompt}] |
|
aux_prompt = [{"role": "user", "content": aux_prompt}] |
|
|
|
main_inputs = tokenizer.apply_chat_template( |
|
main_prompt, |
|
tokenize=True, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(device) |
|
|
|
aux_inputs = tokenizer.apply_chat_template( |
|
aux_prompt, |
|
tokenize=True, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(device) |
|
|
|
else: |
|
|
|
|
|
main_inputs = tokenizer(main_prompt, return_tensors="pt").to(device) |
|
aux_inputs = tokenizer(aux_prompt, return_tensors="pt").to(device) |
|
|
|
else: |
|
raise ValueError("Invalid prompt format") |
|
|
|
|
|
|
|
|
|
main_inputs = convert_to_tensor_format(main_inputs, device) |
|
aux_inputs = convert_to_tensor_format(aux_inputs, device) |
|
|
|
return main_inputs, aux_inputs, mask_token |
|
|
|
|
|
class SPALogitsProcessor(LogitsProcessor): |
|
"""Processor that combines logits from a main and auxiliary model.""" |
|
|
|
def __init__(self, aux_model, aux_input_ids, mask_token, strength=1.5, modulated_by_prob=True, tokenizer=None, use_attention_mask=True): |
|
self.aux_model = aux_model |
|
self.aux_input_ids = aux_input_ids |
|
self.aux_past_key_values = None |
|
self.strength = strength |
|
self.modulated_by_prob = modulated_by_prob |
|
self.tokenizer = tokenizer |
|
self.mask_token = mask_token |
|
|
|
self.device = aux_input_ids.device |
|
self.use_attention_mask = use_attention_mask |
|
if self.use_attention_mask: |
|
self.attention_mask = create_masked_attention(self.aux_input_ids, [mask_token], self.tokenizer) |
|
else: |
|
self.attention_mask = None |
|
|
|
def __call__(self, input_ids, scores): |
|
|
|
if self.aux_past_key_values is None: |
|
|
|
aux_outputs = self.aux_model( |
|
input_ids=self.aux_input_ids, |
|
use_cache=True, |
|
return_dict=True, |
|
attention_mask=self.attention_mask |
|
) |
|
self.aux_past_key_values = aux_outputs.past_key_values |
|
aux_logits = aux_outputs.logits[:, -1, :] |
|
else: |
|
|
|
last_token = input_ids[:, -1].unsqueeze(-1).to(self.device) |
|
|
|
aux_outputs = self.aux_model( |
|
input_ids=last_token, |
|
past_key_values=self.aux_past_key_values, |
|
use_cache=True, |
|
return_dict=True |
|
) |
|
self.aux_past_key_values = aux_outputs.past_key_values |
|
aux_logits = aux_outputs.logits[:, -1, :] |
|
|
|
|
|
if abs(self.strength - 1.0) < 1e-4: |
|
return scores |
|
|
|
|
|
if abs(self.strength - 0.0) < 1e-4: |
|
return aux_logits |
|
|
|
|
|
if scores.device != aux_logits.device: |
|
aux_logits = aux_logits.to(scores.device) |
|
|
|
|
|
if torch.isnan(scores).any() or torch.isnan(aux_logits).any(): |
|
print("Warning: NaN values detected in input scores or aux_logits") |
|
scores = torch.nan_to_num(scores, nan=0.0) |
|
aux_logits = torch.nan_to_num(aux_logits, nan=0.0) |
|
|
|
|
|
diff = scores - aux_logits |
|
|
|
|
|
base_weight = self.strength - 1.0 |
|
|
|
|
|
|
|
if self.modulated_by_prob and (self.strength > 1 or self.strength < -1): |
|
|
|
temperature = 1.0 |
|
scaled_logits = scores / temperature |
|
main_probs = F.softmax(scaled_logits, dim=-1) |
|
|
|
|
|
main_probs = torch.clamp(main_probs, min=1e-6, max=1.0) |
|
|
|
|
|
|
|
|
|
max_prob = torch.max(main_probs) |
|
|
|
base_weight = base_weight / max_prob |
|
|
|
token_weights = base_weight * main_probs |
|
|
|
|
|
adjustment = token_weights * diff |
|
|
|
|
|
adjustment = torch.clamp(adjustment, min=-1e2, max=1e2) |
|
|
|
|
|
final_scores = scores + adjustment |
|
else: |
|
|
|
weighted_diff = base_weight * diff |
|
|
|
weighted_diff = torch.nan_to_num(weighted_diff, nan=0.0) |
|
|
|
weighted_diff = torch.clamp(weighted_diff, min=-1e3, max=1e3) |
|
final_scores = scores + weighted_diff |
|
|
|
|
|
|
|
final_scores = torch.clamp(final_scores, min=-1e3, max=1e3) |
|
|
|
return final_scores |
|
|
|
|
|
|
|
|
|
|