Spaces:
Sleeping
Sleeping
Update tasks/text.py
Browse files- tasks/text.py +8 -3
tasks/text.py
CHANGED
|
@@ -158,10 +158,15 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 158 |
model.eval()
|
| 159 |
predictions = []
|
| 160 |
for batch in tqdm(test_dataloader):
|
| 161 |
-
|
| 162 |
with torch.no_grad():
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
logits = logits.detach().cpu().numpy()
|
| 166 |
predictions.extend(logits.argmax(1))
|
| 167 |
|
|
|
|
| 158 |
model.eval()
|
| 159 |
predictions = []
|
| 160 |
for batch in tqdm(test_dataloader):
|
| 161 |
+
|
| 162 |
with torch.no_grad():
|
| 163 |
+
if MODEL =="mlp":
|
| 164 |
+
b_texts = batch
|
| 165 |
+
logits = model(b_texts)
|
| 166 |
+
elif MODEL == "ct":
|
| 167 |
+
b_input_ids, b_input_mask, b_token_type_ids, b_labels = batch
|
| 168 |
+
logits = model(b_input_ids, b_token_type_ids, b_input_mask)
|
| 169 |
+
|
| 170 |
logits = logits.detach().cpu().numpy()
|
| 171 |
predictions.extend(logits.argmax(1))
|
| 172 |
|