99eren99 commited on
Commit
4751db5
·
verified ·
1 Parent(s): 154f69b

Update assets/evalPytrec.py

Browse files
Files changed (1) hide show
  1. assets/evalPytrec.py +183 -183
assets/evalPytrec.py CHANGED
@@ -1,183 +1,183 @@
1
- import os
2
-
3
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
4
- os.environ["HF_HOME"] = "../../cache/hgCache"
5
- os.environ["TRANSFORMERS_CACHE"] = "../../cache/transformersCache/"
6
-
7
- import glob
8
- import logging
9
- import sys
10
- from collections import defaultdict
11
-
12
- import numpy as np
13
- import pytrec_eval
14
- import tqdm, torch
15
- import pandas as pd
16
- from pylate import models, rank
17
-
18
-
19
- document_length = 512
20
-
21
- model_name_or_paths = [
22
- "9eren99/TrColBERT",
23
- "jinaai/jina-colbert-v2",
24
- "antoinelouis/colbert-xm",
25
- ]
26
-
27
- datasetnames = [
28
- "fiqa2018",
29
- "climatefever",
30
- "dbpedia",
31
- "fever",
32
- "hotpotqa",
33
- # "msmarco",
34
- "nfcorpus",
35
- "nq",
36
- "quoraretrieval",
37
- "scidocs",
38
- "arguana",
39
- "scifact",
40
- "touche2020",
41
- ]
42
- for datasetname in datasetnames:
43
- print("#############", datasetname, "##############")
44
- evalResultsDf = None
45
- for model_name_or_path in model_name_or_paths:
46
- torch.cuda.empty_cache()
47
- if "jinaai/jina-colbert-v2" == model_name_or_path:
48
- model = models.ColBERT(
49
- model_name_or_path=model_name_or_path,
50
- query_prefix="[QueryMarker]",
51
- document_prefix="[DocumentMarker]",
52
- attend_to_expansion_tokens=True,
53
- trust_remote_code=True,
54
- document_length=document_length,
55
- )
56
- elif "antoinelouis/colbert-xm" == model_name_or_path:
57
- model = models.ColBERT(model_name_or_path="antoinelouis/colbert-xm")
58
- language = "tr_TR" # Use a code from https://huggingface.co/facebook/xmod-base#languages
59
-
60
- backbone = model[0].auto_model
61
- if backbone.__class__.__name__.lower().startswith("xmod"):
62
- backbone.set_default_language(language)
63
- else:
64
- model = models.ColBERT(
65
- model_name_or_path=model_name_or_path,
66
- document_length=document_length,
67
- attend_to_expansion_tokens=(
68
- True if "attend" in model_name_or_path else False
69
- ),
70
- )
71
-
72
- model.eval()
73
- model.to("cuda")
74
-
75
- dfDocs = pd.read_parquet(
76
- f"datasets/{datasetname}/corpus/train-00000-of-00001.parquet"
77
- ).dropna()
78
- dfQueries = pd.read_parquet(
79
- f"datasets/{datasetname}/queries/train-00000-of-00001.parquet"
80
- ).dropna()
81
-
82
- if "checkpoint" in model_name_or_path:
83
- try:
84
- model.tokenizer.model_input_names.remove("token_type_ids")
85
- except:
86
- print(model_name_or_path)
87
- dfDocs.TurkishText = dfDocs.TurkishText.apply(
88
- lambda x: x.replace("İ", "i").replace("I", "ı").lower()
89
- )
90
- dfQueries.TurkishText = dfQueries.TurkishText.apply(
91
- lambda x: x.replace("İ", "i").replace("I", "ı").lower()
92
- )
93
-
94
- # Read test queries
95
- queries = []
96
- documents = []
97
- passage_cand = {}
98
- relevant_qid = []
99
- relevant_docs = defaultdict(lambda: defaultdict(int))
100
-
101
- # read corpus
102
- newId2oldId_Docs = {}
103
- for i, row in enumerate(dfDocs.values):
104
- documents.append(row[2])
105
- newId2oldId_Docs[i] = str(row[0])
106
- relevant_qid.append(str(row[0]))
107
-
108
- # read queries
109
- newId2oldId_Queries = {}
110
- for i, row in enumerate(dfQueries.values):
111
- queries.append(row[2])
112
- newId2oldId_Queries[i] = str(row[0])
113
-
114
- for j, rowDoc in enumerate(dfDocs.values):
115
- relevant_docs[str(row[0])][str(rowDoc[0])] = 0
116
-
117
- # read qrels
118
- dfQrels = pd.read_parquet(
119
- f"datasets/{datasetname}/qrels/train-00000-of-00001.parquet"
120
- )
121
- for i, row in enumerate(dfQrels.values):
122
- relevant_docs[str(row[0])][str(row[1])] = 1
123
-
124
- candidateIds = [[i for i in range(len(documents))]]
125
-
126
- queries_result_list = []
127
- run = {}
128
-
129
- documents_embeddings = model.encode(
130
- [documents], is_query=False, show_progress_bar=True
131
- )
132
-
133
- for i, query in enumerate(tqdm.tqdm(queries)):
134
-
135
- queries_embeddings = model.encode(
136
- [query],
137
- is_query=True,
138
- )
139
-
140
- reranked_documents = rank.rerank(
141
- documents_ids=candidateIds,
142
- queries_embeddings=queries_embeddings,
143
- documents_embeddings=documents_embeddings,
144
- )
145
-
146
- run[newId2oldId_Queries[i]] = {}
147
- for resDict in reranked_documents[0]:
148
- run[newId2oldId_Queries[i]][newId2oldId_Docs[resDict["id"]]] = float(
149
- resDict["score"]
150
- )
151
-
152
- evaluator = pytrec_eval.RelevanceEvaluator(
153
- relevant_docs, pytrec_eval.supported_measures
154
- )
155
- scores = evaluator.evaluate(run)
156
-
157
- def print_line(measure, scope, value):
158
- print("{:25s}{:8s}{:.4f}".format(measure, scope, value))
159
-
160
- for query_id, query_measures in sorted(scores.items()):
161
- break
162
- for measure, value in sorted(query_measures.items()):
163
- print_line(measure, query_id, value)
164
-
165
- # Scope hack: use query_measures of last item in previous loop to
166
- # figure out all unique measure names.
167
- resultsColumns = ["model name"]
168
- resultsRow = [model_name_or_path]
169
- for measure in sorted(query_measures.keys()):
170
- resultsColumns.append(measure)
171
- resultsRow.append(
172
- pytrec_eval.compute_aggregated_measure(
173
- measure,
174
- [query_measures[measure] for query_measures in scores.values()],
175
- )
176
- )
177
-
178
- if evalResultsDf is None:
179
- evalResultsDf = pd.DataFrame(columns=resultsColumns)
180
- evalResultsDf.loc[-1] = resultsRow
181
- evalResultsDf.index = evalResultsDf.index + 1
182
-
183
- evalResultsDf.to_csv(f"resultsn/{datasetname}.csv", encoding="utf-8")
 
