ArneBinder commited on
Commit
e7eaeed
·
verified ·
1 Parent(s): 6a6bb2a

upload https://github.com/ArneBinder/pie-document-level/pull/452

Browse files
src/analysis/combine_job_returns.py CHANGED
@@ -47,6 +47,7 @@ def main(
47
  transpose: bool = False,
48
  unpack_multirun_results: bool = False,
49
  in_percent: bool = False,
 
50
  ):
51
  file_paths = get_file_paths(
52
  paths_file=paths_file, file_name=file_name, use_aggregated=use_aggregated
@@ -97,9 +98,6 @@ def main(
97
  data = data.unstack(index_name)
98
  data = data.T
99
 
100
- if transpose:
101
- data = data.T
102
-
103
  # needs to happen before rounding, otherwise the rounding will be off
104
  if in_percent:
105
  data = data * 100
@@ -107,20 +105,23 @@ def main(
107
  if round_precision is not None:
108
  data = data.round(round_precision)
109
 
110
- if format == "markdown":
111
- print(data.to_markdown())
112
- elif format == "markdown_mean_and_std":
113
- if transpose:
114
- data = data.T
115
  if "mean" not in data.columns or "std" not in data.columns:
116
  raise ValueError("Columns 'mean' and 'std' are required for this format.")
117
  # create a single column with mean and std in the format: mean ± std
118
  data = pd.DataFrame(
119
  data["mean"].astype(str) + " ± " + data["std"].astype(str), columns=["mean ± std"]
120
  )
121
- if transpose:
122
- data = data.T
123
- print(data.to_markdown())
 
 
 
 
 
 
124
  elif format == "json":
125
  print(data.to_json())
126
  else:
@@ -156,6 +157,9 @@ if __name__ == "__main__":
156
  parser.add_argument(
157
  "--in-percent", action="store_true", help="Show the values in percent (multiply by 100)"
158
  )
 
 
 
159
  parser.add_argument(
160
  "--format",
161
  type=str,
 
47
  transpose: bool = False,
48
  unpack_multirun_results: bool = False,
49
  in_percent: bool = False,
50
+ reset_index: bool = False,
51
  ):
52
  file_paths = get_file_paths(
53
  paths_file=paths_file, file_name=file_name, use_aggregated=use_aggregated
 
98
  data = data.unstack(index_name)
99
  data = data.T
100
 
 
 
 
101
  # needs to happen before rounding, otherwise the rounding will be off
102
  if in_percent:
103
  data = data * 100
 
105
  if round_precision is not None:
106
  data = data.round(round_precision)
107
 
108
+ # needs to happen before transposing
109
+ if format == "markdown_mean_and_std":
 
 
 
110
  if "mean" not in data.columns or "std" not in data.columns:
111
  raise ValueError("Columns 'mean' and 'std' are required for this format.")
112
  # create a single column with mean and std in the format: mean ± std
113
  data = pd.DataFrame(
114
  data["mean"].astype(str) + " ± " + data["std"].astype(str), columns=["mean ± std"]
115
  )
116
+
117
+ if transpose:
118
+ data = data.T
119
+
120
+ if reset_index:
121
+ data = data.reset_index()
122
+
123
+ if format in ["markdown", "markdown_mean_and_std"]:
124
+ print(data.to_markdown(index=not reset_index))
125
  elif format == "json":
126
  print(data.to_json())
127
  else:
 
157
  parser.add_argument(
158
  "--in-percent", action="store_true", help="Show the values in percent (multiply by 100)"
159
  )
160
+ parser.add_argument(
161
+ "--reset-index", action="store_true", help="Reset the index of the combined job returns"
162
+ )
163
  parser.add_argument(
164
  "--format",
165
  type=str,
src/analysis/show_score_distribution.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrootutils
2
+
3
+ root = pyrootutils.setup_root(
4
+ search_from=__file__,
5
+ indicator=[".project-root"],
6
+ pythonpath=True,
7
+ dotenv=False,
8
+ )
9
+
10
+ import argparse
11
+ from typing import List, Optional
12
+
13
+ import pandas as pd
14
+ import plotly.figure_factory as ff
15
+ from pie_datasets import DatasetDict
16
+
17
+ pd.options.plotting.backend = "plotly"
18
+
19
+ if __name__ == "__main__":
20
+
21
+ parser = argparse.ArgumentParser(
22
+ description="Show score distribution of annotations per layer"
23
+ )
24
+ # --data-dir predictions/default/2025-02-26_14-28-17
25
+ parser.add_argument(
26
+ "--data-dir", type=str, required=True, help="Path to the dataset directory"
27
+ )
28
+ parser.add_argument("--split", type=str, default="test", help="Dataset split to use")
29
+ parser.add_argument(
30
+ "--layers",
31
+ nargs="+",
32
+ default=["labeled_spans", "binary_relations"],
33
+ help="Annotation layers to use",
34
+ )
35
+ # --layer-captions ADUs "Argumentative Relations"
36
+ parser.add_argument(
37
+ "--layer-captions", nargs="+", help="Captions for the figure traces per layer"
38
+ )
39
+ # --layer-colors "rgb(31,119,180)" "rgb(255,127,14)"
40
+ parser.add_argument("--layer-colors", nargs="+", help="Colors for the figure traces per layer")
41
+
42
+ args = parser.parse_args()
43
+
44
+ # Load the dataset
45
+ ds = DatasetDict.from_json(data_dir=args.data_dir)[args.split]
46
+
47
+ # get scores per annotation layer and label
48
+ layers = args.layers
49
+ all_scores = []
50
+ all_scores_idx = []
51
+ for doc in ds:
52
+ for layer in layers:
53
+ for ann in doc[layer].predictions:
54
+ all_scores.append(ann.score)
55
+ all_scores_idx.append((doc.id, layer, getattr(ann, "label", None)))
56
+ scores = pd.Series(
57
+ all_scores,
58
+ index=pd.MultiIndex.from_tuples(all_scores_idx, names=["doc_id", "layer", "label"]),
59
+ name="score",
60
+ )
61
+
62
+ if args.layer_captions is not None:
63
+ if len(args.layer_captions) < len(layers):
64
+ raise ValueError("Not enough captions provided for all layers")
65
+ name_mapping = dict(zip(layers, args.layer_captions))
66
+ else:
67
+ name_mapping = dict(zip(layers, layers))
68
+
69
+ colors: Optional[List[str]] = None
70
+ if args.layer_colors is not None:
71
+ if len(args.layer_colors) < len(layers):
72
+ raise ValueError("Not enough colors provided for all layers")
73
+ color_mapping = dict(zip(layers, args.layer_colors))
74
+ colors = [color_mapping[layer] for layer in layers]
75
+ else:
76
+ colors = None
77
+
78
+ score_groups = {layer: scores.xs(layer, level="layer").to_numpy() for layer in layers}
79
+ group_labels, hist_data = zip(*score_groups.items())
80
+ group_labels_renamed = [name_mapping[label] for label in group_labels]
81
+ fig = ff.create_distplot(
82
+ hist_data,
83
+ group_labels=group_labels_renamed,
84
+ show_hist=True,
85
+ colors=colors,
86
+ bin_size=0.025,
87
+ )
88
+
89
+ fig.update_layout(
90
+ height=600,
91
+ width=800,
92
+ title_text="Score Distribution per Annotation Layer",
93
+ title_x=0.5,
94
+ barmode="group",
95
+ )
96
+ fig.update_layout(legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01))
97
+
98
+ fig.show()
99
+ print("done")
src/data/calc_iaa_for_brat.py CHANGED
@@ -92,6 +92,7 @@ def calc_brat_iaas(
92
  create_multi_spans=True,
93
  result_document_type=BratDocument,
94
  result_field_mapping={"spans": "spans", "relations": "relations"},
 
95
  )
96
  else:
97
  merger = None
 
92
  create_multi_spans=True,
93
  result_document_type=BratDocument,
94
  result_field_mapping={"spans": "spans", "relations": "relations"},
95
+ combine_scores_method="product",
96
  )
97
  else:
98
  merger = None
src/datamodules/datamodule.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Union
2
 
3
  from pytorch_ie.core import Document
@@ -21,6 +22,8 @@ DatasetType: TypeAlias = Union[
21
  IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
22
  ]
23
 
 
 
24
 
25
  class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, TargetEncoding]):
26
  """A simple LightningDataModule for PIE document datasets.
@@ -49,6 +52,7 @@ class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, Ta
49
  test_split: Optional[str] = "test",
50
  show_progress_for_encode: bool = False,
51
  train_sampler: Optional[str] = None,
 
52
  **dataloader_kwargs,
53
  ):
54
  super().__init__()
@@ -62,6 +66,7 @@ class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, Ta
62
  self.show_progress_for_encode = show_progress_for_encode
63
  self.train_sampler_name = train_sampler
64
  self.dataloader_kwargs = dataloader_kwargs
 
65
 
66
  self._data: Dict[str, DatasetType] = {}
67
 
@@ -128,12 +133,17 @@ class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, Ta
128
  sampler = self.get_train_sampler(sampler_name=self.train_sampler_name, dataset=ds)
129
  else:
130
  sampler = None
 
 
 
 
 
 
131
  return DataLoader(
132
  dataset=ds,
133
  sampler=sampler,
134
  collate_fn=self.taskmodule.collate,
135
- # don't shuffle streamed datasets or if we use a sampler
136
- shuffle=not (isinstance(ds, IterableTaskEncodingDataset) or sampler is not None),
137
  **self.dataloader_kwargs,
138
  )
139
 
 
1
+ import logging
2
  from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Union
3
 
4
  from pytorch_ie.core import Document
 
22
  IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
23
  ]
24
 
25
+ logger = logging.getLogger(__name__)
26
+
27
 
28
  class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, TargetEncoding]):
29
  """A simple LightningDataModule for PIE document datasets.
 
52
  test_split: Optional[str] = "test",
53
  show_progress_for_encode: bool = False,
54
  train_sampler: Optional[str] = None,
55
+ dont_shuffle_train: bool = False,
56
  **dataloader_kwargs,
57
  ):
58
  super().__init__()
 
66
  self.show_progress_for_encode = show_progress_for_encode
67
  self.train_sampler_name = train_sampler
68
  self.dataloader_kwargs = dataloader_kwargs
69
+ self.dont_shuffle_train = dont_shuffle_train
70
 
71
  self._data: Dict[str, DatasetType] = {}
72
 
 
133
  sampler = self.get_train_sampler(sampler_name=self.train_sampler_name, dataset=ds)
134
  else:
135
  sampler = None
136
+ # don't shuffle streamed datasets or if we use a sampler or if we explicitly set dont_shuffle_train
137
+ shuffle = not self.dont_shuffle_train and not (
138
+ isinstance(ds, IterableTaskEncodingDataset) or sampler is not None
139
+ )
140
+ if not shuffle:
141
+ logger.warning("not shuffling train dataloader")
142
  return DataLoader(
143
  dataset=ds,
144
  sampler=sampler,
145
  collate_fn=self.taskmodule.collate,
146
+ shuffle=shuffle,
 
147
  **self.dataloader_kwargs,
148
  )
149
 
src/demo/annotation_utils.py CHANGED
@@ -37,6 +37,7 @@ def get_merger() -> SpansViaRelationMerger:
37
  "binary_relations": "binary_relations",
38
  "labeled_partitions": "labeled_partitions",
39
  },
 
40
  )
41
 
42
 
 
37
  "binary_relations": "binary_relations",
38
  "labeled_partitions": "labeled_partitions",
39
  },
40
+ combine_scores_method="product",
41
  )
42
 
43
 
