DoctorChaos commited on
Commit
d82fdf0
·
verified ·
1 Parent(s): aca1be6

Upload spa.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. spa.py +530 -0
spa.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import re
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ BitsAndBytesConfig,
9
+ LogitsProcessor,
10
+ GenerationConfig,
11
+ TextIteratorStreamer,
12
+ )
13
+
14
+ # --- Helper Function for Input Preparation ---
15
+
16
+ def create_masked_attention(input_ids, target_strings, tokenizer):
17
+ """
18
+ Creates an attention mask where tokens corresponding to any of the target strings have 0 attention.
19
+ """
20
+ # Ensure input_ids is 2D
21
+ if len(input_ids.shape) == 1:
22
+ input_ids = input_ids.unsqueeze(0)
23
+
24
+ # Create default attention mask (all 1s)
25
+ attention_mask = torch.ones_like(input_ids)
26
+
27
+ # Convert single string to list for uniform processing
28
+ if isinstance(target_strings, str):
29
+ target_strings = [target_strings]
30
+
31
+ # Get the input IDs as a list
32
+ input_ids_list = input_ids[0].tolist()
33
+
34
+ # Decode each token individually for comparison
35
+ token_texts = []
36
+ for token_id in input_ids_list:
37
+ token_texts.append(tokenizer.decode([token_id]))
38
+
39
+
40
+
41
+ masked_indices = []
42
+
43
+ # Try tokenizing each target string to find its exact token representation
44
+ for target_string in target_strings:
45
+ if not target_string:
46
+ continue
47
+
48
+ # Tokenize the target string to get its expected token IDs
49
+ target_ids = tokenizer.encode(target_string, add_special_tokens=False)
50
+ target_tokens = [tokenizer.decode([id]) for id in target_ids]
51
+
52
+
53
+ # First approach: Direct token sequence matching
54
+ # Look for the sequence of tokens in the input
55
+ for i in range(len(token_texts) - len(target_tokens) + 1):
56
+ # Check if this position starts a matching sequence
57
+ all_match = True
58
+ for j, target_token in enumerate(target_tokens):
59
+ if i+j >= len(token_texts) or target_token != token_texts[i+j]:
60
+ all_match = False
61
+ break
62
+
63
+ if all_match:
64
+ for j in range(len(target_tokens)):
65
+ attention_mask[0, i+j] = 0
66
+ masked_indices.append(i+j)
67
+
68
+ # Second approach: Look for individual tokens that make up the target
69
+ for i, token_text in enumerate(token_texts):
70
+ if token_text.strip() in target_tokens:
71
+ attention_mask[0, i] = 0
72
+ masked_indices.append(i)
73
+
74
+ # Third approach: If the target is split between tokens, try to detect it
75
+ # For example 'MASKTOKEN' might be split as ' MASK' and 'TOKEN'
76
+ if len(target_tokens) == 1 and len(target_tokens[0]) > 2: # Only for substantial single tokens
77
+ # Look for token pairs that might contain the target
78
+ for i in range(len(token_texts) - 1):
79
+ pair = token_texts[i].strip() + token_texts[i+1].strip()
80
+ if target_string in pair:
81
+ attention_mask[0, i] = 0
82
+ attention_mask[0, i+1] = 0
83
+ masked_indices.extend([i, i+1])
84
+
85
+ # Check for triplet if possible
86
+ if i < len(token_texts) - 2:
87
+ triplet = token_texts[i].strip() + token_texts[i+1].strip() + token_texts[i+2].strip()
88
+ if target_string in triplet:
89
+ attention_mask[0, i] = 0
90
+ attention_mask[0, i+1] = 0
91
+ attention_mask[0, i+2] = 0
92
+ masked_indices.extend([i, i+1, i+2])
93
+
94
+
95
+ # Print the final mask
96
+ mask_positions = list(set(masked_indices)) # Remove duplicates
97
+ mask_positions.sort()
98
+
99
+ if mask_positions:
100
+ masked_text = [token_texts[idx] for idx in mask_positions]
101
+ else:
102
+ print("WARNING: No tokens were masked!")
103
+ # Last resort - just mask any token containing part of the target
104
+ for target_string in target_strings:
105
+ for i, token_text in enumerate(token_texts):
106
+ if (target_string in token_text) or (token_text.strip() in target_string and len(token_text.strip()) > 2):
107
+ attention_mask[0, i] = 0
108
+ masked_indices.append(i)
109
+
110
+ # Check again
111
+ mask_positions = list(set(masked_indices))
112
+ mask_positions.sort()
113
+
114
+ return attention_mask
115
+
116
+
117
+ def preprocess_anchors(anchors):
118
+ # remove duplicates in anchors
119
+ anchors = list(set(anchors))
120
+ # remove "", " " in anchors
121
+ anchors = [anchor for anchor in anchors if anchor != "" and anchor != " "]
122
+ # sort the anchors by length
123
+ anchors = sorted(anchors, key=len, reverse=True)
124
+ return anchors
125
+
126
+
127
+ # Define a wrapper function to handle different cases
128
+ # The provided anchors are viewed as global anchors
129
+ def format_spa_input(input, anchors, mask_token, whole_word_only=True):
130
+ # check if the input is a string or a list of messages
131
+ if isinstance(input, str):
132
+ # 1. Collect all anchors
133
+ current_anchors = list(anchors) # Start with global anchors
134
+ tag_anchors = []
135
+ if re.search(r"<anchor>", input):
136
+ tag_anchors = re.findall(r"<anchor>(.*?)</anchor>", input, flags=re.DOTALL)
137
+ current_anchors.extend(tag_anchors)
138
+
139
+ # 2. Clean the input string (remove tags)
140
+ cleaned_input = re.sub(r"<anchor>|</anchor>", "", input)
141
+
142
+ # 3. Preprocess all collected anchors (unique, non-empty, sorted desc)
143
+ final_anchors = preprocess_anchors(current_anchors)
144
+
145
+ # 4. Escape anchors for regex and build pattern (longest first)
146
+ masked_input = cleaned_input # Initialize with cleaned input
147
+ if final_anchors:
148
+ if whole_word_only:
149
+ # Use lookarounds to assert boundaries without consuming them (Fix 1)
150
+ escaped_anchors = [rf"(?<!\w){re.escape(a)}(?!\w)" for a in final_anchors]
151
+ else:
152
+ escaped_anchors = [re.escape(a) for a in final_anchors]
153
+
154
+ pattern = "|".join(escaped_anchors)
155
+ # 5. Perform anchor replacement in one pass
156
+ masked_input = re.sub(pattern, mask_token, cleaned_input)
157
+
158
+ # 6. Post-processing: Merge consecutive mask tokens (separated by space)
159
+ if mask_token: # Avoid processing if mask_token is empty
160
+ escaped_mask_token = re.escape(mask_token)
161
+ # Improved merging logic (Fix 2)
162
+ merge_pattern = f"{escaped_mask_token}\s+{escaped_mask_token}"
163
+ while re.search(merge_pattern, masked_input):
164
+ masked_input = re.sub(merge_pattern, mask_token, masked_input)
165
+ # Optional: merge masks without space if needed, e.g., mask_token+mask_token -> mask_token
166
+ # merge_pattern_no_space = f"{escaped_mask_token}{escaped_mask_token}"
167
+ # while re.search(merge_pattern_no_space, masked_input):
168
+ # masked_input = re.sub(merge_pattern_no_space, mask_token, masked_input)
169
+
170
+ return cleaned_input, masked_input
171
+
172
+ elif isinstance(input, list):
173
+ cleaned_input_list = []
174
+ masked_input_list = []
175
+
176
+ for msg in input:
177
+ msg_copy = msg.copy() # Work on a copy
178
+ content = msg_copy.get("content", "")
179
+
180
+ # 1. Collect all anchors for this message
181
+ current_anchors = list(anchors) # Start with global anchors
182
+ if "anchors" in msg_copy:
183
+ dict_anchors = msg_copy.get("anchors", [])
184
+ if isinstance(dict_anchors, list):
185
+ current_anchors.extend(dict_anchors)
186
+ tag_anchors = []
187
+ if re.search(r"<anchor>", content):
188
+ tag_anchors = re.findall(r"<anchor>(.*?)</anchor>", content, flags=re.DOTALL)
189
+ current_anchors.extend(tag_anchors)
190
+
191
+ # 2. Clean the message content (remove tags)
192
+ cleaned_content = re.sub(r"<anchor>|</anchor>", "", content)
193
+
194
+ # 3. Preprocess all collected anchors for this message
195
+ final_anchors = preprocess_anchors(current_anchors)
196
+
197
+ # 4. Escape anchors, build pattern, and replace in one pass
198
+ masked_content = cleaned_content # Initialize
199
+ if final_anchors:
200
+ if whole_word_only:
201
+ # Use lookarounds to assert boundaries without consuming them (Fix 1)
202
+ escaped_anchors = [rf"(?<!\w){re.escape(a)}(?!\w)" for a in final_anchors]
203
+ else:
204
+ escaped_anchors = [re.escape(a) for a in final_anchors]
205
+
206
+ pattern = "|".join(escaped_anchors)
207
+ masked_content = re.sub(pattern, mask_token, cleaned_content)
208
+
209
+ # 5. Post-processing: Merge consecutive mask tokens (separated by space) for this message
210
+ if mask_token:
211
+ escaped_mask_token = re.escape(mask_token)
212
+ # Improved merging logic (Fix 2)
213
+ merge_pattern = f"{escaped_mask_token}\s+{escaped_mask_token}"
214
+ while re.search(merge_pattern, masked_content):
215
+ masked_content = re.sub(merge_pattern, mask_token, masked_content)
216
+ # Optional: merge masks without space if needed
217
+ # merge_pattern_no_space = f"{escaped_mask_token}{escaped_mask_token}"
218
+ # while re.search(merge_pattern_no_space, masked_content):
219
+ # masked_content = re.sub(merge_pattern_no_space, mask_token, masked_content)
220
+
221
+ # 6. Prepare output dictionaries
222
+ final_cleaned_msg = msg_copy.copy()
223
+ final_cleaned_msg["content"] = cleaned_content
224
+ if "anchors" in final_cleaned_msg:
225
+ del final_cleaned_msg["anchors"]
226
+
227
+ final_masked_msg = msg_copy.copy()
228
+ final_masked_msg["content"] = masked_content
229
+ if "anchors" in final_masked_msg:
230
+ del final_masked_msg["anchors"]
231
+
232
+ cleaned_input_list.append(final_cleaned_msg)
233
+ masked_input_list.append(final_masked_msg)
234
+
235
+ return cleaned_input_list, masked_input_list
236
+ else:
237
+ raise ValueError("Invalid input type. Must be string or list of dictionaries.")
238
+
239
+
240
+ def get_mask_messages(messages, mask_token):
241
+ mask_msg = messages.copy() # get a copy of the messages
242
+
243
+ # Debug anchor count
244
+ for msg in mask_msg:
245
+ if "anchors" in msg:
246
+ # Debug pre-replacement content
247
+ original_content = msg["content"]
248
+
249
+ # Sort anchors by length (descending) to replace longest matches first
250
+ anchors = sorted(msg["anchors"], key=len, reverse=True)
251
+
252
+ for anchor in anchors:
253
+ if anchor in msg["content"]:
254
+ # Replace the anchor with mask token
255
+ msg["content"] = msg["content"].replace(anchor, mask_token)
256
+
257
+ # Debug post-replacement content
258
+ if original_content == msg["content"]:
259
+ print(f"WARNING: No anchors were replaced in message: {original_content[:50]}...")
260
+ print(f"Anchors: {anchors}")
261
+
262
+ return mask_msg
263
+
264
+
265
+ def convert_to_tensor_format(inputs, device=None):
266
+ # Case 1: Already a tensor in correct format
267
+ if isinstance(inputs, torch.Tensor) and len(inputs.shape) == 2:
268
+ if device is not None:
269
+ inputs = inputs.to(device)
270
+ return inputs
271
+
272
+ # Case 2: Object with input_ids attribute
273
+ if hasattr(inputs, 'input_ids'):
274
+ inputs = inputs.input_ids
275
+
276
+ # Case 3: Dictionary with input_ids key
277
+ elif isinstance(inputs, dict) and 'input_ids' in inputs:
278
+ inputs = inputs['input_ids']
279
+
280
+ # Case 4: List of token IDs
281
+ elif isinstance(inputs, list):
282
+ inputs = torch.tensor([inputs], device=device)
283
+
284
+ # Case 5: Single tensor but needs reshaping
285
+ elif isinstance(inputs, torch.Tensor):
286
+ if len(inputs.shape) == 1:
287
+ inputs = inputs.unsqueeze(0)
288
+
289
+ # Ensure it's on the correct device
290
+ if isinstance(inputs, torch.Tensor) and device is not None:
291
+ inputs = inputs.to(device)
292
+
293
+ return inputs
294
+
295
+ def create_default_attention_mask(input_ids, device=None):
296
+ """
297
+ Creates a default attention mask (all 1s) for the given input_ids tensor.
298
+
299
+ Args:
300
+ input_ids (torch.Tensor): The input IDs tensor, shape (batch_size, seq_len)
301
+ device: The device to place the attention mask on
302
+
303
+ Returns:
304
+ torch.Tensor: Attention mask with the same shape as input_ids, all values set to 1
305
+ """
306
+ # Ensure input_ids is on the right device if specified
307
+ if device is not None and input_ids.device != device:
308
+ input_ids = input_ids.to(device)
309
+
310
+ # Create attention mask filled with 1s (all tokens attend to all positions)
311
+ attention_mask = torch.ones_like(input_ids)
312
+
313
+ return attention_mask
314
+
315
+ def spa_tokenize(prompt_with_anchors, global_anchors, tokenizer, device):
316
+
317
+ # Set pad token if missing
318
+ if tokenizer.pad_token is None:
319
+ print("Setting pad token to EOS token")
320
+ tokenizer.pad_token = tokenizer.eos_token
321
+ # Remove reference to global model variable
322
+ # model.config.pad_token_id = model.config.eos_token_id
323
+
324
+ if tokenizer.mask_token:
325
+ mask_token = tokenizer.mask_token
326
+ else:
327
+ mask_token = "MASKTOKEN"
328
+
329
+
330
+ main_prompt, aux_prompt = format_spa_input(
331
+ input=prompt_with_anchors,
332
+ anchors=global_anchors,
333
+ mask_token=mask_token,
334
+ whole_word_only=False
335
+ )
336
+
337
+
338
+ # detect if tokenizer has chat_template
339
+ if isinstance(main_prompt, list):
340
+ # Expected for chat models
341
+ # print("--- Message list processed by chat template")
342
+ if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
343
+
344
+ main_inputs = tokenizer.apply_chat_template(
345
+ main_prompt,
346
+ tokenize=True,
347
+ add_generation_prompt=True,
348
+ return_tensors="pt"
349
+ ).to(device)
350
+
351
+ aux_inputs = tokenizer.apply_chat_template(
352
+ aux_prompt,
353
+ tokenize=True,
354
+ add_generation_prompt=True,
355
+ return_tensors="pt"
356
+ ).to(device)
357
+
358
+ else:
359
+ # non-chat models, need to convert to a string prompt
360
+ # print("--- Message list processed by flat prompt")
361
+ flat_prompt_main = ""
362
+ for msg in main_prompt:
363
+ flat_prompt_main += f"{msg['role']}: {msg['content']}\n"
364
+ flat_prompt_main += "Assistant: " # Add assistant prefix for generation
365
+
366
+ flat_prompt_aux = ""
367
+ for msg in aux_prompt:
368
+ flat_prompt_aux += f"{msg['role']}: {msg['content']}\n"
369
+ flat_prompt_aux += "Assistant: " # Add assistant prefix for generation
370
+
371
+ # Tokenize the flattened prompts
372
+ main_inputs = tokenizer(flat_prompt_main, return_tensors="pt").to(device)
373
+ aux_inputs = tokenizer(flat_prompt_aux, return_tensors="pt").to(device)
374
+
375
+ # User provides a string prompt
376
+ elif isinstance(prompt_with_anchors, str):
377
+ if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
378
+ # print("--- String prompt processed by chat template")
379
+
380
+ # If user only provides a string prompt, we need to convert it to a chat prompt
381
+ main_prompt = [{"role": "user", "content": main_prompt}]
382
+ aux_prompt = [{"role": "user", "content": aux_prompt}]
383
+
384
+ main_inputs = tokenizer.apply_chat_template(
385
+ main_prompt,
386
+ tokenize=True,
387
+ add_generation_prompt=True,
388
+ return_tensors="pt"
389
+ ).to(device)
390
+
391
+ aux_inputs = tokenizer.apply_chat_template(
392
+ aux_prompt,
393
+ tokenize=True,
394
+ add_generation_prompt=True,
395
+ return_tensors="pt"
396
+ ).to(device)
397
+
398
+ else:
399
+ # non-chat models, need to convert to a string prompt
400
+ # print("--- String prompt processed by flat prompt")
401
+ main_inputs = tokenizer(main_prompt, return_tensors="pt").to(device)
402
+ aux_inputs = tokenizer(aux_prompt, return_tensors="pt").to(device)
403
+
404
+ else:
405
+ raise ValueError("Invalid prompt format")
406
+
407
+ # Make sure the returned input_ids follow the expected format: tensor([[1, 2, 3]], device='x')
408
+ # Handle all possible tokenizer output formats
409
+
410
+ main_inputs = convert_to_tensor_format(main_inputs, device)
411
+ aux_inputs = convert_to_tensor_format(aux_inputs, device)
412
+
413
+ return main_inputs, aux_inputs, mask_token
414
+
415
+
416
+ class SPALogitsProcessor(LogitsProcessor):
417
+ """Processor that combines logits from a main and auxiliary model."""
418
+
419
+ def __init__(self, aux_model, aux_input_ids, mask_token, strength=1.5, modulated_by_prob=True, tokenizer=None, use_attention_mask=True):
420
+ self.aux_model = aux_model # Same model, used for aux inputs
421
+ self.aux_input_ids = aux_input_ids
422
+ self.aux_past_key_values = None
423
+ self.strength = strength
424
+ self.modulated_by_prob = modulated_by_prob # Whether to modulate weight by probability
425
+ self.tokenizer = tokenizer # Optional, for debug printing
426
+ self.mask_token = mask_token # Store mask_token
427
+ # Store the device of the input_ids to use consistently
428
+ self.device = aux_input_ids.device
429
+ self.use_attention_mask = use_attention_mask
430
+ if self.use_attention_mask:
431
+ self.attention_mask = create_masked_attention(self.aux_input_ids, [mask_token], self.tokenizer)
432
+ else:
433
+ self.attention_mask = None
434
+
435
+ def __call__(self, input_ids, scores):
436
+ # Get aux model outputs for the current step
437
+ if self.aux_past_key_values is None:
438
+ # First step, run on full aux prompt
439
+ aux_outputs = self.aux_model(
440
+ input_ids=self.aux_input_ids,
441
+ use_cache=True,
442
+ return_dict=True,
443
+ attention_mask=self.attention_mask
444
+ )
445
+ self.aux_past_key_values = aux_outputs.past_key_values
446
+ aux_logits = aux_outputs.logits[:, -1, :]
447
+ else:
448
+ # Subsequent steps, run only on new token with past_key_values
449
+ last_token = input_ids[:, -1].unsqueeze(-1).to(self.device) # Ensure same device
450
+ # For subsequent tokens, we don't need to pass the attention mask
451
+ aux_outputs = self.aux_model(
452
+ input_ids=last_token,
453
+ past_key_values=self.aux_past_key_values,
454
+ use_cache=True,
455
+ return_dict=True
456
+ )
457
+ self.aux_past_key_values = aux_outputs.past_key_values
458
+ aux_logits = aux_outputs.logits[:, -1, :]
459
+
460
+ # Special case: strength = 1 means use only main logits
461
+ if abs(self.strength - 1.0) < 1e-4:
462
+ return scores
463
+
464
+ # if strength is 0, return the aux logits
465
+ if abs(self.strength - 0.0) < 1e-4:
466
+ return aux_logits
467
+
468
+ # Ensure scores and aux_logits are on the same device
469
+ if scores.device != aux_logits.device:
470
+ aux_logits = aux_logits.to(scores.device)
471
+
472
+ # Check for NaNs in the inputs
473
+ if torch.isnan(scores).any() or torch.isnan(aux_logits).any():
474
+ print("Warning: NaN values detected in input scores or aux_logits")
475
+ scores = torch.nan_to_num(scores, nan=0.0)
476
+ aux_logits = torch.nan_to_num(aux_logits, nan=0.0)
477
+
478
+ # Calculate the difference between main and aux logits
479
+ diff = scores - aux_logits
480
+
481
+ # Calculate the base weight
482
+ base_weight = self.strength - 1.0
483
+
484
+ # Modulate the weight by probability if enabled
485
+ # Only do this when strength > 1 (that's what can cause random behavior. If -1 < strength < 1, it is semantic dimishment, disable this for more precise control)
486
+ if self.modulated_by_prob and (self.strength > 1 or self.strength < -1):
487
+ # Convert logits to probabilities with temperature scaling for stability
488
+ temperature = 1.0
489
+ scaled_logits = scores / temperature
490
+ main_probs = F.softmax(scaled_logits, dim=-1)
491
+
492
+ # Clamp probabilities to avoid numerical issues
493
+ main_probs = torch.clamp(main_probs, min=1e-6, max=1.0)
494
+
495
+ # Each token's weight is scaled by its probability
496
+
497
+ # get the max probability
498
+ max_prob = torch.max(main_probs)
499
+ # normalize the base weight by the max probability
500
+ base_weight = base_weight / max_prob
501
+ # get different weights for each token based on their main probability
502
+ token_weights = base_weight * main_probs
503
+
504
+ # Apply the weighted adjustment
505
+ adjustment = token_weights * diff
506
+
507
+ # Clamp the adjustment to avoid extreme values
508
+ adjustment = torch.clamp(adjustment, min=-1e2, max=1e2)
509
+
510
+ # Compute final scores
511
+ final_scores = scores + adjustment
512
+ else:
513
+ # Safe computation of weighted difference
514
+ weighted_diff = base_weight * diff
515
+ # Check for and handle any NaNs that might have appeared
516
+ weighted_diff = torch.nan_to_num(weighted_diff, nan=0.0)
517
+ # Clamp to avoid extreme values
518
+ weighted_diff = torch.clamp(weighted_diff, min=-1e3, max=1e3)
519
+ final_scores = scores + weighted_diff
520
+
521
+
522
+ # Final stability check
523
+ final_scores = torch.clamp(final_scores, min=-1e3, max=1e3)
524
+
525
+ return final_scores
526
+
527
+
528
+
529
+
530
+