CMCenjoyer commited on
Commit
30ca9ee
·
verified ·
1 Parent(s): 2c2b35c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +114 -1
README.md CHANGED
@@ -6,4 +6,117 @@ language:
6
  - ru
7
  base_model:
8
  - MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli
9
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  - ru
7
  base_model:
8
  - MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli
9
+ ---
10
+
11
+
12
+ # DebertaTrace Model
13
+
14
+ Карточка модели для token classification классификации ответов RAG-модели без оконного прохода по тексту, аналогчному в Luna. На выходе — три логита: релевантность, использование и приверженность (правдивость).
15
+ ## Пример использования
16
+
17
+ ```python
18
+ import torch
19
+ from transformers import AutoModel
20
+ from torch import nn
21
+ from huggingface_hub import hf_hub_download
22
+ from transformers import AutoModel, AutoTokenizer
23
+ tokenizer = AutoTokenizer.from_pretrained("CMCenjoyer/deberta-trace")
24
+
25
+
26
+ class DebertaTrace(nn.Module):
27
+ def __init__(self, base_model):
28
+ super().__init__()
29
+ self.base = base_model
30
+ hid = base_model.config.hidden_size
31
+ self.rel_head = nn.Linear(hid,1)
32
+ self.util_head = nn.Linear(hid,1)
33
+ self.adh_head = nn.Linear(hid,1)
34
+
35
+ def forward(self, input_ids, attention_mask):
36
+ out = self.base(input_ids=input_ids, attention_mask=attention_mask)
37
+ hs = out.last_hidden_state
38
+ return {
39
+ 'logits_relevance': self.rel_head(hs).squeeze(-1),
40
+ 'logits_utilization': self.util_head(hs).squeeze(-1),
41
+ 'logits_adherence': self.adh_head(hs).squeeze(-1)
42
+ }
43
+
44
+ base_model = AutoModel.from_pretrained("CMCenjoyer/deberta-trace")
45
+ model = DebertaTrace(base_model)
46
+ # heads_weights.p в локальный кэш
47
+ file_path = hf_hub_download(repo_id="CMCenjoyer/deberta-trace", filename="heads_weights.pt")
48
+ heads_weights = torch.load(file_path, weights_only=True)
49
+ model.rel_head.load_state_dict(heads_weights['rel_head'])
50
+ model.util_head.load_state_dict(heads_weights['util_head'])
51
+ model.adh_head.load_state_dict(heads_weights['adh_head'])
52
+ def preprocess(example, max_length=512):
53
+ '''
54
+ Препроцессим входной элемент в маску контекста, маску ответва и input_ids + attention_mask
55
+ '''
56
+ question_ids = tokenizer.encode(example["question"], add_special_tokens=False)
57
+
58
+ doc_ids = []
59
+ for doc in example["documents_sentences"]:
60
+ for _, sent in doc:
61
+ tokens = tokenizer.encode(sent, add_special_tokens=False)
62
+ doc_ids += tokens
63
+
64
+ response_ids = tokenizer.encode(example["response"], add_special_tokens=False)
65
+
66
+ sep_id = tokenizer.sep_token_id
67
+ input_ids = question_ids + [sep_id] + doc_ids + [sep_id] + response_ids
68
+
69
+ context_mask = [0] * (len(question_ids) + 1) + [1] * len(doc_ids) + [0] + [0] * len(response_ids)
70
+ response_mask = [0] * (len(question_ids) + len(doc_ids) + 2) + [1] * len(response_ids)
71
+
72
+ if len(input_ids) > max_length:
73
+ input_ids = input_ids[:max_length]
74
+ context_mask = context_mask[:max_length]
75
+ response_mask = response_mask[:max_length]
76
+
77
+ return {
78
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
79
+ "attention_mask": torch.tensor([1] * len(input_ids), dtype=torch.long),
80
+ "context_mask": torch.tensor(context_mask, dtype=torch.bool),
81
+ "response_mask": torch.tensor(response_mask, dtype=torch.bool),
82
+ }
83
+ def compute_trace_metrics_inference(logits, masks, threshold=0.5):
84
+ '''
85
+ подсчет метрик TRACE для каждого элемента батча(все батчи должны быть фиксированной одной длины)
86
+ '''
87
+ rel_pred = (torch.sigmoid(logits['logits_relevance'].detach().cpu()) > threshold)
88
+ util_pred = (torch.sigmoid(logits['logits_utilization'].detach().cpu())> threshold)
89
+ adh_pred = (torch.sigmoid(logits['logits_adherence'].detach().cpu()) > threshold)
90
+
91
+ ctx_m = masks['context_mask'].detach().cpu()
92
+ resp_m = masks['response_mask'].detach().cpu()
93
+
94
+ def rate(pred, mask):
95
+ # sum(pred & mask) / sum(mask)
96
+ num = (pred & mask).sum(dim=1).float()
97
+ den = mask.sum(dim=1).float().clamp(min=1)
98
+ return num.div(den)
99
+
100
+ relevance_rate = rate(rel_pred, ctx_m)
101
+ utilization_rate = rate(util_pred, ctx_m)
102
+ adherence_rate = rate(adh_pred, resp_m)
103
+
104
+ # completeness: из релевантных предсказаний — сколько ещё и util
105
+ num_ru = (rel_pred & util_pred & ctx_m).sum(dim=1).float()
106
+ den_r = rel_pred.sum(dim=1).float().clamp(min=1)
107
+ completeness = num_ru.div(den_r)
108
+
109
+ return {
110
+ 'relevance_rate': relevance_rate,
111
+ 'utilization_rate': utilization_rate,
112
+ 'adherence_rate': adherence_rate,
113
+ 'completeness': completeness
114
+ }
115
+ from datasets import load_dataset
116
+ ds = load_dataset("rungalileo/ragbench", "delucionqa")
117
+ ex = preprocess(ds['train'][9])
118
+ model.eval()
119
+ with torch.no_grad():
120
+ outputs = model(ex["input_ids"].unsqueeze(0), ex["attention_mask"].unsqueeze(0))
121
+ batch_metrics = compute_trace_metrics_inference(outputs, {'context_mask': ex["context_mask"].unsqueeze(0) , 'response_mask':ex["response_mask"].unsqueeze(0)})
122
+ batch_metrics