src/demo/retrieve_and_dump_all_relevant.py CHANGED
@@ -10,9 +10,17 @@ root = pyrootutils.setup_root(
10
  import argparse
11
  import logging
12
  import os
 
13
 
14
  import pandas as pd
 
 
 
15
 
 
 
 
 
16
  from src.demo.retriever_utils import (
17
  retrieve_all_relevant_spans,
18
  retrieve_all_relevant_spans_for_all_documents,
@@ -23,6 +31,168 @@ from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
23
  logger = logging.getLogger(__name__)
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  if __name__ == "__main__":
27
 
28
  parser = argparse.ArgumentParser()
@@ -81,6 +251,19 @@ if __name__ == "__main__":
81
  '(each separated by ":") to retrieve spans for. If provided, '
82
  "--query_doc_id and --query_span_id are ignored.",
83
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  args = parser.parse_args()
85
 
86
  logging.basicConfig(
@@ -157,4 +340,17 @@ if __name__ == "__main__":
157
  os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
158
  all_spans_for_all_documents.to_json(args.output_path)
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  logger.info("done")
 
10
  import argparse
11
  import logging
12
  import os
13
+ from typing import Dict, List, Optional, Tuple
14
 
15
  import pandas as pd
16
+ from pie_datasets import Dataset, DatasetDict
17
+ from pytorch_ie import Annotation
18
+ from pytorch_ie.annotations import BinaryRelation, MultiSpan, Span
19
 
20
+ from document.types import (
21
+ RelatedRelation,
22
+ TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
23
+ )
24
  from src.demo.retriever_utils import (
25
  retrieve_all_relevant_spans,
26
  retrieve_all_relevant_spans_for_all_documents,
 
31
  logger = logging.getLogger(__name__)
32
 
33
 
34
+ def get_original_doc_id_and_offsets(doc_id: str) -> Tuple[str, int, Optional[int]]:
35
+ original_doc_id, middle, start_end, ext = doc_id.split(".")
36
+ if middle == "remaining":
37
+ return original_doc_id, int(start_end), None
38
+ elif middle == "abstract":
39
+ start, end = start_end.split("_")
40
+ return original_doc_id, int(start), int(end)
41
+ else:
42
+ raise ValueError(f"unexpected doc_id format: {doc_id}")
43
+
44
+
45
+ def add_base_annotations(
46
+ documents: Dict[
47
+ str, TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations
48
+ ],
49
+ retrieved_doc_ids: List[str],
50
+ retriever: DocumentAwareSpanRetrieverWithRelations,
51
+ ) -> Dict[Tuple[str, Annotation], Tuple[str, Annotation]]:
52
+ # (retrieved_doc_id, retrieved_annotation) -> (original_doc_id, original_annotation)
53
+ annotation_mapping = {}
54
+ for retrieved_doc_id in retrieved_doc_ids:
55
+ pie_doc = retriever.get_document(retrieved_doc_id).metadata["pie_document"].copy()
56
+ original_doc_id, offset, _ = get_original_doc_id_and_offsets(retrieved_doc_id)
57
+ document = documents[original_doc_id]
58
+ span_mapping = {}
59
+ for span in pie_doc.labeled_multi_spans.predictions:
60
+ if isinstance(span, MultiSpan):
61
+ new_span = span.copy(
62
+ slices=[(start + offset, end + offset) for start, end in span.slices]
63
+ )
64
+ elif isinstance(span, Span):
65
+ new_span = span.copy(start=span.start + offset, end=span.end + offset)
66
+ else:
67
+ raise ValueError(f"unexpected span type: {span}")
68
+ span_mapping[span] = new_span
69
+ document.labeled_multi_spans.predictions.extend(span_mapping.values())
70
+ for relation in pie_doc.binary_relations.predictions:
71
+ new_relation = relation.copy(
72
+ head=span_mapping[relation.head], tail=span_mapping[relation.tail]
73
+ )
74
+ document.binary_relations.predictions.append(new_relation)
75
+ for old_ann, new_ann in span_mapping.items():
76
+ annotation_mapping[(retrieved_doc_id, old_ann)] = (original_doc_id, new_ann)
77
+
78
+ return annotation_mapping
79
+
80
+
81
+ def get_doc_and_span_id2annotation_mapping(
82
+ span_ids: pd.Series,
83
+ doc_ids: pd.Series,
84
+ retriever: DocumentAwareSpanRetrieverWithRelations,
85
+ base_annotation_mapping: Dict[Tuple[str, Annotation], Tuple[str, Annotation]],
86
+ ) -> Dict[Tuple[str, str], Tuple[str, Annotation]]:
87
+ if len(doc_ids) != len(span_ids):
88
+ raise ValueError("doc_ids and span_ids must have the same length")
89
+ doc_and_span_ids = zip(doc_ids.tolist(), span_ids.tolist())
90
+ return {
91
+ (doc_id, span_id): base_annotation_mapping[(doc_id, retriever.get_span_by_id(span_id))]
92
+ for doc_id, span_id in set(doc_and_span_ids)
93
+ }
94
+
95
+
96
+ def add_result_to_gold_data(
97
+ result: pd.DataFrame,
98
+ gold_dataset_dir: str,
99
+ dataset_out_dir: str,
100
+ retriever: DocumentAwareSpanRetrieverWithRelations,
101
+ split: Optional[str] = None,
102
+ link_relation_label: str = "semantically_same",
103
+ reversed_relation_suffix: str = "_reversed",
104
+ ):
105
+
106
+ if not os.path.exists(gold_dataset_dir):
107
+ raise ValueError(f"gold dataset directory does not exist: {gold_dataset_dir}")
108
+
109
+ dataset_dict = DatasetDict.from_json(data_dir=gold_dataset_dir)
110
+ if split is None and len(dataset_dict) == 1:
111
+ split = list(dataset_dict.keys())[0]
112
+ if split is None:
113
+ raise ValueError("need to provide split name to add results to gold dataset")
114
+
115
+ dataset = dataset_dict[split]
116
+
117
+ doc_id2doc = {doc.id: doc for doc in dataset}
118
+ retriever_doc_ids = (
119
+ result["doc_id"].unique().tolist() + result["query_doc_id"].unique().tolist()
120
+ )
121
+ base_annotation_mapping = add_base_annotations(
122
+ documents=doc_id2doc, retrieved_doc_ids=retriever_doc_ids, retriever=retriever
123
+ )
124
+ # (retriever_doc_id, retriever_span_id) -> (original_doc_id, original_span)
125
+ doc_and_span_id2annotation = {}
126
+ doc_and_span_id2annotation.update(
127
+ get_doc_and_span_id2annotation_mapping(
128
+ span_ids=result["span_id"],
129
+ doc_ids=result["doc_id"],
130
+ retriever=retriever,
131
+ base_annotation_mapping=base_annotation_mapping,
132
+ )
133
+ )
134
+ doc_and_span_id2annotation.update(
135
+ get_doc_and_span_id2annotation_mapping(
136
+ span_ids=result["ref_span_id"],
137
+ doc_ids=result["doc_id"],
138
+ retriever=retriever,
139
+ base_annotation_mapping=base_annotation_mapping,
140
+ )
141
+ )
142
+ doc_and_span_id2annotation.update(
143
+ get_doc_and_span_id2annotation_mapping(
144
+ span_ids=result["query_span_id"],
145
+ doc_ids=result["query_doc_id"],
146
+ retriever=retriever,
147
+ base_annotation_mapping=base_annotation_mapping,
148
+ )
149
+ )
150
+ doc_id2head_tail2relation = {}
151
+ for doc_id, doc in doc_id2doc.items():
152
+ head_and_tail2relation = {}
153
+ for relation in doc.binary_relations.predictions:
154
+ head_and_tail2relation[(relation.head, relation.tail)] = relation
155
+ doc_id2head_tail2relation[doc_id] = head_and_tail2relation
156
+
157
+ for row in result.itertuples():
158
+ query_doc_id, query_span = doc_and_span_id2annotation[
159
+ (row.query_doc_id, row.query_span_id)
160
+ ]
161
+ doc_id, span = doc_and_span_id2annotation[(row.doc_id, row.span_id)]
162
+ doc_id2, ref_span = doc_and_span_id2annotation[(row.doc_id, row.ref_span_id)]
163
+ if doc_id != query_doc_id:
164
+ raise ValueError("doc_id and query_doc_id must be the same")
165
+ if doc_id != doc_id2:
166
+ raise ValueError("doc_id and ref_doc_id must be the same")
167
+ doc = doc_id2doc[doc_id]
168
+ link_rel = BinaryRelation(
169
+ head=query_span, tail=ref_span, label=link_relation_label, score=row.sim_score
170
+ )
171
+ doc.binary_relations.predictions.append(link_rel)
172
+ head_and_tail2relation = doc_id2head_tail2relation[doc_id]
173
+ related_rel_label = row.type
174
+ if related_rel_label.endswith(reversed_relation_suffix):
175
+ base_rel = head_and_tail2relation[(span, ref_span)]
176
+ else:
177
+ base_rel = head_and_tail2relation[(ref_span, span)]
178
+ related_rel = RelatedRelation(
179
+ head=query_span,
180
+ tail=span,
181
+ link_relation=link_rel,
182
+ relation=base_rel,
183
+ label=related_rel_label,
184
+ score=link_rel.score * base_rel.score,
185
+ )
186
+ doc.related_relations.predictions.append(related_rel)
187
+
188
+ dataset = Dataset.from_documents(list(doc_id2doc.values()))
189
+ dataset_dict = DatasetDict({split: dataset})
190
+ if not os.path.exists(dataset_out_dir):
191
+ os.makedirs(dataset_out_dir, exist_ok=True)
192
+
193
+ dataset_dict.to_json(dataset_out_dir)
194
+
195
+
196
  if __name__ == "__main__":
197
 
198
  parser = argparse.ArgumentParser()
 
251
  '(each separated by ":") to retrieve spans for. If provided, '
252
  "--query_doc_id and --query_span_id are ignored.",
253
  )
254
+ parser.add_argument(
255
+ "--gold_dataset_dir",
256
+ type=str,
257
+ default=None,
258
+ help="If provided, add the spans and base relations from the retriever data as well "
259
+ "as the related relations to the gold dataset.",
260
+ )
261
+ parser.add_argument(
262
+ "--dataset_out_dir",
263
+ type=str,
264
+ default=None,
265
+ help="If provided, save the enriched gold dataset to this directory.",
266
+ )
267
  args = parser.parse_args()
268
 
269
  logging.basicConfig(
 
340
  os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
341
  all_spans_for_all_documents.to_json(args.output_path)
342
 
343
+ if args.gold_dataset_dir is not None:
344
+ logger.info(
345
+ f"reading gold data from {args.gold_dataset_dir} and adding results as predictions ..."
346
+ )
347
+ if args.dataset_out_dir is None:
348
+ raise ValueError("need to provide --dataset_out_dir to save the enriched dataset")
349
+ add_result_to_gold_data(
350
+ all_spans_for_all_documents,
351
+ gold_dataset_dir=args.gold_dataset_dir,
352
+ dataset_out_dir=args.dataset_out_dir,
353
+ retriever=retriever,
354
+ )
355
+
356
  logger.info("done")
src/demo/retriever_utils.py CHANGED
@@ -51,6 +51,7 @@ def load_retriever(
51
  def retrieve_similar_spans(
52
  retriever: DocumentAwareSpanRetriever,
53
  query_span_id: str,
 
54
  **kwargs,
55
  ) -> pd.DataFrame:
56
  if not query_span_id.strip():
@@ -60,21 +61,42 @@ def retrieve_similar_spans(
60
  records = []
61
  for similar_span_doc in retrieval_result:
62
  pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc)
 
 
63
  span_ann = metadata["attached_span"]
 
 
 
64
  records.append(
65
  {
 
66
  "doc_id": pie_doc.id,
67
  "span_id": similar_span_doc.id,
68
- "score": metadata["relevance_score"],
 
 
69
  "label": span_ann.label,
70
  "text": str(span_ann),
71
  }
72
  )
73
- return (
74
- pd.DataFrame(records, columns=["doc_id", "score", "label", "text", "span_id"])
 
 
 
 
 
 
 
 
 
 
 
 
75
  .sort_values(by="score", ascending=False)
76
  .round(3)
77
  )
 
78
  except Exception as e:
79
  raise gr.Error(f"Failed to retrieve similar ADUs: {e}")
80
 
@@ -83,6 +105,7 @@ def retrieve_relevant_spans(
83
  retriever: DocumentAwareSpanRetriever,
84
  query_span_id: str,
85
  relation_label_mapping: Optional[dict[str, str]] = None,
 
86
  **kwargs,
87
  ) -> pd.DataFrame:
88
  if not query_span_id.strip():
@@ -98,40 +121,57 @@ def retrieve_relevant_spans(
98
  mapped_relation_label = relation_label_mapping.get(
99
  metadata["relation_label"], metadata["relation_label"]
100
  )
 
 
 
 
 
 
 
 
101
  records.append(
102
  {
103
  "doc_id": pie_doc.id,
104
  "type": mapped_relation_label,
105
- "rel_score": metadata["relation_score"],
 
106
  "text": str(tail_span_ann),
107
  "span_id": relevant_span_doc.id,
 
108
  "label": tail_span_ann.label,
109
- "ref_score": metadata["relevance_score"],
110
  "ref_label": span_ann.label,
111
  "ref_text": str(span_ann),
112
  "ref_span_id": metadata["head_id"],
 
 
113
  }
114
  )
115
- return (
116
  pd.DataFrame(
117
  records,
118
  columns=[
 
119
  "type",
120
- # omitted for now, we get no valid relation scores for the generative model
121
- # "rel_score",
122
- "ref_score",
123
- "label",
124
  "text",
 
 
 
125
  "ref_label",
 
126
  "ref_text",
127
  "doc_id",
128
  "span_id",
 
129
  "ref_span_id",
 
130
  ],
131
  )
132
- .sort_values(by=["ref_score"], ascending=False)
133
  .round(3)
134
  )
 
 
135
  except Exception as e:
136
  raise gr.Error(f"Failed to retrieve relevant ADUs: {e}")
137
 
 
51
  def retrieve_similar_spans(
52
  retriever: DocumentAwareSpanRetriever,
53
  query_span_id: str,
54
+ min_score: float = 0.0,
55
  **kwargs,
56
  ) -> pd.DataFrame:
57
  if not query_span_id.strip():
 
61
  records = []
62
  for similar_span_doc in retrieval_result:
63
  pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc)
64
+ query_span = retriever.get_span_by_id(span_id=query_span_id)
65
+ query_span_score = query_span.score
66
  span_ann = metadata["attached_span"]
67
+ sim_score = metadata["relevance_score"]
68
+ span_score = span_ann.score
69
+ score = query_span_score * sim_score * span_score
70
  records.append(
71
  {
72
+ "score": score,
73
  "doc_id": pie_doc.id,
74
  "span_id": similar_span_doc.id,
75
+ "sim_score": sim_score,
76
+ "query_span_score": query_span_score,
77
+ "span_score": span_score,
78
  "label": span_ann.label,
79
  "text": str(span_ann),
80
  }
81
  )
82
+ result = (
83
+ pd.DataFrame(
84
+ records,
85
+ columns=[
86
+ "score",
87
+ "text",
88
+ "label",
89
+ "sim_score",
90
+ "span_score",
91
+ "query_span_score",
92
+ "doc_id",
93
+ "span_id",
94
+ ],
95
+ )
96
  .sort_values(by="score", ascending=False)
97
  .round(3)
98
  )
99
+ return result[result["score"] >= min_score]
100
  except Exception as e:
101
  raise gr.Error(f"Failed to retrieve similar ADUs: {e}")
102
 
 
105
  retriever: DocumentAwareSpanRetriever,
106
  query_span_id: str,
107
  relation_label_mapping: Optional[dict[str, str]] = None,
108
+ min_score: float = 0.0,
109
  **kwargs,
110
  ) -> pd.DataFrame:
111
  if not query_span_id.strip():
 
121
  mapped_relation_label = relation_label_mapping.get(
122
  metadata["relation_label"], metadata["relation_label"]
123
  )
124
+
125
+ query_span = retriever.get_span_by_id(span_id=query_span_id)
126
+ query_span_score = query_span.score
127
+ sim_score = metadata["relevance_score"]
128
+ ref_span_score = span_ann.score
129
+ rel_score = metadata["relation_score"]
130
+ span_score = tail_span_ann.score
131
+ score = query_span_score * sim_score * ref_span_score * rel_score * span_score
132
  records.append(
133
  {
134
  "doc_id": pie_doc.id,
135
  "type": mapped_relation_label,
136
+ "score": score,
137
+ "rel_score": rel_score,
138
  "text": str(tail_span_ann),
139
  "span_id": relevant_span_doc.id,
140
+ "span_score": span_score,
141
  "label": tail_span_ann.label,
142
+ "sim_score": sim_score,
143
  "ref_label": span_ann.label,
144
  "ref_text": str(span_ann),
145
  "ref_span_id": metadata["head_id"],
146
+ "ref_span_score": ref_span_score,
147
+ "query_span_score": query_span_score,
148
  }
149
  )
150
+ result = (
151
  pd.DataFrame(
152
  records,
153
  columns=[
154
+ "score",
155
  "type",
 
 
 
 
156
  "text",
157
+ "label",
158
+ "rel_score",
159
+ "sim_score",
160
  "ref_label",
161
+ "ref_span_score",
162
  "ref_text",
163
  "doc_id",
164
  "span_id",
165
+ "span_score",
166
  "ref_span_id",
167
+ "query_span_score",
168
  ],
169
  )
170
+ .sort_values(by=["score"], ascending=False)
171
  .round(3)
172
  )
173
+ return result[result["score"] >= min_score]
174
+
175
  except Exception as e:
176
  raise gr.Error(f"Failed to retrieve relevant ADUs: {e}")
177
 
src/document/processing.py CHANGED
@@ -1,16 +1,20 @@
1
  from __future__ import annotations
2
 
3
  import logging
4
- from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
 
5
 
6
- from pie_modules.document.processing.merge_spans_via_relation import _merge_spans_via_relation
7
- from pie_modules.documents import TextDocumentWithLabeledMultiSpansAndBinaryRelations
8
  from pie_modules.utils.span import have_overlap
9
  from pytorch_ie import AnnotationLayer
 
10
  from pytorch_ie.core import Document
11
  from pytorch_ie.core.document import Annotation, _enumerate_dependencies
12
 
13
- from src.utils import distance
 
 
 
 
14
  from src.utils.span_utils import get_overlap_len
15
 
16
  logger = logging.getLogger(__name__)
@@ -68,58 +72,6 @@ def remove_overlapping_entities(
68
  return new_doc
69
 
70
 
71
- # TODO: remove and use pie_modules.document.processing.SpansViaRelationMerger instead
72
- def merge_spans_via_relation(
73
- document: D,
74
- relation_layer: str,
75
- link_relation_label: str,
76
- use_predicted_spans: bool = False,
77
- process_predictions: bool = True,
78
- create_multi_spans: bool = False,
79
- ) -> D:
80
-
81
- rel_layer = document[relation_layer]
82
- span_layer = rel_layer.target_layer
83
- new_gold_spans, new_gold_relations = _merge_spans_via_relation(
84
- spans=span_layer,
85
- relations=rel_layer,
86
- link_relation_label=link_relation_label,
87
- create_multi_spans=create_multi_spans,
88
- )
89
- if process_predictions:
90
- new_pred_spans, new_pred_relations = _merge_spans_via_relation(
91
- spans=span_layer.predictions if use_predicted_spans else span_layer,
92
- relations=rel_layer.predictions,
93
- link_relation_label=link_relation_label,
94
- create_multi_spans=create_multi_spans,
95
- )
96
- else:
97
- assert not use_predicted_spans
98
- new_pred_spans = set(span_layer.predictions.clear())
99
- new_pred_relations = set(rel_layer.predictions.clear())
100
-
101
- relation_layer_name = relation_layer
102
- span_layer_name = document[relation_layer].target_name
103
- if create_multi_spans:
104
- doc_dict = document.asdict()
105
- for f in document.annotation_fields():
106
- doc_dict.pop(f.name)
107
-
108
- result = TextDocumentWithLabeledMultiSpansAndBinaryRelations.fromdict(doc_dict)
109
- result.labeled_multi_spans.extend(new_gold_spans)
110
- result.labeled_multi_spans.predictions.extend(new_pred_spans)
111
- result.binary_relations.extend(new_gold_relations)
112
- result.binary_relations.predictions.extend(new_pred_relations)
113
- else:
114
- result = document.copy(with_annotations=False)
115
- result[span_layer_name].extend(new_gold_spans)
116
- result[span_layer_name].predictions.extend(new_pred_spans)
117
- result[relation_layer_name].extend(new_gold_relations)
118
- result[relation_layer_name].predictions.extend(new_pred_relations)
119
-
120
- return result
121
-
122
-
123
  def remove_partitions_by_labels(
124
  document: D, partition_layer: str, label_blacklist: List[str], span_layer: Optional[str] = None
125
  ) -> D:
@@ -249,31 +201,19 @@ def relabel_annotations(
249
  DWithSpans = TypeVar("DWithSpans", bound=Document)
250
 
251
 
252
- def align_predicted_span_annotations(
253
- document: DWithSpans, span_layer: str, distance_type: str = "center", verbose: bool = False
254
- ) -> DWithSpans:
255
- """
256
- Aligns predicted span annotations with the closest gold spans in a document.
257
-
258
- First, calculates the distance between each predicted span and each gold span. Then,
259
- for each predicted span, the gold span with the smallest distance is selected. If the
260
- predicted span and the gold span have an overlap of at least half of the maximum length
261
- of the two spans, the predicted span is aligned with the gold span.
262
-
263
- Args:
264
- document: The document to process.
265
- span_layer: The name of the span layer.
266
- distance_type: The type of distance to calculate. One of: center, inner, outer
267
- verbose: Whether to print debug information.
268
 
269
- Returns:
270
- The processed document.
271
- """
272
- gold_spans = document[span_layer]
273
- if len(gold_spans) == 0:
274
- return document.copy()
275
 
276
- pred_spans = document[span_layer].predictions
 
 
277
  old2new_pred_span = {}
278
  span_id2gold_span = {}
279
  for pred_span in pred_spans:
@@ -282,29 +222,32 @@ def align_predicted_span_annotations(
282
  (
283
  gold_span,
284
  distance(
285
- start_end=(pred_span.start, pred_span.end),
286
- other_start_end=(gold_span.start, gold_span.end),
287
  distance_type=distance_type,
288
  ),
289
  )
290
  for gold_span in gold_spans
291
  ]
 
 
292
 
293
  closest_gold_span, min_distance = min(gold_spans_with_distance, key=lambda x: x[1])
294
  # if the closest gold span is the same as the predicted span, we don't need to align
295
  if min_distance == 0.0:
296
  continue
297
 
 
 
 
298
  if have_overlap(
299
- start_end=(pred_span.start, pred_span.end),
300
- other_start_end=(closest_gold_span.start, closest_gold_span.end),
301
  ):
302
- overlap_len = get_overlap_len(
303
- (pred_span.start, pred_span.end), (closest_gold_span.start, closest_gold_span.end)
304
- )
305
- # get the maximum length of the two spans
306
  l_max = max(
307
- pred_span.end - pred_span.start, closest_gold_span.end - closest_gold_span.start
 
308
  )
309
  # if the overlap is at least half of the maximum length, we consider it a valid match for alignment
310
  valid_match = overlap_len >= (l_max / 2)
@@ -312,12 +255,140 @@ def align_predicted_span_annotations(
312
  valid_match = False
313
 
314
  if valid_match:
315
- aligned_pred_span = pred_span.copy(
316
- start=closest_gold_span.start, end=closest_gold_span.end
317
- )
 
 
 
 
 
318
  old2new_pred_span[pred_span._id] = aligned_pred_span
319
  span_id2gold_span[pred_span._id] = closest_gold_span
320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  result = document.copy(with_annotations=False)
322
 
323
  # multiple predicted spans can be aligned with the same gold span,
@@ -356,3 +427,88 @@ def align_predicted_span_annotations(
356
  )
357
 
358
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import logging
4
+ from collections import defaultdict
5
+ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union
6
 
 
 
7
  from pie_modules.utils.span import have_overlap
8
  from pytorch_ie import AnnotationLayer
9
+ from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan, MultiSpan, Span
10
  from pytorch_ie.core import Document
11
  from pytorch_ie.core.document import Annotation, _enumerate_dependencies
12
 
13
+ from src.document.types import (
14
+ RelatedRelation,
15
+ TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
16
+ )
17
+ from src.utils import distance, distance_slices
18
  from src.utils.span_utils import get_overlap_len
19
 
20
  logger = logging.getLogger(__name__)
 
72
  return new_doc
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def remove_partitions_by_labels(
76
  document: D, partition_layer: str, label_blacklist: List[str], span_layer: Optional[str] = None
77
  ) -> D:
 
201
  DWithSpans = TypeVar("DWithSpans", bound=Document)
202
 
203
 
204
+ def get_start_end(span: Union[Span, MultiSpan]) -> Tuple[int, int]:
205
+ if isinstance(span, Span):
206
+ return span.start, span.end
207
+ elif isinstance(span, MultiSpan):
208
+ starts, ends = zip(*span.slices)
209
+ return min(starts), max(ends)
210
+ else:
211
+ raise ValueError(f"Unsupported span type: {type(span)}")
 
 
 
 
 
 
 
 
212
 
 
 
 
 
 
 
213
 
214
+ def _get_aligned_span_mappings(
215
+ gold_spans: Iterable[Span], pred_spans: Iterable[Span], distance_type: str
216
+ ) -> Tuple[Dict[int, Span], Dict[int, Span]]:
217
  old2new_pred_span = {}
218
  span_id2gold_span = {}
219
  for pred_span in pred_spans:
 
222
  (
223
  gold_span,
224
  distance(
225
+ start_end=get_start_end(pred_span),
226
+ other_start_end=get_start_end(gold_span),
227
  distance_type=distance_type,
228
  ),
229
  )
230
  for gold_span in gold_spans
231
  ]
232
+ if len(gold_spans_with_distance) == 0:
233
+ continue
234
 
235
  closest_gold_span, min_distance = min(gold_spans_with_distance, key=lambda x: x[1])
236
  # if the closest gold span is the same as the predicted span, we don't need to align
237
  if min_distance == 0.0:
238
  continue
239
 
240
+ pred_start_end = get_start_end(pred_span)
241
+ closest_gold_start_end = get_start_end(closest_gold_span)
242
+
243
  if have_overlap(
244
+ start_end=pred_start_end,
245
+ other_start_end=closest_gold_start_end,
246
  ):
247
+ overlap_len = get_overlap_len(pred_start_end, closest_gold_start_end)
 
 
 
248
  l_max = max(
249
+ pred_start_end[1] - pred_start_end[0],
250
+ closest_gold_start_end[1] - closest_gold_start_end[0],
251
  )
252
  # if the overlap is at least half of the maximum length, we consider it a valid match for alignment
253
  valid_match = overlap_len >= (l_max / 2)
 
255
  valid_match = False
256
 
257
  if valid_match:
258
+ if isinstance(pred_span, Span):
259
+ aligned_pred_span = pred_span.copy(
260
+ start=closest_gold_span.start, end=closest_gold_span.end
261
+ )
262
+ elif isinstance(pred_span, MultiSpan):
263
+ aligned_pred_span = pred_span.copy(slices=closest_gold_span.slices)
264
+ else:
265
+ raise ValueError(f"Unsupported span type: {type(pred_span)}")
266
  old2new_pred_span[pred_span._id] = aligned_pred_span
267
  span_id2gold_span[pred_span._id] = closest_gold_span
268
 
269
+ return old2new_pred_span, span_id2gold_span
270
+
271
+
272
+ def get_spans2multi_spans_mapping(multi_spans: Iterable[MultiSpan]) -> Dict[Span, MultiSpan]:
273
+ result = {}
274
+ for multi_span in multi_spans:
275
+ for start, end in multi_span.slices:
276
+ span_kwargs = dict(start=start, end=end, score=multi_span.score)
277
+ if isinstance(multi_span, LabeledMultiSpan):
278
+ result[LabeledSpan(label=multi_span.label, **span_kwargs)] = multi_span
279
+ else:
280
+ result[Span(**span_kwargs)] = multi_span
281
+
282
+ return result
283
+
284
+
285
+ def align_predicted_span_annotations(
286
+ document: DWithSpans,
287
+ span_layer: str,
288
+ distance_type: str = "center",
289
+ simple_multi_span: bool = False,
290
+ verbose: bool = False,
291
+ ) -> DWithSpans:
292
+ """
293
+ Aligns predicted span annotations with the closest gold spans in a document.
294
+
295
+ First, calculates the distance between each predicted span and each gold span. Then,
296
+ for each predicted span, the gold span with the smallest distance is selected. If the
297
+ predicted span and the gold span have an overlap of at least half of the maximum length
298
+ of the two spans, the predicted span is aligned with the gold span.
299
+
300
+ This also works for MultiSpan annotations, where the slices of the MultiSpan are used
301
+ to align the predicted spans. If any of the slices is aligned with a gold slice,
302
+ the MultiSpan is aligned with the respective gold MultiSpan. However, this may result in
303
+ the predicted MultiSpan being aligned with multiple gold MultiSpans, in which case the
304
+ closest gold MultiSpan is selected. A simplified version of this alignment can be achieved
305
+ by setting `simple_multi_span=True`, which treats MultiSpan annotations as simple Spans
306
+ by using their maximum and minimum start and end indices.
307
+
308
+ Args:
309
+ document: The document to process.
310
+ span_layer: The name of the span layer.
311
+ distance_type: The type of distance to calculate. One of: center, inner, outer
312
+ simple_multi_span: Whether to treat MultiSpan annotations as simple Spans by using their
313
+ maximum and minimum start and end indices.
314
+ verbose: Whether to print debug information.
315
+
316
+ Returns:
317
+ The processed document.
318
+ """
319
+ gold_spans = document[span_layer]
320
+ if len(gold_spans) == 0:
321
+ return document.copy()
322
+
323
+ pred_spans = document[span_layer].predictions
324
+ span_annotation_type = document.annotation_types()[span_layer]
325
+ if issubclass(span_annotation_type, Span) or simple_multi_span:
326
+ old2new_pred_span, span_id2gold_span = _get_aligned_span_mappings(
327
+ gold_spans=gold_spans, pred_spans=pred_spans, distance_type=distance_type
328
+ )
329
+ elif issubclass(span_annotation_type, MultiSpan):
330
+ # create Span objects from MultiSpan slices
331
+ gold_single_spans2multi_spans = get_spans2multi_spans_mapping(gold_spans)
332
+ pred_single_spans2multi_spans = get_spans2multi_spans_mapping(pred_spans)
333
+ # create the alignment mappings for the single spans
334
+ single_old2new_pred_span, single_span_id2gold_span = _get_aligned_span_mappings(
335
+ gold_spans=gold_single_spans2multi_spans.keys(),
336
+ pred_spans=pred_single_spans2multi_spans.keys(),
337
+ distance_type=distance_type,
338
+ )
339
+ # collect all Spans that are part of the same MultiSpan
340
+ pred_multi_span2single_spans: Dict[MultiSpan, List[Span]] = defaultdict(list)
341
+ for pred_span, multi_span in pred_single_spans2multi_spans.items():
342
+ pred_multi_span2single_spans[multi_span].append(pred_span)
343
+
344
+ # create the new mappings for the MultiSpans
345
+ old2new_pred_span = {}
346
+ span_id2gold_span = {}
347
+ for pred_multi_span, pred_single_spans in pred_multi_span2single_spans.items():
348
+ # if any of the single spans is aligned with a gold span, align the multi span
349
+ if any(
350
+ pred_single_span._id in single_old2new_pred_span
351
+ for pred_single_span in pred_single_spans
352
+ ):
353
+ # get aligned gold multi spans
354
+ aligned_gold_multi_spans = set()
355
+ for pred_single_span in pred_single_spans:
356
+ if pred_single_span._id in single_old2new_pred_span:
357
+ aligned_gold_single_span = single_span_id2gold_span[pred_single_span._id]
358
+ aligned_gold_multi_span = gold_single_spans2multi_spans[
359
+ aligned_gold_single_span
360
+ ]
361
+ aligned_gold_multi_spans.add(aligned_gold_multi_span)
362
+
363
+ # calculate distances between the predicted multi span and the aligned gold multi spans
364
+ gold_multi_spans_with_distance = [
365
+ (
366
+ gold_multi_span,
367
+ distance_slices(
368
+ slices=pred_multi_span.slices,
369
+ other_slices=gold_multi_span.slices,
370
+ distance_type=distance_type,
371
+ ),
372
+ )
373
+ for gold_multi_span in aligned_gold_multi_spans
374
+ ]
375
+
376
+ if len(aligned_gold_multi_spans) > 1:
377
+ logger.warning(
378
+ f"Multiple gold multi spans aligned with predicted multi span ({pred_multi_span}): "
379
+ f"{aligned_gold_multi_spans}"
380
+ )
381
+ # get the closest gold multi span
382
+ closest_gold_multi_span, min_distance = min(
383
+ gold_multi_spans_with_distance, key=lambda x: x[1]
384
+ )
385
+ old2new_pred_span[pred_multi_span._id] = pred_multi_span.copy(
386
+ slices=closest_gold_multi_span.slices
387
+ )
388
+ span_id2gold_span[pred_multi_span._id] = closest_gold_multi_span
389
+ else:
390
+ raise ValueError(f"Unsupported span annotation type: {span_annotation_type}")
391
+
392
  result = document.copy(with_annotations=False)
393
 
394
  # multiple predicted spans can be aligned with the same gold span,
 
427
  )
428
 
429
  return result
430
+
431
+
432
+ def add_related_relations_from_binary_relations(
433
+ document: TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
434
+ link_relation_label: str,
435
+ link_partition_whitelist: Optional[List[List[str]]] = None,
436
+ relation_label_whitelist: Optional[List[str]] = None,
437
+ reversed_relation_suffix: str = "_reversed",
438
+ symmetric_relations: Optional[List[str]] = None,
439
+ ) -> TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations:
440
+ span2partition = {}
441
+ for multi_span in document.labeled_multi_spans:
442
+ found_partition = False
443
+ for partition in document.labeled_partitions or [
444
+ LabeledSpan(start=0, end=len(document.text), label="ALL")
445
+ ]:
446
+ starts, ends = zip(*multi_span.slices)
447
+ if partition.start <= min(starts) and max(ends) <= partition.end:
448
+ span2partition[multi_span] = partition
449
+ found_partition = True
450
+ break
451
+ if not found_partition:
452
+ raise ValueError(f"No partition found for multi_span {multi_span}")
453
+
454
+ rel_head2rels = defaultdict(list)
455
+ rel_tail2rels = defaultdict(list)
456
+ for rel in document.binary_relations:
457
+ rel_head2rels[rel.head].append(rel)
458
+ rel_tail2rels[rel.tail].append(rel)
459
+
460
+ link_partition_whitelist_tuples = None
461
+ if link_partition_whitelist is not None:
462
+ link_partition_whitelist_tuples = {tuple(pair) for pair in link_partition_whitelist}
463
+
464
+ skipped_labels = []
465
+ for link_rel in document.binary_relations:
466
+ if link_rel.label == link_relation_label:
467
+ head_partition = span2partition[link_rel.head]
468
+ tail_partition = span2partition[link_rel.tail]
469
+ if link_partition_whitelist_tuples is None or (
470
+ (head_partition.label, tail_partition.label) in link_partition_whitelist_tuples
471
+ ):
472
+ # link_head -> link_tail == rel_head -> rel_tail
473
+ for rel in rel_head2rels.get(link_rel.tail, []):
474
+ label = rel.label
475
+ if relation_label_whitelist is None or label in relation_label_whitelist:
476
+ new_rel = RelatedRelation(
477
+ head=link_rel.head,
478
+ tail=rel.tail,
479
+ link_relation=link_rel,
480
+ relation=rel,
481
+ label=label,
482
+ )
483
+ document.related_relations.append(new_rel)
484
+ else:
485
+ skipped_labels.append(label)
486
+
487
+ # link_head -> link_tail == rel_tail -> rel_head
488
+ if reversed_relation_suffix is not None:
489
+ for reversed_rel in rel_tail2rels.get(link_rel.tail, []):
490
+ label = reversed_rel.label
491
+ if not (symmetric_relations is not None and label in symmetric_relations):
492
+ label = f"{label}{reversed_relation_suffix}"
493
+ if relation_label_whitelist is None or label in relation_label_whitelist:
494
+ new_rel = RelatedRelation(
495
+ head=link_rel.head,
496
+ tail=reversed_rel.head,
497
+ link_relation=link_rel,
498
+ relation=reversed_rel,
499
+ label=label,
500
+ )
501
+ document.related_relations.append(new_rel)
502
+ else:
503
+ skipped_labels.append(label)
504
+
505
+ else:
506
+ logger.warning(
507
+ f"Skipping related relation because of partition whitelist ({[head_partition.label, tail_partition.label]}): {link_rel.resolve()}"
508
+ )
509
+ if len(skipped_labels) > 0:
510
+ logger.warning(
511
+ f"Skipped relations with labels not in whitelist: {sorted(set(skipped_labels))}"
512
+ )
513
+
514
+ return document
src/document/types.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+
3
+ from pytorch_ie import AnnotationLayer, annotation_field
4
+ from pytorch_ie.annotations import BinaryRelation
5
+ from pytorch_ie.documents import (
6
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
7
+ )
8
+
9
+
10
+ @dataclasses.dataclass(eq=True, frozen=True)
11
+ class RelatedRelation(BinaryRelation):
12
+ link_relation: BinaryRelation = dataclasses.field(default=None, compare=False)
13
+ relation: BinaryRelation = dataclasses.field(default=None, compare=False)
14
+
15
+ def __post_init__(self):
16
+ super().__post_init__()
17
+ # check if the reference_span is correct
18
+ self.reference_span
19
+
20
+ @property
21
+ def reference_span(self):
22
+ if self.link_relation is None:
23
+ raise ValueError(
24
+ "No semantically_same_relation available, cannot return reference_span"
25
+ )
26
+ if self.link_relation.head == self.head:
27
+ return self.link_relation.tail
28
+ elif self.link_relation.tail == self.head:
29
+ return self.link_relation.head
30
+ elif self.link_relation.head == self.tail:
31
+ return self.link_relation.tail
32
+ elif self.link_relation.tail == self.tail:
33
+ return self.link_relation.head
34
+ else:
35
+ raise ValueError(
36
+ "The semantically_same_relation is neither linked to head nor tail of the current relation"
37
+ )
38
+
39
+
40
+ @dataclasses.dataclass
41
+ class TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations(
42
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
43
+ ):
44
+ related_relations: AnnotationLayer[RelatedRelation] = annotation_field(
45
+ targets=["labeled_multi_spans", "binary_relations"]
46
+ )
src/metrics/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
  from .coref_sklearn import CorefMetricsSKLearn
2
  from .coref_torchmetrics import CorefMetricsTorchmetrics
 
 
1
  from .coref_sklearn import CorefMetricsSKLearn
2
  from .coref_torchmetrics import CorefMetricsTorchmetrics
3
+ from .score_distribution import ScoreDistribution
src/metrics/score_distribution.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ import pandas as pd
5
+ from pytorch_ie import Document, DocumentMetric
6
+
7
+
8
+ class ScoreDistribution(DocumentMetric):
9
+ """Computes the distribution of prediction scores for annotations in a layer. The scores are
10
+ separated into true positives (TP) and false positives (FP) based on the gold annotations.
11
+
12
+ Args:
13
+ layer: The name of the annotation layer to analyze.
14
+ per_label: If True, the scores are separated per label. Default is False.
15
+ label_field: The field name of the label to use for separating the scores per label. Default is "label".
16
+ equal_sample_size_binning: If True, the scores are binned into equal sample sizes. If False,
17
+ the scores are binned into equal width. The former is useful when the distribution of scores is skewed.
18
+ Default is True.
19
+ show_plot: If True, a plot of the score distribution is shown. Default is False.
20
+ plotting_backend: The plotting backend to use. Default is "plotly".
21
+ plotting_caption_mapping: A mapping to rename any caption entries for plotting, i.e., the layer name,
22
+ labels, or TP/FP. Default is None.
23
+ plotting_colors: A dictionary mapping from gold scores to colors for plotting. Default is None.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ layer: str,
29
+ label_field: str = "label",
30
+ per_label: bool = False,
31
+ show_plot: bool = False,
32
+ equal_sample_size_binning: bool = True,
33
+ plotting_backend: str = "plotly",
34
+ plotting_caption_mapping: Optional[Dict[str, str]] = None,
35
+ plotting_colors: Optional[Dict[str, str]] = None,
36
+ plotly_use_create_distplot: bool = True,
37
+ plotly_barmode: Optional[str] = None,
38
+ plotly_marginal: Optional[str] = "violin",
39
+ plotly_font_size: int = 18,
40
+ plotly_font_family: Optional[str] = None,
41
+ plotly_background_color: Optional[str] = None,
42
+ ):
43
+ super().__init__()
44
+ self.layer = layer
45
+ self.label_field = label_field
46
+ self.per_label = per_label
47
+ self.equal_sample_size_binning = equal_sample_size_binning
48
+ self.plotting_backend = plotting_backend
49
+ self.show_plot = show_plot
50
+ self.plotting_caption_mapping = plotting_caption_mapping or {}
51
+ self.plotting_colors = plotting_colors
52
+ self.plotly_use_create_distplot = plotly_use_create_distplot
53
+ self.plotly_barmode = plotly_barmode
54
+ self.plotly_marginal = plotly_marginal
55
+ self.plotly_font_size = plotly_font_size
56
+ self.plotly_font_family = plotly_font_family
57
+ self.plotly_background_color = plotly_background_color
58
+ self.scores: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
59
+
60
+ def reset(self):
61
+ self.scores = defaultdict(lambda: defaultdict(list))
62
+
63
+ def _update(self, document: Document):
64
+
65
+ gold_annotations = set(document[self.layer])
66
+ for ann in document[self.layer].predictions:
67
+ if self.per_label:
68
+ label = getattr(ann, self.label_field)
69
+ else:
70
+ label = "ALL"
71
+ if ann in gold_annotations:
72
+ self.scores[label]["TP"].append(ann.score)
73
+ else:
74
+ self.scores[label]["FP"].append(ann.score)
75
+
76
+ def _combine_scores(
77
+ self,
78
+ scores_tp: List[float],
79
+ score_fp: List[float],
80
+ col_name_pred: str = "prediction",
81
+ col_name_gold: str = "gold",
82
+ ) -> pd.DataFrame:
83
+ scores_tp_df = pd.DataFrame(scores_tp, columns=[col_name_pred])
84
+ scores_tp_df[col_name_gold] = 1.0
85
+ scores_fp_df = pd.DataFrame(score_fp, columns=[col_name_pred])
86
+ scores_fp_df[col_name_gold] = 0.0
87
+ scores_df = pd.concat([scores_tp_df, scores_fp_df])
88
+ return scores_df
89
+
90
+ def _get_calibration_data_and_metrics(
91
+ self, scores: pd.DataFrame, q: int = 20
92
+ ) -> Tuple[pd.DataFrame, pd.Series]:
93
+ from sklearn.metrics import brier_score_loss
94
+
95
+ if self.equal_sample_size_binning:
96
+ # Create bins with equal number of samples.
97
+ scores["bin"] = pd.qcut(scores["prediction"], q=q, labels=False)
98
+ else:
99
+ # Create bins with equal width.
100
+ scores["bin"] = pd.cut(
101
+ scores["prediction"],
102
+ bins=q,
103
+ include_lowest=True,
104
+ right=True,
105
+ labels=False,
106
+ )
107
+
108
+ calibration_data = (
109
+ scores.groupby("bin")
110
+ .apply(
111
+ lambda x: pd.Series(
112
+ {
113
+ "avg_score": x["prediction"].mean(),
114
+ "fraction_positive": x["gold"].mean(),
115
+ "count": len(x),
116
+ }
117
+ )
118
+ )
119
+ .reset_index()
120
+ )
121
+
122
+ total_count = scores.shape[0]
123
+ calibration_data["bin_weight"] = calibration_data["count"] / total_count
124
+
125
+ # Calculate the absolute differences and squared differences.
126
+ calibration_data["abs_diff"] = abs(
127
+ calibration_data["avg_score"] - calibration_data["fraction_positive"]
128
+ )
129
+ calibration_data["squared_diff"] = (
130
+ calibration_data["avg_score"] - calibration_data["fraction_positive"]
131
+ ) ** 2
132
+
133
+ # Compute Expected Calibration Error (ECE): weighted average of absolute differences.
134
+ ece = (calibration_data["abs_diff"] * calibration_data["bin_weight"]).sum()
135
+
136
+ # Compute Maximum Calibration Error (MCE): maximum absolute difference.
137
+ mce = calibration_data["abs_diff"].max()
138
+
139
+ # Compute Mean Squared Error (MSE): weighted average of squared differences.
140
+ mse = (calibration_data["squared_diff"] * calibration_data["bin_weight"]).sum()
141
+
142
+ # Compute the Brier Score on the raw predictions.
143
+ brier = brier_score_loss(scores["gold"], scores["prediction"])
144
+
145
+ values = {
146
+ "ece": ece,
147
+ "mce": mce,
148
+ "mse": mse,
149
+ "brier": brier,
150
+ }
151
+ return calibration_data, pd.Series(values)
152
+
153
+ def calculate_calibration_metrics(self, scores_combined: pd.DataFrame) -> pd.DataFrame:
154
+
155
+ calibration_data_dict = {}
156
+ calibration_metrics_dict = {}
157
+ for label, current_scores in scores_combined.groupby("label"):
158
+ calibration_data, calibration_metrics = self._get_calibration_data_and_metrics(
159
+ current_scores, q=20
160
+ )
161
+ calibration_data_dict[label] = calibration_data
162
+ calibration_metrics_dict[label] = calibration_metrics
163
+ all_calibration_data = pd.concat(
164
+ calibration_data_dict, names=["label", "idx"]
165
+ ).reset_index(level=0)
166
+ all_calibration_metrics = pd.concat(calibration_metrics_dict, axis=1).T
167
+
168
+ if self.show_plot:
169
+ self.plot_calibration_data(calibration_data=all_calibration_data)
170
+
171
+ return all_calibration_metrics
172
+
173
+ def calculate_correlation(self, scores: pd.DataFrame) -> pd.Series:
174
+ result_dict = {}
175
+ for label, current_scores in scores.groupby("label"):
176
+ result_dict[label] = current_scores.drop("label", axis=1).corr()["prediction"]["gold"]
177
+
178
+ return pd.Series(result_dict, name="correlation")
179
+
180
+ @property
181
+ def mapped_layer(self):
182
+ return self.plotting_caption_mapping.get(self.layer, self.layer)
183
+
184
+ def plot_score_distribution(self, scores: pd.DataFrame):
185
+ if self.plotting_backend == "plotly":
186
+ for label in scores["label"].unique():
187
+ description = f"Distribution of Predicted Scores for {self.mapped_layer}"
188
+ if self.per_label:
189
+ label_mapped = self.plotting_caption_mapping.get(label, label)
190
+ description += f" ({label_mapped})"
191
+ if self.plotly_use_create_distplot:
192
+ import plotly.figure_factory as ff
193
+
194
+ current_scores = scores[scores["label"] == label]
195
+ # group by gold score
196
+ scores_dict = (
197
+ current_scores.groupby("gold")["prediction"].apply(list).to_dict()
198
+ )
199
+ group_labels, hist_data = zip(*scores_dict.items())
200
+ group_labels_renamed = [
201
+ self.plotting_caption_mapping.get(label, label) for label in group_labels
202
+ ]
203
+ if self.plotting_colors is not None:
204
+ colors = [
205
+ self.plotting_colors[group_label] for group_label in group_labels
206
+ ]
207
+ else:
208
+ colors = None
209
+ fig = ff.create_distplot(
210
+ hist_data,
211
+ group_labels=group_labels_renamed,
212
+ show_hist=True,
213
+ colors=colors,
214
+ bin_size=0.025,
215
+ )
216
+ else:
217
+ import plotly.express as px
218
+
219
+ fig = px.histogram(
220
+ scores,
221
+ x="prediction",
222
+ color="gold",
223
+ marginal=self.plotly_marginal, # "violin", # or box, violin, rug
224
+ hover_data=scores.columns,
225
+ color_discrete_map=self.plotting_colors,
226
+ nbins=50,
227
+ )
228
+
229
+ fig.update_layout(
230
+ height=600,
231
+ width=800,
232
+ title_text=description,
233
+ title_x=0.5,
234
+ font=dict(size=self.plotly_font_size),
235
+ legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
236
+ )
237
+ if self.plotly_barmode is not None:
238
+ fig.update_layout(barmode=self.plotly_barmode)
239
+ if self.plotly_font_family is not None:
240
+ fig.update_layout(font_family=self.plotly_font_family)
241
+ if self.plotly_background_color is not None:
242
+ fig.update_layout(
243
+ plot_bgcolor=self.plotly_background_color,
244
+ paper_bgcolor=self.plotly_background_color,
245
+ )
246
+
247
+ fig.show()
248
+ else:
249
+ raise NotImplementedError(f"Plotting backend {self.plotting_backend} not implemented")
250
+
251
+ def plot_calibration_data(self, calibration_data: pd.DataFrame):
252
+ import plotly.express as px
253
+ import plotly.graph_objects as go
254
+
255
+ color = "label" if self.per_label else None
256
+ x_col = "avg_score"
257
+ y_col = "fraction_positive"
258
+ fig = px.scatter(
259
+ calibration_data,
260
+ x=x_col,
261
+ y=y_col,
262
+ color=color,
263
+ trendline="ols",
264
+ labels=self.plotting_caption_mapping,
265
+ )
266
+ if not self.per_label:
267
+ fig["data"][1]["name"] = "prediction vs. gold"
268
+
269
+ # show legend only for trendlines
270
+ for idx, trace_data in enumerate(fig["data"]):
271
+ if idx % 2 == 0:
272
+ trace_data["showlegend"] = False
273
+ else:
274
+ trace_data["showlegend"] = True
275
+
276
+ # add the optimal line
277
+ minimum = calibration_data[x_col].min()
278
+ maximum = calibration_data[x_col].max()
279
+ fig.add_trace(
280
+ go.Scatter(
281
+ x=[minimum, maximum],
282
+ y=[minimum, maximum],
283
+ mode="lines",
284
+ name="optimal",
285
+ line=dict(color="black", dash="dash"),
286
+ )
287
+ )
288
+ fig.update_layout(
289
+ height=600,
290
+ width=800,
291
+ title_text=f"Mean Binned Scores for {self.mapped_layer}",
292
+ title_x=0.5,
293
+ font=dict(size=self.plotly_font_size),
294
+ )
295
+ fig.update_layout(
296
+ legend=dict(
297
+ yanchor="top",
298
+ y=0.99,
299
+ xanchor="left",
300
+ x=0.01,
301
+ title="OLS trendline" + ("s" if self.per_label else ""),
302
+ ),
303
+ )
304
+ if self.plotly_background_color is not None:
305
+ fig.update_layout(
306
+ plot_bgcolor=self.plotly_background_color,
307
+ paper_bgcolor=self.plotly_background_color,
308
+ )
309
+
310
+ if self.plotly_font_family is not None:
311
+ fig.update_layout(font_family=self.plotly_font_family)
312
+
313
+ fig.show()
314
+
315
+ def _compute(self) -> Dict[str, Dict[str, Any]]:
316
+ scores_combined = pd.concat(
317
+ {
318
+ label: self._combine_scores(scores["TP"], scores["FP"])
319
+ for label, scores in self.scores.items()
320
+ },
321
+ names=["label", "idx"],
322
+ ).reset_index(level=0)
323
+
324
+ result_df = scores_combined.groupby("label")["prediction"].agg(["mean", "std", "count"])
325
+ if self.show_plot:
326
+ self.plot_score_distribution(scores=scores_combined)
327
+
328
+ calibration_metrics = self.calculate_calibration_metrics(scores_combined)
329
+ calibration_metrics["correlation"] = self.calculate_correlation(scores_combined)
330
+
331
+ result_df = pd.concat(
332
+ {"prediction": result_df, "prediction vs. gold": calibration_metrics}, axis=1
333
+ )
334
+
335
+ if not self.per_label:
336
+ result = result_df.xs("ALL")
337
+ else:
338
+ result = result_df.T.stack().unstack()
339
+
340
+ result_dict = {
341
+ main_key: result.xs(main_key).T.to_dict()
342
+ for main_key in result.index.get_level_values(0).unique()
343
+ }
344
+
345
+ return result_dict
src/models/__init__.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from .sequence_classification_with_pooler import (
2
  SequencePairSimilarityModelWithMaxCosineSim,
3
  SequencePairSimilarityModelWithPooler2,
 
1
+ from .sequence_classification import SimpleSequenceClassificationModelWithInputTypeIds
2
  from .sequence_classification_with_pooler import (
3
  SequencePairSimilarityModelWithMaxCosineSim,
4
  SequencePairSimilarityModelWithPooler2,
src/models/sequence_classification.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from pie_modules.models import SimpleSequenceClassificationModel
4
+ from pie_modules.models.simple_sequence_classification import InputType, OutputType, TargetType
5
+ from pytorch_ie import PyTorchIEModel
6
+ from torch import nn
7
+ from transformers import BertModel
8
+ from transformers.utils import is_accelerate_available
9
+
10
+ if is_accelerate_available():
11
+ from accelerate.hooks import add_hook_to_module
12
+
13
+
14
+ @PyTorchIEModel.register()
15
+ class SimpleSequenceClassificationModelWithInputTypeIds(SimpleSequenceClassificationModel):
16
+
17
+ def __init__(
18
+ self, num_token_type_ids: int, use_as_token_type_ids: str = "token_type_ids", **kwargs
19
+ ):
20
+ super().__init__(**kwargs)
21
+ self.num_token_type_ids = num_token_type_ids
22
+ self.token_type_ids_key = use_as_token_type_ids
23
+ self.resize_type_embeddings(num_token_type_ids)
24
+
25
+ def get_input_type_embeddings(self) -> nn.Module:
26
+ base_model: BertModel = getattr(self.model, self.model.base_model_prefix)
27
+ if base_model is None:
28
+ raise ValueError("Model has no base model.")
29
+ return base_model.embeddings.token_type_embeddings
30
+
31
+ def set_input_type_embeddings(self, value):
32
+ base_model: BertModel = getattr(self.model, self.model.base_model_prefix)
33
+ if base_model is None:
34
+ raise ValueError("Model has no base model.")
35
+ base_model.embeddings.token_type_embeddings = value
36
+
37
+ def _resize_type_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
38
+ old_embeddings = self.get_input_type_embeddings()
39
+ new_embeddings = self.model._get_resized_embeddings(
40
+ old_embeddings, new_num_tokens, pad_to_multiple_of
41
+ )
42
+ if hasattr(old_embeddings, "_hf_hook"):
43
+ hook = old_embeddings._hf_hook
44
+ add_hook_to_module(new_embeddings, hook)
45
+ old_embeddings_requires_grad = old_embeddings.weight.requires_grad
46
+ new_embeddings.requires_grad_(old_embeddings_requires_grad)
47
+ self.set_input_type_embeddings(new_embeddings)
48
+
49
+ return self.get_input_type_embeddings()
50
+
51
+ def resize_type_embeddings(
52
+ self, new_num_types: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
53
+ ) -> nn.Embedding:
54
+ """
55
+ Same as resize_token_embeddings but for the token type embeddings.
56
+
57
+ Resizes input token type embeddings matrix of the model if `new_num_types != config.type_vocab_size`.
58
+
59
+ Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
60
+
61
+ Arguments:
62
+ new_num_types (`int`, *optional*):
63
+ The number of new token types in the embedding matrix. Increasing the size will add newly initialized
64
+ vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
65
+ returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
66
+ pad_to_multiple_of (`int`, *optional*):
67
+ If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
68
+ `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
69
+
70
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
71
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
72
+ details about this, or help on choosing the correct value for resizing, refer to this guide:
73
+ https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
74
+
75
+ Return:
76
+ `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
77
+ """
78
+ model_embeds = self._resize_type_embeddings(new_num_types, pad_to_multiple_of)
79
+ if new_num_types is None and pad_to_multiple_of is None:
80
+ return model_embeds
81
+
82
+ # Update base model and current model config
83
+ self.model.config.type_vocab_size = model_embeds.weight.shape[0]
84
+
85
+ # Tie weights again if needed
86
+ self.model.tie_weights()
87
+
88
+ return model_embeds
89
+
90
+ def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> OutputType:
91
+ kwargs = {**inputs, **(targets or {})}
92
+ # rename key to input_type_ids
93
+ kwargs["token_type_ids"] = kwargs.pop(self.token_type_ids_key)
94
+ return self.model(**kwargs)
src/pipeline/ner_re_pipeline.py CHANGED
@@ -15,6 +15,7 @@ from typing import (
15
  overload,
16
  )
17
 
 
18
  from pie_modules.utils import resolve_type
19
  from pytorch_ie import AutoPipeline, WithDocumentTypeMixin
20
  from pytorch_ie.core import Document
@@ -53,31 +54,105 @@ def move_annotations_to_predictions(doc: D, layer_names: List[str]) -> None:
53
  doc[layer_name].predictions.extend(annotations)
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def add_annotations_from_other_documents(
57
  docs: Iterable[D],
58
  other_docs: Sequence[Document],
59
- layer_names: List[str],
60
- from_predictions: bool = False,
61
- to_predictions: bool = False,
62
- clear_before: bool = True,
63
- ) -> None:
64
- for i, doc in enumerate(docs):
65
- other_doc = other_docs[i]
66
- # copy to not modify the input
67
- other_doc = type(other_doc).fromdict(other_doc.asdict())
68
-
69
- for layer_name in layer_names:
70
- if clear_before:
71
- doc[layer_name].clear()
72
- other_layer = other_doc[layer_name]
73
- if from_predictions:
74
- other_layer = other_layer.predictions
75
- other_annotations = list(other_layer)
76
- other_layer.clear()
77
- if to_predictions:
78
- doc[layer_name].predictions.extend(other_annotations)
 
 
79
  else:
80
- doc[layer_name].extend(other_annotations)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
 
83
  def process_pipeline_steps(
@@ -227,6 +302,9 @@ class NerRePipeline:
227
  "re_add_gold_data": partial(
228
  add_annotations_from_other_documents,
229
  other_docs=original_docs,
 
 
 
230
  layer_names=[self.entity_layer, self.relation_layer],
231
  **self.processor_kwargs.get("re_add_gold_data", {}),
232
  ),
 
15
  overload,
16
  )
17
 
18
+ from pie_datasets import Dataset
19
  from pie_modules.utils import resolve_type
20
  from pytorch_ie import AutoPipeline, WithDocumentTypeMixin
21
  from pytorch_ie.core import Document
 
54
  doc[layer_name].predictions.extend(annotations)
55
 
56
 
57
+ def _add_annotations_from_other_document(
58
+ doc: D,
59
+ from_predictions: bool,
60
+ to_predictions: bool,
61
+ clear_before: bool,
62
+ other_doc: Optional[D] = None,
63
+ other_docs_dict: Optional[Dict[str, D]] = None,
64
+ layer_names: Optional[List[str]] = None,
65
+ ) -> D:
66
+ if other_doc is None:
67
+ if other_docs_dict is None:
68
+ raise ValueError("Either other_doc or other_docs_dict must be provided")
69
+ other_doc = other_docs_dict.get(doc.id)
70
+ if other_doc is None:
71
+ logger.warning(f"Document with ID {doc.id} not found in other_docs")
72
+ return doc
73
+
74
+ # copy to not modify the input
75
+ other_doc_copy = type(other_doc).fromdict(other_doc.asdict())
76
+
77
+ if layer_names is None:
78
+ layer_names = [field.name for field in doc.annotation_fields()]
79
+
80
+ for layer_name in layer_names:
81
+ layer = doc[layer_name]
82
+ if to_predictions:
83
+ layer = layer.predictions
84
+ if clear_before:
85
+ layer.clear()
86
+ other_layer = other_doc_copy[layer_name]
87
+ if from_predictions:
88
+ other_layer = other_layer.predictions
89
+ other_annotations = list(other_layer)
90
+ other_layer.clear()
91
+ layer.extend(other_annotations)
92
+
93
+ return doc
94
+
95
+
96
  def add_annotations_from_other_documents(
97
  docs: Iterable[D],
98
  other_docs: Sequence[Document],
99
+ get_other_doc_by_id: bool = False,
100
+ **kwargs,
101
+ ) -> Sequence[D]:
102
+ other_id2doc = None
103
+ if get_other_doc_by_id:
104
+ other_id2doc = {doc.id: doc for doc in other_docs}
105
+
106
+ if isinstance(docs, Dataset):
107
+ if other_id2doc is None:
108
+ raise ValueError("get_other_doc_by_id must be True when passing a Dataset")
109
+ result = docs.map(
110
+ _add_annotations_from_other_document,
111
+ fn_kwargs=dict(other_docs_dict=other_id2doc, **kwargs),
112
+ )
113
+ elif isinstance(docs, list):
114
+ result = []
115
+ for i, doc in enumerate(docs):
116
+ if other_id2doc is not None:
117
+ other_doc = other_id2doc.get(doc.id)
118
+ if other_doc is None:
119
+ logger.warning(f"Document with ID {doc.id} not found in other_docs")
120
+ continue
121
  else:
122
+ other_doc = other_docs[i]
123
+
124
+ # check if the IDs of the documents match
125
+ doc_id = getattr(doc, "id", None)
126
+ other_doc_id = getattr(other_doc, "id", None)
127
+ if doc_id is not None and doc_id != other_doc_id:
128
+ raise ValueError(
129
+ f"IDs of the documents do not match: {doc_id} != {other_doc_id}"
130
+ )
131
+
132
+ current_result = _add_annotations_from_other_document(
133
+ doc, other_doc=other_doc, **kwargs
134
+ )
135
+ result.append(current_result)
136
+ else:
137
+ raise ValueError(f"Unsupported type: {type(docs)}")
138
+
139
+ return result
140
+
141
+
142
+ DM = TypeVar("DM", bound=Dict[str, Iterable[Document]])
143
+
144
+
145
+ def add_annotations_from_other_documents_dict(
146
+ docs: DM, other_docs: Dict[str, Sequence[Document]], **kwargs
147
+ ) -> DM:
148
+ if set(docs.keys()) != set(other_docs.keys()):
149
+ raise ValueError("Keys of the documents do not match")
150
+
151
+ result_dict = {
152
+ key: add_annotations_from_other_documents(doc_list, other_docs[key], **kwargs)
153
+ for key, doc_list in docs.items()
154
+ }
155
+ return type(docs)(result_dict)
156
 
157
 
158
  def process_pipeline_steps(
 
302
  "re_add_gold_data": partial(
303
  add_annotations_from_other_documents,
304
  other_docs=original_docs,
305
+ from_predictions=False,
306
+ to_predictions=False,
307
+ clear_before=False,
308
  layer_names=[self.entity_layer, self.relation_layer],
309
  **self.processor_kwargs.get("re_add_gold_data", {}),
310
  ),
src/predict.py CHANGED
@@ -34,14 +34,13 @@ root = pyrootutils.setup_root(
34
  # ------------------------------------------------------------------------------------ #
35
 
36
  import os
37
- import timeit
38
  from collections.abc import Iterable, Sequence
39
  from typing import Any, Dict, Optional, Tuple, Union
40
 
41
  import hydra
42
  import pytorch_lightning as pl
43
  from omegaconf import DictConfig, OmegaConf
44
- from pie_datasets import Dataset, DatasetDict
45
  from pie_modules.models import * # noqa: F403
46
  from pie_modules.taskmodules import * # noqa: F403
47
  from pytorch_ie import Document, Pipeline
@@ -132,38 +131,13 @@ def predict(cfg: DictConfig) -> Tuple[dict, dict]:
132
  "pipeline": pipeline,
133
  "serializer": serializer,
134
  }
135
- result: Dict[str, Any] = {}
136
- if pipeline is not None:
137
- log.info("Starting inference!")
138
- prediction_time = 0.0
139
- else:
140
- log.warning("No prediction pipeline is defined, skip inference!")
141
- prediction_time = None
142
- document_batch_size = cfg.get("document_batch_size", None)
143
- for docs_batch in (
144
- document_batch_iter(dataset_predict, document_batch_size)
145
- if document_batch_size
146
- else [dataset_predict]
147
- ):
148
- if pipeline is not None:
149
- t_start = timeit.default_timer()
150
- docs_batch = pipeline(docs_batch, inplace=False)
151
- prediction_time += timeit.default_timer() - t_start # type: ignore
152
-
153
- # serialize the documents
154
- if serializer is not None:
155
- # the serializer should not return the serialized documents, but write them to disk
156
- # and instead return some metadata such as the path to the serialized documents
157
- serializer_result = serializer(docs_batch)
158
- if "serializer" in result and result["serializer"] != serializer_result:
159
- log.warning(
160
- f"serializer result changed from {result['serializer']} to {serializer_result}"
161
- " during prediction. Only the last result is returned."
162
- )
163
- result["serializer"] = serializer_result
164
-
165
- if prediction_time is not None:
166
- result["prediction_time"] = prediction_time
167
 
168
  # serialize config with resolved paths
169
  if cfg.get("config_out_path"):
 
34
  # ------------------------------------------------------------------------------------ #
35
 
36
  import os
 
37
  from collections.abc import Iterable, Sequence
38
  from typing import Any, Dict, Optional, Tuple, Union
39
 
40
  import hydra
41
  import pytorch_lightning as pl
42
  from omegaconf import DictConfig, OmegaConf
43
+ from pie_datasets import DatasetDict
44
  from pie_modules.models import * # noqa: F403
45
  from pie_modules.taskmodules import * # noqa: F403
46
  from pytorch_ie import Document, Pipeline
 
131
  "pipeline": pipeline,
132
  "serializer": serializer,
133
  }
134
+ # predict and serialize
135
+ result: Dict[str, Any] = utils.predict_and_serialize(
136
+ pipeline=pipeline,
137
+ serializer=serializer,
138
+ dataset=dataset_predict,
139
+ document_batch_size=cfg.get("document_batch_size", None),
140
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  # serialize config with resolved paths
143
  if cfg.get("config_out_path"):
src/serializer/interface.py CHANGED
@@ -1,5 +1,5 @@
1
  from abc import ABC, abstractmethod
2
- from typing import Any, Sequence
3
 
4
  from pytorch_ie.core import Document
5
 
@@ -12,5 +12,5 @@ class DocumentSerializer(ABC):
12
  """
13
 
14
  @abstractmethod
15
- def __call__(self, documents: Sequence[Document]) -> Any:
16
  pass
 
1
  from abc import ABC, abstractmethod
2
+ from typing import Any, Iterable
3
 
4
  from pytorch_ie.core import Document
5
 
 
12
  """
13
 
14
  @abstractmethod
15
+ def __call__(self, documents: Iterable[Document]) -> Any:
16
  pass
src/serializer/json.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
2
  import os
3
- from typing import Dict, List, Optional, Sequence, Type, TypeVar
4
 
5
  from pie_datasets import Dataset, DatasetDict, IterableDataset
6
  from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
@@ -8,7 +8,7 @@ from pytorch_ie.core import Document
8
  from pytorch_ie.utils.hydra import resolve_optional_document_type, serialize_document_type
9
 
10
  from src.serializer.interface import DocumentSerializer
11
- from src.utils import get_pylogger
12
 
13
  log = get_pylogger(__name__)
14
 
@@ -31,7 +31,7 @@ class JsonSerializer(DocumentSerializer):
31
  @classmethod
32
  def write(
33
  cls,
34
- documents: Sequence[Document],
35
  path: str,
36
  file_name: str = "documents.jsonl",
37
  metadata_file_name: str = METADATA_FILE_NAME,
@@ -42,6 +42,9 @@ class JsonSerializer(DocumentSerializer):
42
  log.info(f'serialize documents to "{realpath}" ...')
43
  os.makedirs(realpath, exist_ok=True)
44
 
 
 
 
45
  # dump metadata including the document_type
46
  if len(documents) == 0:
47
  raise Exception("cannot serialize empty list of documents")
@@ -130,7 +133,7 @@ class JsonSerializer(DocumentSerializer):
130
  all_kwargs = {**self.default_kwargs, **kwargs}
131
  return self.write(**all_kwargs)
132
 
133
- def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]:
134
  return self.write_with_defaults(documents=documents, **kwargs)
135
 
136
 
@@ -141,12 +144,15 @@ class JsonSerializer2(DocumentSerializer):
141
  @classmethod
142
  def write(
143
  cls,
144
- documents: Sequence[Document],
145
  path: str,
146
  split: str = "train",
147
  ) -> Dict[str, str]:
148
  if not isinstance(documents, (Dataset, IterableDataset)):
149
- documents = Dataset.from_documents(documents)
 
 
 
150
  dataset_dict = DatasetDict({split: documents})
151
  dataset_dict.to_json(path=path)
152
  return {"path": path, "split": split}
@@ -175,5 +181,5 @@ class JsonSerializer2(DocumentSerializer):
175
  all_kwargs = {**self.default_kwargs, **kwargs}
176
  return self.write(**all_kwargs)
177
 
178
- def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]:
179
  return self.write_with_defaults(documents=documents, **kwargs)
 
1
  import json
2
  import os
3
+ from typing import Dict, Iterable, List, Optional, Sequence, Type, TypeVar
4
 
5
  from pie_datasets import Dataset, DatasetDict, IterableDataset
6
  from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
 
8
  from pytorch_ie.utils.hydra import resolve_optional_document_type, serialize_document_type
9
 
10
  from src.serializer.interface import DocumentSerializer
11
+ from src.utils.logging_utils import get_pylogger
12
 
13
  log = get_pylogger(__name__)
14
 
 
31
  @classmethod
32
  def write(
33
  cls,
34
+ documents: Iterable[Document],
35
  path: str,
36
  file_name: str = "documents.jsonl",
37
  metadata_file_name: str = METADATA_FILE_NAME,
 
42
  log.info(f'serialize documents to "{realpath}" ...')
43
  os.makedirs(realpath, exist_ok=True)
44
 
45
+ if not isinstance(documents, Sequence):
46
+ documents = list(documents)
47
+
48
  # dump metadata including the document_type
49
  if len(documents) == 0:
50
  raise Exception("cannot serialize empty list of documents")
 
133
  all_kwargs = {**self.default_kwargs, **kwargs}
134
  return self.write(**all_kwargs)
135
 
136
+ def __call__(self, documents: Iterable[Document], **kwargs) -> Dict[str, str]:
137
  return self.write_with_defaults(documents=documents, **kwargs)
138
 
139
 
 
144
  @classmethod
145
  def write(
146
  cls,
147
+ documents: Iterable[Document],
148
  path: str,
149
  split: str = "train",
150
  ) -> Dict[str, str]:
151
  if not isinstance(documents, (Dataset, IterableDataset)):
152
+ if not isinstance(documents, Sequence):
153
+ documents = IterableDataset.from_documents(documents)
154
+ else:
155
+ documents = Dataset.from_documents(documents)
156
  dataset_dict = DatasetDict({split: documents})
157
  dataset_dict.to_json(path=path)
158
  return {"path": path, "split": split}
 
181
  all_kwargs = {**self.default_kwargs, **kwargs}
182
  return self.write(**all_kwargs)
183
 
184
+ def __call__(self, documents: Iterable[Document], **kwargs) -> Dict[str, str]:
185
  return self.write_with_defaults(documents=documents, **kwargs)
src/start_demo.py CHANGED
@@ -99,6 +99,7 @@ def main(cfg: DictConfig) -> None:
99
  render_caption2mode = {v: k for k, v in render_mode2caption.items()}
100
  default_min_similarity = cfg["default_min_similarity"]
101
  default_top_k = cfg["default_top_k"]
 
102
  layer_caption_mapping = cfg["layer_caption_mapping"]
103
  relation_name_mapping = cfg["relation_name_mapping"]
104
 
@@ -287,6 +288,13 @@ def main(cfg: DictConfig) -> None:
287
  step=1,
288
  value=default_top_k,
289
  )
 
 
 
 
 
 
 
290
  retrieve_similar_adus_btn = gr.Button(
291
  "Retrieve *similar* ADUs for *selected* ADU"
292
  )
@@ -361,18 +369,23 @@ def main(cfg: DictConfig) -> None:
361
  load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
362
 
363
  render_event_kwargs = dict(
364
- fn=lambda _retriever, _document_id, _render_as, _render_kwargs, _all_relevant_adus_df, _all_relevant_adus_query_doc_id: render_annotated_document(
365
- retriever=_retriever[0],
366
- document_id=_document_id,
367
- render_with=render_caption2mode[_render_as],
368
- render_kwargs_json=_render_kwargs,
369
- highlight_span_ids=(
370
- _all_relevant_adus_df["query_span_id"].tolist()
371
- if _document_id == _all_relevant_adus_query_doc_id
372
- else None
373
- ),
 
 
 
 
374
  ),
375
  inputs=[
 
376
  retriever_state,
377
  selected_document_id,
378
  render_as,
@@ -583,10 +596,11 @@ def main(cfg: DictConfig) -> None:
583
  ).success(**show_stats_kwargs)
584
 
585
  retrieve_relevant_adus_event_kwargs = dict(
586
- fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans(
587
  retriever=_retriever[0],
588
  query_span_id=_selected_adu_id,
589
  k=_top_k,
 
590
  score_threshold=_min_similarity,
591
  relation_label_mapping=relation_name_mapping,
592
  # columns=relevant_adus.headers
@@ -596,6 +610,7 @@ def main(cfg: DictConfig) -> None:
596
  selected_adu_id,
597
  min_similarity,
598
  top_k,
 
599
  ],
600
  outputs=[relevant_adus_df],
601
  )
@@ -614,10 +629,11 @@ def main(cfg: DictConfig) -> None:
614
  ).success(**retrieve_relevant_adus_event_kwargs)
615
 
616
  retrieve_similar_adus_btn.click(
617
- fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k: retrieve_similar_spans(
618
  retriever=_retriever[0],
619
  query_span_id=_selected_adu_id,
620
  k=_tok_k,
 
621
  score_threshold=_min_similarity,
622
  ),
623
  inputs=[
@@ -625,6 +641,7 @@ def main(cfg: DictConfig) -> None:
625
  selected_adu_id,
626
  min_similarity,
627
  top_k,
 
628
  ],
629
  outputs=[similar_adus_df],
630
  )
@@ -635,10 +652,11 @@ def main(cfg: DictConfig) -> None:
635
  )
636
 
637
  retrieve_all_similar_adus_btn.click(
638
- fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_similar_spans(
639
  retriever=_retriever[0],
640
  query_doc_id=_document_id,
641
  k=_tok_k,
 
642
  score_threshold=_min_similarity,
643
  query_span_id_column="query_span_id",
644
  ),
@@ -647,16 +665,18 @@ def main(cfg: DictConfig) -> None:
647
  selected_document_id,
648
  min_similarity,
649
  top_k,
 
650
  ],
651
  outputs=[all_similar_adus_df],
652
  )
653
 
654
  retrieve_all_relevant_adus_btn.click(
655
- fn=lambda _retriever, _document_id, _min_similarity, _tok_k: (
656
  retrieve_all_relevant_spans(
657
  retriever=_retriever[0],
658
  query_doc_id=_document_id,
659
  k=_tok_k,
 
660
  score_threshold=_min_similarity,
661
  query_span_id_column="query_span_id",
662
  query_span_text_column="query_span_text",
@@ -668,6 +688,7 @@ def main(cfg: DictConfig) -> None:
668
  selected_document_id,
669
  min_similarity,
670
  top_k,
 
671
  ],
672
  outputs=[all_relevant_adus_df, all_relevant_adus_query_doc_id],
673
  )
 
99
  render_caption2mode = {v: k for k, v in render_mode2caption.items()}
100
  default_min_similarity = cfg["default_min_similarity"]
101
  default_top_k = cfg["default_top_k"]
102
+ default_min_score = cfg["default_min_score"]
103
  layer_caption_mapping = cfg["layer_caption_mapping"]
104
  relation_name_mapping = cfg["relation_name_mapping"]
105
 
 
288
  step=1,
289
  value=default_top_k,
290
  )
291
+ min_score = gr.Slider(
292
+ label="Minimum Score",
293
+ minimum=0.0,
294
+ maximum=1.0,
295
+ step=0.01,
296
+ value=default_min_score,
297
+ )
298
  retrieve_similar_adus_btn = gr.Button(
299
  "Retrieve *similar* ADUs for *selected* ADU"
300
  )
 
369
  load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
370
 
371
  render_event_kwargs = dict(
372
+ fn=lambda _rendered_output, _retriever, _document_id, _render_as, _render_kwargs, _all_relevant_adus_df, _all_relevant_adus_query_doc_id: (
373
+ render_annotated_document(
374
+ retriever=_retriever[0],
375
+ document_id=_document_id,
376
+ render_with=render_caption2mode[_render_as],
377
+ render_kwargs_json=_render_kwargs,
378
+ highlight_span_ids=(
379
+ _all_relevant_adus_df["query_span_id"].tolist()
380
+ if _document_id == _all_relevant_adus_query_doc_id
381
+ else None
382
+ ),
383
+ )
384
+ if _document_id.strip() != ""
385
+ else _rendered_output
386
  ),
387
  inputs=[
388
+ rendered_output,
389
  retriever_state,
390
  selected_document_id,
391
  render_as,
 
596
  ).success(**show_stats_kwargs)
597
 
598
  retrieve_relevant_adus_event_kwargs = dict(
599
+ fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k, _min_score: retrieve_relevant_spans(
600
  retriever=_retriever[0],
601
  query_span_id=_selected_adu_id,
602
  k=_top_k,
603
+ min_score=_min_score,
604
  score_threshold=_min_similarity,
605
  relation_label_mapping=relation_name_mapping,
606
  # columns=relevant_adus.headers
 
610
  selected_adu_id,
611
  min_similarity,
612
  top_k,
613
+ min_score,
614
  ],
615
  outputs=[relevant_adus_df],
616
  )
 
629
  ).success(**retrieve_relevant_adus_event_kwargs)
630
 
631
  retrieve_similar_adus_btn.click(
632
+ fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k, _min_score: retrieve_similar_spans(
633
  retriever=_retriever[0],
634
  query_span_id=_selected_adu_id,
635
  k=_tok_k,
636
+ min_score=_min_score,
637
  score_threshold=_min_similarity,
638
  ),
639
  inputs=[
 
641
  selected_adu_id,
642
  min_similarity,
643
  top_k,
644
+ min_score,
645
  ],
646
  outputs=[similar_adus_df],
647
  )
 
652
  )
653
 
654
  retrieve_all_similar_adus_btn.click(
655
+ fn=lambda _retriever, _document_id, _min_similarity, _tok_k, _min_score: retrieve_all_similar_spans(
656
  retriever=_retriever[0],
657
  query_doc_id=_document_id,
658
  k=_tok_k,
659
+ min_score=_min_score,
660
  score_threshold=_min_similarity,
661
  query_span_id_column="query_span_id",
662
  ),
 
665
  selected_document_id,
666
  min_similarity,
667
  top_k,
668
+ min_score,
669
  ],
670
  outputs=[all_similar_adus_df],
671
  )
672
 
673
  retrieve_all_relevant_adus_btn.click(
674
+ fn=lambda _retriever, _document_id, _min_similarity, _tok_k, _min_score: (
675
  retrieve_all_relevant_spans(
676
  retriever=_retriever[0],
677
  query_doc_id=_document_id,
678
  k=_tok_k,
679
+ min_score=_min_score,
680
  score_threshold=_min_similarity,
681
  query_span_id_column="query_span_id",
682
  query_span_text_column="query_span_text",
 
688
  selected_document_id,
689
  min_similarity,
690
  top_k,
691
+ min_score,
692
  ],
693
  outputs=[all_relevant_adus_df, all_relevant_adus_query_doc_id],
694
  )
src/taskmodules/cross_text_binary_coref_nli.py CHANGED
@@ -62,6 +62,9 @@ class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleTyp
62
  tokenizer_name_or_path: str,
63
  labels: List[str],
64
  entailment_label: str,
 
 
 
65
  **kwargs,
66
  ) -> None:
67
  super().__init__(**kwargs)
@@ -69,6 +72,9 @@ class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleTyp
69
 
70
  self.labels = labels
71
  self.entailment_label = entailment_label
 
 
 
72
  self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
73
 
74
  def _post_prepare(self):
@@ -118,9 +124,18 @@ class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleTyp
118
  for task_encoding in task_encodings:
119
  all_texts.extend(task_encoding.inputs["text"])
120
  all_texts_pair.extend(task_encoding.inputs["text_pair"])
 
 
 
 
 
 
 
 
 
121
  inputs = self.tokenizer(
122
- text=all_texts,
123
- text_pair=all_texts_pair,
124
  truncation=True,
125
  padding=True,
126
  return_tensors="pt",
@@ -159,8 +174,20 @@ class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleTyp
159
  task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType],
160
  task_output: TaskOutputType,
161
  ) -> Iterator[Tuple[str, Annotation]]:
162
- if all(label == self.entailment_label for label in task_output["label_pair"]):
 
 
 
163
  probs = task_output["entailment_probability_pair"]
164
- score = (probs[0] + probs[1]) / 2
 
 
 
 
 
 
 
 
 
165
  new_coref_rel = task_encoding.metadata["candidate_annotation"].copy(score=score)
166
  yield "binary_coref_relations", new_coref_rel
 
62
  tokenizer_name_or_path: str,
63
  labels: List[str],
64
  entailment_label: str,
65
+ combine_score_method: str = "average",
66
+ keep_all_relations: bool = False,
67
+ as_text_pair: bool = True,
68
  **kwargs,
69
  ) -> None:
70
  super().__init__(**kwargs)
 
72
 
73
  self.labels = labels
74
  self.entailment_label = entailment_label
75
+ self.combine_score_method = combine_score_method
76
+ self.keep_all_relations = keep_all_relations
77
+ self.as_text_pair = as_text_pair
78
  self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
79
 
80
  def _post_prepare(self):
 
124
  for task_encoding in task_encodings:
125
  all_texts.extend(task_encoding.inputs["text"])
126
  all_texts_pair.extend(task_encoding.inputs["text_pair"])
127
+ if self.as_text_pair:
128
+ text = all_texts
129
+ text_pair = all_texts_pair
130
+ else:
131
+ text = [
132
+ f"{text}{self.tokenizer.sep_token}{text_pair}"
133
+ for text, text_pair in zip(all_texts, all_texts_pair)
134
+ ]
135
+ text_pair = None
136
  inputs = self.tokenizer(
137
+ text=text,
138
+ text_pair=text_pair,
139
  truncation=True,
140
  padding=True,
141
  return_tensors="pt",
 
174
  task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType],
175
  task_output: TaskOutputType,
176
  ) -> Iterator[Tuple[str, Annotation]]:
177
+ if (
178
+ all(label == self.entailment_label for label in task_output["label_pair"])
179
+ or self.keep_all_relations
180
+ ):
181
  probs = task_output["entailment_probability_pair"]
182
+ if self.combine_score_method == "average":
183
+ score = (probs[0] + probs[1]) / 2
184
+ elif self.combine_score_method == "min":
185
+ score = min(probs)
186
+ elif self.combine_score_method == "max":
187
+ score = max(probs)
188
+ elif self.combine_score_method == "product":
189
+ score = probs[0] * probs[1]
190
+ else:
191
+ raise ValueError(f"Unsupported combine_score_method: {self.combine_score_method}")
192
  new_coref_rel = task_encoding.metadata["candidate_annotation"].copy(score=score)
193
  yield "binary_coref_relations", new_coref_rel
src/train.py CHANGED
@@ -38,13 +38,14 @@ from typing import Any, Dict, List, Optional, Tuple
38
 
39
  import hydra
40
  import pytorch_lightning as pl
41
- from omegaconf import DictConfig
42
  from pie_datasets import DatasetDict
43
  from pie_modules.models import * # noqa: F403
44
  from pie_modules.models import SimpleGenerativeModel
45
  from pie_modules.models.interface import RequiresTaskmoduleConfig
46
  from pie_modules.taskmodules import * # noqa: F403
47
  from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
 
48
  from pytorch_ie.core import PyTorchIEModel, TaskModule
49
  from pytorch_ie.models import * # noqa: F403
50
  from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
@@ -56,6 +57,7 @@ from pytorch_lightning.loggers import Logger
56
  from src import utils
57
  from src.datamodules import PieDataModule
58
  from src.models import * # noqa: F403
 
59
  from src.taskmodules import * # noqa: F403
60
 
61
  log = utils.get_pylogger(__name__)
@@ -81,6 +83,27 @@ def get_metric_value(metric_dict: dict, metric_name: str) -> Optional[float]:
81
  return metric_value
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  @utils.task_wrapper
85
  def train(cfg: DictConfig) -> Tuple[dict, dict]:
86
  """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
@@ -179,6 +202,11 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
179
  )
180
  additional_model_kwargs["base_model_config"] = base_model_config
181
 
 
 
 
 
 
182
  # initialize the model
183
  model: PyTorchIEModel = hydra.utils.instantiate(
184
  cfg.model, _convert_="partial", **additional_model_kwargs
@@ -207,9 +235,11 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
207
  log.info("Logging hyperparameters!")
208
  utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)
209
 
210
- if cfg.model_save_dir is not None:
211
- log.info(f"Save taskmodule to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
212
- taskmodule.save_pretrained(save_directory=cfg.model_save_dir, push_to_hub=cfg.push_to_hub)
 
 
213
  else:
214
  log.warning("the taskmodule is not saved because no save_dir is specified")
215
 
@@ -238,15 +268,17 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
238
  f"Expected format: " + '"epoch_{best_epoch}.ckpt"'
239
  )
240
 
241
- if not cfg.trainer.get("fast_dev_run"):
242
- if cfg.model_save_dir is not None:
243
  if best_ckpt_path == "":
244
  log.warning("Best ckpt not found! Using current weights for saving...")
245
  else:
246
  model = type(model).load_from_checkpoint(best_ckpt_path)
247
 
248
- log.info(f"Save model to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
249
- model.save_pretrained(save_directory=cfg.model_save_dir, push_to_hub=cfg.push_to_hub)
 
 
250
  else:
251
  log.warning("the model is not saved because no save_dir is specified")
252
 
@@ -275,8 +307,36 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
275
 
276
  # add model_save_dir to the result so that it gets dumped to job_return_value.json
277
  # if we use hydra_callbacks.SaveJobReturnValueCallback
278
- if cfg.get("model_save_dir") is not None:
279
- metric_dict["model_save_dir"] = cfg.model_save_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  return metric_dict, object_dict
282
 
@@ -301,4 +361,5 @@ def main(cfg: DictConfig) -> Optional[float]:
301
  if __name__ == "__main__":
302
  utils.replace_sys_args_with_values_from_files()
303
  utils.prepare_omegaconf()
 
304
  main()
 
38
 
39
  import hydra
40
  import pytorch_lightning as pl
41
+ from omegaconf import DictConfig, OmegaConf
42
  from pie_datasets import DatasetDict
43
  from pie_modules.models import * # noqa: F403
44
  from pie_modules.models import SimpleGenerativeModel
45
  from pie_modules.models.interface import RequiresTaskmoduleConfig
46
  from pie_modules.taskmodules import * # noqa: F403
47
  from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
48
+ from pytorch_ie import Pipeline
49
  from pytorch_ie.core import PyTorchIEModel, TaskModule
50
  from pytorch_ie.models import * # noqa: F403
51
  from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
 
57
  from src import utils
58
  from src.datamodules import PieDataModule
59
  from src.models import * # noqa: F403
60
+ from src.serializer.interface import DocumentSerializer
61
  from src.taskmodules import * # noqa: F403
62
 
63
  log = utils.get_pylogger(__name__)
 
83
  return metric_value
84
 
85
 
86
+ def flatten_nested_dict(d: Dict[str, Any], parent_key: str = "", sep: str = ".") -> Dict[str, Any]:
87
+ """Flatten a nested dictionary.
88
+
89
+ Args:
90
+ d (Dict[str, Any]): The dictionary to flatten.
91
+ parent_key (str): The parent key.
92
+ sep (str): The separator.
93
+
94
+ Returns:
95
+ Dict[str, Any]: The flattened dictionary.
96
+ """
97
+ items: List[Tuple[str, Any]] = []
98
+ for k, v in d.items():
99
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
100
+ if isinstance(v, dict):
101
+ items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
102
+ else:
103
+ items.append((new_key, v))
104
+ return dict(items)
105
+
106
+
107
  @utils.task_wrapper
108
  def train(cfg: DictConfig) -> Tuple[dict, dict]:
109
  """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
 
202
  )
203
  additional_model_kwargs["base_model_config"] = base_model_config
204
 
205
+ if issubclass(model_cls, SimpleSequenceClassificationModelWithInputTypeIds): # noqa: F405
206
+ # add the number of input type ids to the model:
207
+ # 2 for B- and I-labels for each entity type, 1 for O labels, 1 for padding
208
+ additional_model_kwargs["num_token_type_ids"] = len(taskmodule.entity_labels) * 2 + 1 + 1
209
+
210
  # initialize the model
211
  model: PyTorchIEModel = hydra.utils.instantiate(
212
  cfg.model, _convert_="partial", **additional_model_kwargs
 
235
  log.info("Logging hyperparameters!")
236
  utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)
237
 
238
+ if cfg.paths.model_save_dir is not None:
239
+ log.info(f"Save taskmodule to {cfg.paths.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
240
+ taskmodule.save_pretrained(
241
+ save_directory=cfg.paths.model_save_dir, push_to_hub=cfg.push_to_hub
242
+ )
243
  else:
244
  log.warning("the taskmodule is not saved because no save_dir is specified")
245
 
 
268
  f"Expected format: " + '"epoch_{best_epoch}.ckpt"'
269
  )
270
 
271
+ if not cfg.trainer.get("fast_dev_run") or cfg.get("predict", False):
272
+ if cfg.paths.model_save_dir is not None:
273
  if best_ckpt_path == "":
274
  log.warning("Best ckpt not found! Using current weights for saving...")
275
  else:
276
  model = type(model).load_from_checkpoint(best_ckpt_path)
277
 
278
+ log.info(f"Save model to {cfg.paths.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
279
+ model.save_pretrained(
280
+ save_directory=cfg.paths.model_save_dir, push_to_hub=cfg.push_to_hub
281
+ )
282
  else:
283
  log.warning("the model is not saved because no save_dir is specified")
284
 
 
307
 
308
  # add model_save_dir to the result so that it gets dumped to job_return_value.json
309
  # if we use hydra_callbacks.SaveJobReturnValueCallback
310
+ if cfg.paths.get("model_save_dir") is not None:
311
+ metric_dict["model_save_dir"] = cfg.paths.model_save_dir
312
+
313
+ if cfg.get("predict"):
314
+ # Init the inference pipeline
315
+ pipeline: Optional[Pipeline] = None
316
+ if cfg.get("pipeline") and cfg.pipeline.get("_target_"):
317
+ log.info(f"Instantiating inference pipeline <{cfg.pipeline._target_}>")
318
+ pipeline = hydra.utils.instantiate(cfg.pipeline, _convert_="partial")
319
+ # Init the serializer
320
+ serializer: Optional[DocumentSerializer] = None
321
+ if cfg.get("serializer") and cfg.serializer.get("_target_"):
322
+ log.info(f"Instantiating serializer <{cfg.serializer._target_}>")
323
+ serializer = hydra.utils.instantiate(cfg.serializer, _convert_="partial")
324
+ # predict and serialize
325
+ predict_metrics: Dict[str, Any] = utils.predict_and_serialize(
326
+ pipeline=pipeline,
327
+ serializer=serializer,
328
+ dataset=dataset[cfg.dataset_split],
329
+ document_batch_size=cfg.get("document_batch_size", None),
330
+ )
331
+ # flatten the predict_metrics dict
332
+ predict_metrics_flat = flatten_nested_dict(predict_metrics, sep="/")
333
+ metric_dict.update(predict_metrics_flat)
334
+
335
+ if cfg.get("delete_model_dir"):
336
+ import shutil
337
+
338
+ log.info(f"Deleting model directory {cfg.paths.model_save_dir}")
339
+ shutil.rmtree(cfg.paths.model_save_dir)
340
 
341
  return metric_dict, object_dict
342
 
 
361
  if __name__ == "__main__":
362
  utils.replace_sys_args_with_values_from_files()
363
  utils.prepare_omegaconf()
364
+ OmegaConf.register_new_resolver("eval", eval)
365
  main()
src/utils/__init__.py CHANGED
@@ -5,7 +5,8 @@ from .config_utils import (
5
  prepare_omegaconf,
6
  )
7
  from .data_utils import download_and_unzip, filter_dataframe_and_get_column
 
8
  from .logging_utils import close_loggers, get_pylogger, log_hyperparameters
9
  from .rich_utils import enforce_tags, print_config_tree
10
- from .span_utils import distance
11
  from .task_utils import extras, replace_sys_args_with_values_from_files, save_file, task_wrapper
 
5
  prepare_omegaconf,
6
  )
7
  from .data_utils import download_and_unzip, filter_dataframe_and_get_column
8
+ from .inference_utils import predict_and_serialize
9
  from .logging_utils import close_loggers, get_pylogger, log_hyperparameters
10
  from .rich_utils import enforce_tags, print_config_tree
11
+ from .span_utils import distance, distance_slices
12
  from .task_utils import extras, replace_sys_args_with_values_from_files, save_file, task_wrapper
src/utils/inference_utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timeit
2
+ from collections.abc import Iterable, Sequence
3
+ from typing import Any, Dict, Optional, Union
4
+
5
+ from pytorch_ie import Document, Pipeline
6
+
7
+ from src.serializer.interface import DocumentSerializer
8
+
9
+ from .logging_utils import get_pylogger
10
+
11
+ log = get_pylogger(__name__)
12
+
13
+
14
+ def document_batch_iter(
15
+ dataset: Iterable[Document], batch_size: int
16
+ ) -> Iterable[Sequence[Document]]:
17
+ if isinstance(dataset, Sequence):
18
+ for i in range(0, len(dataset), batch_size):
19
+ yield dataset[i : i + batch_size]
20
+ elif isinstance(dataset, Iterable):
21
+ docs = []
22
+ for doc in dataset:
23
+ docs.append(doc)
24
+ if len(docs) == batch_size:
25
+ yield docs
26
+ docs = []
27
+ if docs:
28
+ yield docs
29
+ else:
30
+ raise ValueError(f"Unsupported dataset type: {type(dataset)}")
31
+
32
+
33
+ def predict_and_serialize(
34
+ pipeline: Optional[Pipeline],
35
+ serializer: Optional[DocumentSerializer],
36
+ dataset: Iterable[Document],
37
+ document_batch_size: Optional[int] = None,
38
+ ) -> Dict[str, Any]:
39
+ result: Dict[str, Any] = {}
40
+ if pipeline is not None:
41
+ log.info("Starting inference!")
42
+ prediction_time = 0.0
43
+ else:
44
+ log.warning("No prediction pipeline is defined, skip inference!")
45
+ prediction_time = None
46
+ docs_batch: Union[Iterable[Document], Sequence[Document]]
47
+
48
+ batch_iter: Union[Sequence[Iterable[Document]], Iterable[Sequence[Document]]]
49
+ if document_batch_size is None:
50
+ batch_iter = [dataset]
51
+ else:
52
+ batch_iter = document_batch_iter(dataset=dataset, batch_size=document_batch_size)
53
+ for docs_batch in batch_iter:
54
+ if pipeline is not None:
55
+ t_start = timeit.default_timer()
56
+ docs_batch = pipeline(docs_batch, inplace=False)
57
+ prediction_time += timeit.default_timer() - t_start # type: ignore
58
+
59
+ # serialize the documents
60
+ if serializer is not None:
61
+ # the serializer should not return the serialized documents, but write them to disk
62
+ # and instead return some metadata such as the path to the serialized documents
63
+ serializer_result = serializer(docs_batch)
64
+ if "serializer" in result and result["serializer"] != serializer_result:
65
+ log.warning(
66
+ f"serializer result changed from {result['serializer']} to {serializer_result}"
67
+ " during prediction. Only the last result is returned."
68
+ )
69
+ result["serializer"] = serializer_result
70
+
71
+ if prediction_time is not None:
72
+ result["prediction_time"] = prediction_time
73
+
74
+ return result
src/utils/span_utils.py CHANGED
@@ -58,3 +58,17 @@ def distance(
58
  raise ValueError(
59
  f"unknown distance_type={distance_type}. use one of: center, inner, outer"
60
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  raise ValueError(
59
  f"unknown distance_type={distance_type}. use one of: center, inner, outer"
60
  )
61
+
62
+
63
+ def distance_slices(
64
+ slices: Tuple[Tuple[int, int], ...],
65
+ other_slices: Tuple[Tuple[int, int], ...],
66
+ distance_type: str,
67
+ ) -> float:
68
+ starts, ends = zip(*slices)
69
+ other_starts, other_ends = zip(*other_slices)
70
+ return distance(
71
+ start_end=(min(starts), max(ends)),
72
+ other_start_end=(min(other_starts), max(other_ends)),
73
+ distance_type=distance_type,
74
+ )