1
+ import os
2
+
3
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
4
+ os.environ["HF_HOME"] = "../../cache/hgCache"
5
+ os.environ["TRANSFORMERS_CACHE"] = "../../cache/transformersCache/"
6
+
7
+ import glob
8
+ import logging
9
+ import sys
10
+ from collections import defaultdict
11
+
12
+ import numpy as np
13
+ import pytrec_eval
14
+ import tqdm, torch
15
+ import pandas as pd
16
+ from pylate import models, rank
17
+
18
+
19
+ document_length = 512
20
+
21
+ model_name_or_paths = [
22
+ "9eren99/TrColBERT",
23
+ "jinaai/jina-colbert-v2",
24
+ "antoinelouis/colbert-xm",
25
+ ]
26
+
27
+ datasetnames = [
28
+ "fiqa2018",
29
+ "climatefever",
30
+ "dbpedia",
31
+ "fever",
32
+ "hotpotqa",
33
+ # "msmarco",
34
+ "nfcorpus",
35
+ "nq",
36
+ "quoraretrieval",
37
+ "scidocs",
38
+ "arguana",
39
+ "scifact",
40
+ "touche2020",
41
+ ]
42
+ for datasetname in datasetnames:
43
+ print("#############", datasetname, "##############")
44
+ evalResultsDf = None
45
+ for model_name_or_path in model_name_or_paths:
46
+ torch.cuda.empty_cache()
47
+ if "jinaai/jina-colbert-v2" == model_name_or_path:
48
+ model = models.ColBERT(
49
+ model_name_or_path=model_name_or_path,
50
+ query_prefix="[QueryMarker]",
51
+ document_prefix="[DocumentMarker]",
52
+ attend_to_expansion_tokens=True,
53
+ trust_remote_code=True,
54
+ document_length=document_length,
55
+ )
56
+ elif "antoinelouis/colbert-xm" == model_name_or_path:
57
+ model = models.ColBERT(model_name_or_path="antoinelouis/colbert-xm")
58
+ language = "tr_TR" # Use a code from https://huggingface.co/facebook/xmod-base#languages
59
+
60
+ backbone = model[0].auto_model
61
+ if backbone.__class__.__name__.lower().startswith("xmod"):
62
+ backbone.set_default_language(language)
63
+ else:
64
+ model = models.ColBERT(
65
+ model_name_or_path=model_name_or_path,
66
+ document_length=document_length,
67
+ attend_to_expansion_tokens=(
68
+ True if "attend" in model_name_or_path else False
69
+ ),
70
+ )
71
+
72
+ model.eval()
73
+ model.to("cuda")
74
+
75
+ dfDocs = pd.read_parquet(
76
+ f"datasets/{datasetname}/corpus/train-00000-of-00001.parquet"
77
+ ).dropna()
78
+ dfQueries = pd.read_parquet(
79
+ f"datasets/{datasetname}/queries/train-00000-of-00001.parquet"
80
+ ).dropna()
81
+
82
+ if "99eren99/TrColBERT" == model_name_or_path:
83
+ try:
84
+ model.tokenizer.model_input_names.remove("token_type_ids")
85
+ except:
86
+ print(model_name_or_path)
87
+ dfDocs.TurkishText = dfDocs.TurkishText.apply(
88
+ lambda x: x.replace("İ", "i").replace("I", "ı").lower()
89
+ )
90
+ dfQueries.TurkishText = dfQueries.TurkishText.apply(
91
+ lambda x: x.replace("İ", "i").replace("I", "ı").lower()
92
+ )
93
+
94
+ # Read test queries
95
+ queries = []
96
+ documents = []
97
+ passage_cand = {}
98
+ relevant_qid = []
99
+ relevant_docs = defaultdict(lambda: defaultdict(int))
100
+
101
+ # read corpus
102
+ newId2oldId_Docs = {}
103
+ for i, row in enumerate(dfDocs.values):
104
+ documents.append(row[2])
105
+ newId2oldId_Docs[i] = str(row[0])
106
+ relevant_qid.append(str(row[0]))
107
+
108
+ # read queries
109
+ newId2oldId_Queries = {}
110
+ for i, row in enumerate(dfQueries.values):
111
+ queries.append(row[2])
112
+ newId2oldId_Queries[i] = str(row[0])
113
+
114
+ for j, rowDoc in enumerate(dfDocs.values):
115
+ relevant_docs[str(row[0])][str(rowDoc[0])] = 0
116
+
117
+ # read qrels
118
+ dfQrels = pd.read_parquet(
119
+ f"datasets/{datasetname}/qrels/train-00000-of-00001.parquet"
120
+ )
121
+ for i, row in enumerate(dfQrels.values):
122
+ relevant_docs[str(row[0])][str(row[1])] = 1
123
+
124
+ candidateIds = [[i for i in range(len(documents))]]
125
+
126
+ queries_result_list = []
127
+ run = {}
128
+
129
+ documents_embeddings = model.encode(
130
+ [documents], is_query=False, show_progress_bar=True
131
+ )
132
+
133
+ for i, query in enumerate(tqdm.tqdm(queries)):
134
+
135
+ queries_embeddings = model.encode(
136
+ [query],
137
+ is_query=True,
138
+ )
139
+
140
+ reranked_documents = rank.rerank(
141
+ documents_ids=candidateIds,
142
+ queries_embeddings=queries_embeddings,
143
+ documents_embeddings=documents_embeddings,
144
+ )
145
+
146
+ run[newId2oldId_Queries[i]] = {}
147
+ for resDict in reranked_documents[0]:
148
+ run[newId2oldId_Queries[i]][newId2oldId_Docs[resDict["id"]]] = float(
149
+ resDict["score"]
150
+ )
151
+
152
+ evaluator = pytrec_eval.RelevanceEvaluator(
153
+ relevant_docs, pytrec_eval.supported_measures
154
+ )
155
+ scores = evaluator.evaluate(run)
156
+
157
+ def print_line(measure, scope, value):
158
+ print("{:25s}{:8s}{:.4f}".format(measure, scope, value))
159
+
160
+ for query_id, query_measures in sorted(scores.items()):
161
+ break
162
+ for measure, value in sorted(query_measures.items()):
163
+ print_line(measure, query_id, value)
164
+
165
+ # Scope hack: use query_measures of last item in previous loop to
166
+ # figure out all unique measure names.
167
+ resultsColumns = ["model name"]
168
+ resultsRow = [model_name_or_path]
169
+ for measure in sorted(query_measures.keys()):
170
+ resultsColumns.append(measure)
171
+ resultsRow.append(
172
+ pytrec_eval.compute_aggregated_measure(
173
+ measure,
174
+ [query_measures[measure] for query_measures in scores.values()],
175
+ )
176
+ )
177
+
178
+ if evalResultsDf is None:
179
+ evalResultsDf = pd.DataFrame(columns=resultsColumns)
180
+ evalResultsDf.loc[-1] = resultsRow
181
+ evalResultsDf.index = evalResultsDf.index + 1
182
+
183
+ evalResultsDf.to_csv(f"resultsn/{datasetname}.csv", encoding="utf-8")