Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
8151596
1
Parent(s):
2d7d23d
switched models to using full context
Browse files- utils/models.py +4 -12
utils/models.py
CHANGED
|
@@ -30,18 +30,13 @@ def generate_summaries(example, model_a_name, model_b_name):
|
|
| 30 |
# Create a plain text version of the contexts for the models
|
| 31 |
context_text = ""
|
| 32 |
context_parts = []
|
| 33 |
-
if "
|
| 34 |
-
for ctx in example["
|
| 35 |
if isinstance(ctx, dict) and "content" in ctx:
|
| 36 |
context_parts.append(ctx["content"])
|
| 37 |
context_text = "\n---\n".join(context_parts)
|
| 38 |
else:
|
| 39 |
-
|
| 40 |
-
if "full_contexts" in example:
|
| 41 |
-
for ctx in example["full_contexts"]:
|
| 42 |
-
if isinstance(ctx, dict) and "content" in ctx:
|
| 43 |
-
context_parts.append(ctx["content"])
|
| 44 |
-
context_text = "\n---\n".join(context_parts)
|
| 45 |
|
| 46 |
# Pass 'Answerable' status to models (they might use it)
|
| 47 |
answerable = example.get("Answerable", True)
|
|
@@ -85,17 +80,14 @@ def run_inference(model_name, context, question):
|
|
| 85 |
).to(device)
|
| 86 |
|
| 87 |
input_length = actual_input.shape[1]
|
| 88 |
-
|
| 89 |
-
# Create attention mask (1 for all tokens since we're not padding)
|
| 90 |
attention_mask = torch.ones_like(actual_input).to(device)
|
| 91 |
|
| 92 |
# Generate output
|
| 93 |
with torch.inference_mode():
|
| 94 |
-
# Disable gradient calculation for inference
|
| 95 |
outputs = model.generate(
|
| 96 |
actual_input,
|
| 97 |
attention_mask=attention_mask,
|
| 98 |
-
max_new_tokens=512,
|
| 99 |
pad_token_id=tokenizer.pad_token_id,
|
| 100 |
)
|
| 101 |
|
|
|
|
| 30 |
# Create a plain text version of the contexts for the models
|
| 31 |
context_text = ""
|
| 32 |
context_parts = []
|
| 33 |
+
if "full_contexts" in example:
|
| 34 |
+
for ctx in example["full_contexts"]:
|
| 35 |
if isinstance(ctx, dict) and "content" in ctx:
|
| 36 |
context_parts.append(ctx["content"])
|
| 37 |
context_text = "\n---\n".join(context_parts)
|
| 38 |
else:
|
| 39 |
+
raise ValueError("No context found in the example.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
# Pass 'Answerable' status to models (they might use it)
|
| 42 |
answerable = example.get("Answerable", True)
|
|
|
|
| 80 |
).to(device)
|
| 81 |
|
| 82 |
input_length = actual_input.shape[1]
|
|
|
|
|
|
|
| 83 |
attention_mask = torch.ones_like(actual_input).to(device)
|
| 84 |
|
| 85 |
# Generate output
|
| 86 |
with torch.inference_mode():
|
|
|
|
| 87 |
outputs = model.generate(
|
| 88 |
actual_input,
|
| 89 |
attention_mask=attention_mask,
|
| 90 |
+
max_new_tokens=512,
|
| 91 |
pad_token_id=tokenizer.pad_token_id,
|
| 92 |
)
|
| 93 |
|