upload https://github.com/ArneBinder/pie-document-level/pull/452
Browse files- src/analysis/combine_job_returns.py +15 -11
- src/analysis/show_score_distribution.py +99 -0
- src/data/calc_iaa_for_brat.py +1 -0
- src/datamodules/datamodule.py +12 -2
- src/demo/annotation_utils.py +1 -0
- src/demo/retrieve_and_dump_all_relevant.py +196 -0
- src/demo/retriever_utils.py +51 -11
- src/document/processing.py +247 -91
- src/document/types.py +46 -0
- src/metrics/__init__.py +1 -0
- src/metrics/score_distribution.py +345 -0
- src/models/__init__.py +1 -0
- src/models/sequence_classification.py +94 -0
- src/pipeline/ner_re_pipeline.py +99 -21
- src/predict.py +8 -34
- src/serializer/interface.py +2 -2
- src/serializer/json.py +13 -7
- src/start_demo.py +35 -14
- src/taskmodules/cross_text_binary_coref_nli.py +31 -4
- src/train.py +71 -10
- src/utils/__init__.py +2 -1
- src/utils/inference_utils.py +74 -0
- src/utils/span_utils.py +14 -0
src/analysis/combine_job_returns.py
CHANGED
@@ -47,6 +47,7 @@ def main(
|
|
47 |
transpose: bool = False,
|
48 |
unpack_multirun_results: bool = False,
|
49 |
in_percent: bool = False,
|
|
|
50 |
):
|
51 |
file_paths = get_file_paths(
|
52 |
paths_file=paths_file, file_name=file_name, use_aggregated=use_aggregated
|
@@ -97,9 +98,6 @@ def main(
|
|
97 |
data = data.unstack(index_name)
|
98 |
data = data.T
|
99 |
|
100 |
-
if transpose:
|
101 |
-
data = data.T
|
102 |
-
|
103 |
# needs to happen before rounding, otherwise the rounding will be off
|
104 |
if in_percent:
|
105 |
data = data * 100
|
@@ -107,20 +105,23 @@ def main(
|
|
107 |
if round_precision is not None:
|
108 |
data = data.round(round_precision)
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
elif format == "markdown_mean_and_std":
|
113 |
-
if transpose:
|
114 |
-
data = data.T
|
115 |
if "mean" not in data.columns or "std" not in data.columns:
|
116 |
raise ValueError("Columns 'mean' and 'std' are required for this format.")
|
117 |
# create a single column with mean and std in the format: mean ± std
|
118 |
data = pd.DataFrame(
|
119 |
data["mean"].astype(str) + " ± " + data["std"].astype(str), columns=["mean ± std"]
|
120 |
)
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
elif format == "json":
|
125 |
print(data.to_json())
|
126 |
else:
|
@@ -156,6 +157,9 @@ if __name__ == "__main__":
|
|
156 |
parser.add_argument(
|
157 |
"--in-percent", action="store_true", help="Show the values in percent (multiply by 100)"
|
158 |
)
|
|
|
|
|
|
|
159 |
parser.add_argument(
|
160 |
"--format",
|
161 |
type=str,
|
|
|
47 |
transpose: bool = False,
|
48 |
unpack_multirun_results: bool = False,
|
49 |
in_percent: bool = False,
|
50 |
+
reset_index: bool = False,
|
51 |
):
|
52 |
file_paths = get_file_paths(
|
53 |
paths_file=paths_file, file_name=file_name, use_aggregated=use_aggregated
|
|
|
98 |
data = data.unstack(index_name)
|
99 |
data = data.T
|
100 |
|
|
|
|
|
|
|
101 |
# needs to happen before rounding, otherwise the rounding will be off
|
102 |
if in_percent:
|
103 |
data = data * 100
|
|
|
105 |
if round_precision is not None:
|
106 |
data = data.round(round_precision)
|
107 |
|
108 |
+
# needs to happen before transposing
|
109 |
+
if format == "markdown_mean_and_std":
|
|
|
|
|
|
|
110 |
if "mean" not in data.columns or "std" not in data.columns:
|
111 |
raise ValueError("Columns 'mean' and 'std' are required for this format.")
|
112 |
# create a single column with mean and std in the format: mean ± std
|
113 |
data = pd.DataFrame(
|
114 |
data["mean"].astype(str) + " ± " + data["std"].astype(str), columns=["mean ± std"]
|
115 |
)
|
116 |
+
|
117 |
+
if transpose:
|
118 |
+
data = data.T
|
119 |
+
|
120 |
+
if reset_index:
|
121 |
+
data = data.reset_index()
|
122 |
+
|
123 |
+
if format in ["markdown", "markdown_mean_and_std"]:
|
124 |
+
print(data.to_markdown(index=not reset_index))
|
125 |
elif format == "json":
|
126 |
print(data.to_json())
|
127 |
else:
|
|
|
157 |
parser.add_argument(
|
158 |
"--in-percent", action="store_true", help="Show the values in percent (multiply by 100)"
|
159 |
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--reset-index", action="store_true", help="Reset the index of the combined job returns"
|
162 |
+
)
|
163 |
parser.add_argument(
|
164 |
"--format",
|
165 |
type=str,
|
src/analysis/show_score_distribution.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyrootutils
|
2 |
+
|
3 |
+
root = pyrootutils.setup_root(
|
4 |
+
search_from=__file__,
|
5 |
+
indicator=[".project-root"],
|
6 |
+
pythonpath=True,
|
7 |
+
dotenv=False,
|
8 |
+
)
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
from typing import List, Optional
|
12 |
+
|
13 |
+
import pandas as pd
|
14 |
+
import plotly.figure_factory as ff
|
15 |
+
from pie_datasets import DatasetDict
|
16 |
+
|
17 |
+
pd.options.plotting.backend = "plotly"
|
18 |
+
|
19 |
+
if __name__ == "__main__":
|
20 |
+
|
21 |
+
parser = argparse.ArgumentParser(
|
22 |
+
description="Show score distribution of annotations per layer"
|
23 |
+
)
|
24 |
+
# --data-dir predictions/default/2025-02-26_14-28-17
|
25 |
+
parser.add_argument(
|
26 |
+
"--data-dir", type=str, required=True, help="Path to the dataset directory"
|
27 |
+
)
|
28 |
+
parser.add_argument("--split", type=str, default="test", help="Dataset split to use")
|
29 |
+
parser.add_argument(
|
30 |
+
"--layers",
|
31 |
+
nargs="+",
|
32 |
+
default=["labeled_spans", "binary_relations"],
|
33 |
+
help="Annotation layers to use",
|
34 |
+
)
|
35 |
+
# --layer-captions ADUs "Argumentative Relations"
|
36 |
+
parser.add_argument(
|
37 |
+
"--layer-captions", nargs="+", help="Captions for the figure traces per layer"
|
38 |
+
)
|
39 |
+
# --layer-colors "rgb(31,119,180)" "rgb(255,127,14)"
|
40 |
+
parser.add_argument("--layer-colors", nargs="+", help="Colors for the figure traces per layer")
|
41 |
+
|
42 |
+
args = parser.parse_args()
|
43 |
+
|
44 |
+
# Load the dataset
|
45 |
+
ds = DatasetDict.from_json(data_dir=args.data_dir)[args.split]
|
46 |
+
|
47 |
+
# get scores per annotation layer and label
|
48 |
+
layers = args.layers
|
49 |
+
all_scores = []
|
50 |
+
all_scores_idx = []
|
51 |
+
for doc in ds:
|
52 |
+
for layer in layers:
|
53 |
+
for ann in doc[layer].predictions:
|
54 |
+
all_scores.append(ann.score)
|
55 |
+
all_scores_idx.append((doc.id, layer, getattr(ann, "label", None)))
|
56 |
+
scores = pd.Series(
|
57 |
+
all_scores,
|
58 |
+
index=pd.MultiIndex.from_tuples(all_scores_idx, names=["doc_id", "layer", "label"]),
|
59 |
+
name="score",
|
60 |
+
)
|
61 |
+
|
62 |
+
if args.layer_captions is not None:
|
63 |
+
if len(args.layer_captions) < len(layers):
|
64 |
+
raise ValueError("Not enough captions provided for all layers")
|
65 |
+
name_mapping = dict(zip(layers, args.layer_captions))
|
66 |
+
else:
|
67 |
+
name_mapping = dict(zip(layers, layers))
|
68 |
+
|
69 |
+
colors: Optional[List[str]] = None
|
70 |
+
if args.layer_colors is not None:
|
71 |
+
if len(args.layer_colors) < len(layers):
|
72 |
+
raise ValueError("Not enough colors provided for all layers")
|
73 |
+
color_mapping = dict(zip(layers, args.layer_colors))
|
74 |
+
colors = [color_mapping[layer] for layer in layers]
|
75 |
+
else:
|
76 |
+
colors = None
|
77 |
+
|
78 |
+
score_groups = {layer: scores.xs(layer, level="layer").to_numpy() for layer in layers}
|
79 |
+
group_labels, hist_data = zip(*score_groups.items())
|
80 |
+
group_labels_renamed = [name_mapping[label] for label in group_labels]
|
81 |
+
fig = ff.create_distplot(
|
82 |
+
hist_data,
|
83 |
+
group_labels=group_labels_renamed,
|
84 |
+
show_hist=True,
|
85 |
+
colors=colors,
|
86 |
+
bin_size=0.025,
|
87 |
+
)
|
88 |
+
|
89 |
+
fig.update_layout(
|
90 |
+
height=600,
|
91 |
+
width=800,
|
92 |
+
title_text="Score Distribution per Annotation Layer",
|
93 |
+
title_x=0.5,
|
94 |
+
barmode="group",
|
95 |
+
)
|
96 |
+
fig.update_layout(legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01))
|
97 |
+
|
98 |
+
fig.show()
|
99 |
+
print("done")
|
src/data/calc_iaa_for_brat.py
CHANGED
@@ -92,6 +92,7 @@ def calc_brat_iaas(
|
|
92 |
create_multi_spans=True,
|
93 |
result_document_type=BratDocument,
|
94 |
result_field_mapping={"spans": "spans", "relations": "relations"},
|
|
|
95 |
)
|
96 |
else:
|
97 |
merger = None
|
|
|
92 |
create_multi_spans=True,
|
93 |
result_document_type=BratDocument,
|
94 |
result_field_mapping={"spans": "spans", "relations": "relations"},
|
95 |
+
combine_scores_method="product",
|
96 |
)
|
97 |
else:
|
98 |
merger = None
|
src/datamodules/datamodule.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Union
|
2 |
|
3 |
from pytorch_ie.core import Document
|
@@ -21,6 +22,8 @@ DatasetType: TypeAlias = Union[
|
|
21 |
IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
|
22 |
]
|
23 |
|
|
|
|
|
24 |
|
25 |
class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, TargetEncoding]):
|
26 |
"""A simple LightningDataModule for PIE document datasets.
|
@@ -49,6 +52,7 @@ class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, Ta
|
|
49 |
test_split: Optional[str] = "test",
|
50 |
show_progress_for_encode: bool = False,
|
51 |
train_sampler: Optional[str] = None,
|
|
|
52 |
**dataloader_kwargs,
|
53 |
):
|
54 |
super().__init__()
|
@@ -62,6 +66,7 @@ class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, Ta
|
|
62 |
self.show_progress_for_encode = show_progress_for_encode
|
63 |
self.train_sampler_name = train_sampler
|
64 |
self.dataloader_kwargs = dataloader_kwargs
|
|
|
65 |
|
66 |
self._data: Dict[str, DatasetType] = {}
|
67 |
|
@@ -128,12 +133,17 @@ class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, Ta
|
|
128 |
sampler = self.get_train_sampler(sampler_name=self.train_sampler_name, dataset=ds)
|
129 |
else:
|
130 |
sampler = None
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
return DataLoader(
|
132 |
dataset=ds,
|
133 |
sampler=sampler,
|
134 |
collate_fn=self.taskmodule.collate,
|
135 |
-
|
136 |
-
shuffle=not (isinstance(ds, IterableTaskEncodingDataset) or sampler is not None),
|
137 |
**self.dataloader_kwargs,
|
138 |
)
|
139 |
|
|
|
1 |
+
import logging
|
2 |
from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Union
|
3 |
|
4 |
from pytorch_ie.core import Document
|
|
|
22 |
IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
|
23 |
]
|
24 |
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
|
28 |
class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, TargetEncoding]):
|
29 |
"""A simple LightningDataModule for PIE document datasets.
|
|
|
52 |
test_split: Optional[str] = "test",
|
53 |
show_progress_for_encode: bool = False,
|
54 |
train_sampler: Optional[str] = None,
|
55 |
+
dont_shuffle_train: bool = False,
|
56 |
**dataloader_kwargs,
|
57 |
):
|
58 |
super().__init__()
|
|
|
66 |
self.show_progress_for_encode = show_progress_for_encode
|
67 |
self.train_sampler_name = train_sampler
|
68 |
self.dataloader_kwargs = dataloader_kwargs
|
69 |
+
self.dont_shuffle_train = dont_shuffle_train
|
70 |
|
71 |
self._data: Dict[str, DatasetType] = {}
|
72 |
|
|
|
133 |
sampler = self.get_train_sampler(sampler_name=self.train_sampler_name, dataset=ds)
|
134 |
else:
|
135 |
sampler = None
|
136 |
+
# don't shuffle streamed datasets or if we use a sampler or if we explicitly set dont_shuffle_train
|
137 |
+
shuffle = not self.dont_shuffle_train and not (
|
138 |
+
isinstance(ds, IterableTaskEncodingDataset) or sampler is not None
|
139 |
+
)
|
140 |
+
if not shuffle:
|
141 |
+
logger.warning("not shuffling train dataloader")
|
142 |
return DataLoader(
|
143 |
dataset=ds,
|
144 |
sampler=sampler,
|
145 |
collate_fn=self.taskmodule.collate,
|
146 |
+
shuffle=shuffle,
|
|
|
147 |
**self.dataloader_kwargs,
|
148 |
)
|
149 |
|
src/demo/annotation_utils.py
CHANGED
@@ -37,6 +37,7 @@ def get_merger() -> SpansViaRelationMerger:
|
|
37 |
"binary_relations": "binary_relations",
|
38 |
"labeled_partitions": "labeled_partitions",
|
39 |
},
|
|
|
40 |
)
|
41 |
|
42 |
|
|
|
37 |
"binary_relations": "binary_relations",
|
38 |
"labeled_partitions": "labeled_partitions",
|
39 |
},
|
40 |
+
combine_scores_method="product",
|
41 |
)
|
42 |
|
43 |
|
src/demo/retrieve_and_dump_all_relevant.py
CHANGED
@@ -10,9 +10,17 @@ root = pyrootutils.setup_root(
|
|
10 |
import argparse
|
11 |
import logging
|
12 |
import os
|
|
|
13 |
|
14 |
import pandas as pd
|
|
|
|
|
|
|
15 |
|
|
|
|
|
|
|
|
|
16 |
from src.demo.retriever_utils import (
|
17 |
retrieve_all_relevant_spans,
|
18 |
retrieve_all_relevant_spans_for_all_documents,
|
@@ -23,6 +31,168 @@ from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
|
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
if __name__ == "__main__":
|
27 |
|
28 |
parser = argparse.ArgumentParser()
|
@@ -81,6 +251,19 @@ if __name__ == "__main__":
|
|
81 |
'(each separated by ":") to retrieve spans for. If provided, '
|
82 |
"--query_doc_id and --query_span_id are ignored.",
|
83 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
args = parser.parse_args()
|
85 |
|
86 |
logging.basicConfig(
|
@@ -157,4 +340,17 @@ if __name__ == "__main__":
|
|
157 |
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
|
158 |
all_spans_for_all_documents.to_json(args.output_path)
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
logger.info("done")
|
|
|
10 |
import argparse
|
11 |
import logging
|
12 |
import os
|
13 |
+
from typing import Dict, List, Optional, Tuple
|
14 |
|
15 |
import pandas as pd
|
16 |
+
from pie_datasets import Dataset, DatasetDict
|
17 |
+
from pytorch_ie import Annotation
|
18 |
+
from pytorch_ie.annotations import BinaryRelation, MultiSpan, Span
|
19 |
|
20 |
+
from document.types import (
|
21 |
+
RelatedRelation,
|
22 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
|
23 |
+
)
|
24 |
from src.demo.retriever_utils import (
|
25 |
retrieve_all_relevant_spans,
|
26 |
retrieve_all_relevant_spans_for_all_documents,
|
|
|
31 |
logger = logging.getLogger(__name__)
|
32 |
|
33 |
|
34 |
+
def get_original_doc_id_and_offsets(doc_id: str) -> Tuple[str, int, Optional[int]]:
|
35 |
+
original_doc_id, middle, start_end, ext = doc_id.split(".")
|
36 |
+
if middle == "remaining":
|
37 |
+
return original_doc_id, int(start_end), None
|
38 |
+
elif middle == "abstract":
|
39 |
+
start, end = start_end.split("_")
|
40 |
+
return original_doc_id, int(start), int(end)
|
41 |
+
else:
|
42 |
+
raise ValueError(f"unexpected doc_id format: {doc_id}")
|
43 |
+
|
44 |
+
|
45 |
+
def add_base_annotations(
|
46 |
+
documents: Dict[
|
47 |
+
str, TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations
|
48 |
+
],
|
49 |
+
retrieved_doc_ids: List[str],
|
50 |
+
retriever: DocumentAwareSpanRetrieverWithRelations,
|
51 |
+
) -> Dict[Tuple[str, Annotation], Tuple[str, Annotation]]:
|
52 |
+
# (retrieved_doc_id, retrieved_annotation) -> (original_doc_id, original_annotation)
|
53 |
+
annotation_mapping = {}
|
54 |
+
for retrieved_doc_id in retrieved_doc_ids:
|
55 |
+
pie_doc = retriever.get_document(retrieved_doc_id).metadata["pie_document"].copy()
|
56 |
+
original_doc_id, offset, _ = get_original_doc_id_and_offsets(retrieved_doc_id)
|
57 |
+
document = documents[original_doc_id]
|
58 |
+
span_mapping = {}
|
59 |
+
for span in pie_doc.labeled_multi_spans.predictions:
|
60 |
+
if isinstance(span, MultiSpan):
|
61 |
+
new_span = span.copy(
|
62 |
+
slices=[(start + offset, end + offset) for start, end in span.slices]
|
63 |
+
)
|
64 |
+
elif isinstance(span, Span):
|
65 |
+
new_span = span.copy(start=span.start + offset, end=span.end + offset)
|
66 |
+
else:
|
67 |
+
raise ValueError(f"unexpected span type: {span}")
|
68 |
+
span_mapping[span] = new_span
|
69 |
+
document.labeled_multi_spans.predictions.extend(span_mapping.values())
|
70 |
+
for relation in pie_doc.binary_relations.predictions:
|
71 |
+
new_relation = relation.copy(
|
72 |
+
head=span_mapping[relation.head], tail=span_mapping[relation.tail]
|
73 |
+
)
|
74 |
+
document.binary_relations.predictions.append(new_relation)
|
75 |
+
for old_ann, new_ann in span_mapping.items():
|
76 |
+
annotation_mapping[(retrieved_doc_id, old_ann)] = (original_doc_id, new_ann)
|
77 |
+
|
78 |
+
return annotation_mapping
|
79 |
+
|
80 |
+
|
81 |
+
def get_doc_and_span_id2annotation_mapping(
|
82 |
+
span_ids: pd.Series,
|
83 |
+
doc_ids: pd.Series,
|
84 |
+
retriever: DocumentAwareSpanRetrieverWithRelations,
|
85 |
+
base_annotation_mapping: Dict[Tuple[str, Annotation], Tuple[str, Annotation]],
|
86 |
+
) -> Dict[Tuple[str, str], Tuple[str, Annotation]]:
|
87 |
+
if len(doc_ids) != len(span_ids):
|
88 |
+
raise ValueError("doc_ids and span_ids must have the same length")
|
89 |
+
doc_and_span_ids = zip(doc_ids.tolist(), span_ids.tolist())
|
90 |
+
return {
|
91 |
+
(doc_id, span_id): base_annotation_mapping[(doc_id, retriever.get_span_by_id(span_id))]
|
92 |
+
for doc_id, span_id in set(doc_and_span_ids)
|
93 |
+
}
|
94 |
+
|
95 |
+
|
96 |
+
def add_result_to_gold_data(
|
97 |
+
result: pd.DataFrame,
|
98 |
+
gold_dataset_dir: str,
|
99 |
+
dataset_out_dir: str,
|
100 |
+
retriever: DocumentAwareSpanRetrieverWithRelations,
|
101 |
+
split: Optional[str] = None,
|
102 |
+
link_relation_label: str = "semantically_same",
|
103 |
+
reversed_relation_suffix: str = "_reversed",
|
104 |
+
):
|
105 |
+
|
106 |
+
if not os.path.exists(gold_dataset_dir):
|
107 |
+
raise ValueError(f"gold dataset directory does not exist: {gold_dataset_dir}")
|
108 |
+
|
109 |
+
dataset_dict = DatasetDict.from_json(data_dir=gold_dataset_dir)
|
110 |
+
if split is None and len(dataset_dict) == 1:
|
111 |
+
split = list(dataset_dict.keys())[0]
|
112 |
+
if split is None:
|
113 |
+
raise ValueError("need to provide split name to add results to gold dataset")
|
114 |
+
|
115 |
+
dataset = dataset_dict[split]
|
116 |
+
|
117 |
+
doc_id2doc = {doc.id: doc for doc in dataset}
|
118 |
+
retriever_doc_ids = (
|
119 |
+
result["doc_id"].unique().tolist() + result["query_doc_id"].unique().tolist()
|
120 |
+
)
|
121 |
+
base_annotation_mapping = add_base_annotations(
|
122 |
+
documents=doc_id2doc, retrieved_doc_ids=retriever_doc_ids, retriever=retriever
|
123 |
+
)
|
124 |
+
# (retriever_doc_id, retriever_span_id) -> (original_doc_id, original_span)
|
125 |
+
doc_and_span_id2annotation = {}
|
126 |
+
doc_and_span_id2annotation.update(
|
127 |
+
get_doc_and_span_id2annotation_mapping(
|
128 |
+
span_ids=result["span_id"],
|
129 |
+
doc_ids=result["doc_id"],
|
130 |
+
retriever=retriever,
|
131 |
+
base_annotation_mapping=base_annotation_mapping,
|
132 |
+
)
|
133 |
+
)
|
134 |
+
doc_and_span_id2annotation.update(
|
135 |
+
get_doc_and_span_id2annotation_mapping(
|
136 |
+
span_ids=result["ref_span_id"],
|
137 |
+
doc_ids=result["doc_id"],
|
138 |
+
retriever=retriever,
|
139 |
+
base_annotation_mapping=base_annotation_mapping,
|
140 |
+
)
|
141 |
+
)
|
142 |
+
doc_and_span_id2annotation.update(
|
143 |
+
get_doc_and_span_id2annotation_mapping(
|
144 |
+
span_ids=result["query_span_id"],
|
145 |
+
doc_ids=result["query_doc_id"],
|
146 |
+
retriever=retriever,
|
147 |
+
base_annotation_mapping=base_annotation_mapping,
|
148 |
+
)
|
149 |
+
)
|
150 |
+
doc_id2head_tail2relation = {}
|
151 |
+
for doc_id, doc in doc_id2doc.items():
|
152 |
+
head_and_tail2relation = {}
|
153 |
+
for relation in doc.binary_relations.predictions:
|
154 |
+
head_and_tail2relation[(relation.head, relation.tail)] = relation
|
155 |
+
doc_id2head_tail2relation[doc_id] = head_and_tail2relation
|
156 |
+
|
157 |
+
for row in result.itertuples():
|
158 |
+
query_doc_id, query_span = doc_and_span_id2annotation[
|
159 |
+
(row.query_doc_id, row.query_span_id)
|
160 |
+
]
|
161 |
+
doc_id, span = doc_and_span_id2annotation[(row.doc_id, row.span_id)]
|
162 |
+
doc_id2, ref_span = doc_and_span_id2annotation[(row.doc_id, row.ref_span_id)]
|
163 |
+
if doc_id != query_doc_id:
|
164 |
+
raise ValueError("doc_id and query_doc_id must be the same")
|
165 |
+
if doc_id != doc_id2:
|
166 |
+
raise ValueError("doc_id and ref_doc_id must be the same")
|
167 |
+
doc = doc_id2doc[doc_id]
|
168 |
+
link_rel = BinaryRelation(
|
169 |
+
head=query_span, tail=ref_span, label=link_relation_label, score=row.sim_score
|
170 |
+
)
|
171 |
+
doc.binary_relations.predictions.append(link_rel)
|
172 |
+
head_and_tail2relation = doc_id2head_tail2relation[doc_id]
|
173 |
+
related_rel_label = row.type
|
174 |
+
if related_rel_label.endswith(reversed_relation_suffix):
|
175 |
+
base_rel = head_and_tail2relation[(span, ref_span)]
|
176 |
+
else:
|
177 |
+
base_rel = head_and_tail2relation[(ref_span, span)]
|
178 |
+
related_rel = RelatedRelation(
|
179 |
+
head=query_span,
|
180 |
+
tail=span,
|
181 |
+
link_relation=link_rel,
|
182 |
+
relation=base_rel,
|
183 |
+
label=related_rel_label,
|
184 |
+
score=link_rel.score * base_rel.score,
|
185 |
+
)
|
186 |
+
doc.related_relations.predictions.append(related_rel)
|
187 |
+
|
188 |
+
dataset = Dataset.from_documents(list(doc_id2doc.values()))
|
189 |
+
dataset_dict = DatasetDict({split: dataset})
|
190 |
+
if not os.path.exists(dataset_out_dir):
|
191 |
+
os.makedirs(dataset_out_dir, exist_ok=True)
|
192 |
+
|
193 |
+
dataset_dict.to_json(dataset_out_dir)
|
194 |
+
|
195 |
+
|
196 |
if __name__ == "__main__":
|
197 |
|
198 |
parser = argparse.ArgumentParser()
|
|
|
251 |
'(each separated by ":") to retrieve spans for. If provided, '
|
252 |
"--query_doc_id and --query_span_id are ignored.",
|
253 |
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--gold_dataset_dir",
|
256 |
+
type=str,
|
257 |
+
default=None,
|
258 |
+
help="If provided, add the spans and base relations from the retriever data as well "
|
259 |
+
"as the related relations to the gold dataset.",
|
260 |
+
)
|
261 |
+
parser.add_argument(
|
262 |
+
"--dataset_out_dir",
|
263 |
+
type=str,
|
264 |
+
default=None,
|
265 |
+
help="If provided, save the enriched gold dataset to this directory.",
|
266 |
+
)
|
267 |
args = parser.parse_args()
|
268 |
|
269 |
logging.basicConfig(
|
|
|
340 |
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
|
341 |
all_spans_for_all_documents.to_json(args.output_path)
|
342 |
|
343 |
+
if args.gold_dataset_dir is not None:
|
344 |
+
logger.info(
|
345 |
+
f"reading gold data from {args.gold_dataset_dir} and adding results as predictions ..."
|
346 |
+
)
|
347 |
+
if args.dataset_out_dir is None:
|
348 |
+
raise ValueError("need to provide --dataset_out_dir to save the enriched dataset")
|
349 |
+
add_result_to_gold_data(
|
350 |
+
all_spans_for_all_documents,
|
351 |
+
gold_dataset_dir=args.gold_dataset_dir,
|
352 |
+
dataset_out_dir=args.dataset_out_dir,
|
353 |
+
retriever=retriever,
|
354 |
+
)
|
355 |
+
|
356 |
logger.info("done")
|
src/demo/retriever_utils.py
CHANGED
@@ -51,6 +51,7 @@ def load_retriever(
|
|
51 |
def retrieve_similar_spans(
|
52 |
retriever: DocumentAwareSpanRetriever,
|
53 |
query_span_id: str,
|
|
|
54 |
**kwargs,
|
55 |
) -> pd.DataFrame:
|
56 |
if not query_span_id.strip():
|
@@ -60,21 +61,42 @@ def retrieve_similar_spans(
|
|
60 |
records = []
|
61 |
for similar_span_doc in retrieval_result:
|
62 |
pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc)
|
|
|
|
|
63 |
span_ann = metadata["attached_span"]
|
|
|
|
|
|
|
64 |
records.append(
|
65 |
{
|
|
|
66 |
"doc_id": pie_doc.id,
|
67 |
"span_id": similar_span_doc.id,
|
68 |
-
"
|
|
|
|
|
69 |
"label": span_ann.label,
|
70 |
"text": str(span_ann),
|
71 |
}
|
72 |
)
|
73 |
-
|
74 |
-
pd.DataFrame(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
.sort_values(by="score", ascending=False)
|
76 |
.round(3)
|
77 |
)
|
|
|
78 |
except Exception as e:
|
79 |
raise gr.Error(f"Failed to retrieve similar ADUs: {e}")
|
80 |
|
@@ -83,6 +105,7 @@ def retrieve_relevant_spans(
|
|
83 |
retriever: DocumentAwareSpanRetriever,
|
84 |
query_span_id: str,
|
85 |
relation_label_mapping: Optional[dict[str, str]] = None,
|
|
|
86 |
**kwargs,
|
87 |
) -> pd.DataFrame:
|
88 |
if not query_span_id.strip():
|
@@ -98,40 +121,57 @@ def retrieve_relevant_spans(
|
|
98 |
mapped_relation_label = relation_label_mapping.get(
|
99 |
metadata["relation_label"], metadata["relation_label"]
|
100 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
records.append(
|
102 |
{
|
103 |
"doc_id": pie_doc.id,
|
104 |
"type": mapped_relation_label,
|
105 |
-
"
|
|
|
106 |
"text": str(tail_span_ann),
|
107 |
"span_id": relevant_span_doc.id,
|
|
|
108 |
"label": tail_span_ann.label,
|
109 |
-
"
|
110 |
"ref_label": span_ann.label,
|
111 |
"ref_text": str(span_ann),
|
112 |
"ref_span_id": metadata["head_id"],
|
|
|
|
|
113 |
}
|
114 |
)
|
115 |
-
|
116 |
pd.DataFrame(
|
117 |
records,
|
118 |
columns=[
|
|
|
119 |
"type",
|
120 |
-
# omitted for now, we get no valid relation scores for the generative model
|
121 |
-
# "rel_score",
|
122 |
-
"ref_score",
|
123 |
-
"label",
|
124 |
"text",
|
|
|
|
|
|
|
125 |
"ref_label",
|
|
|
126 |
"ref_text",
|
127 |
"doc_id",
|
128 |
"span_id",
|
|
|
129 |
"ref_span_id",
|
|
|
130 |
],
|
131 |
)
|
132 |
-
.sort_values(by=["
|
133 |
.round(3)
|
134 |
)
|
|
|
|
|
135 |
except Exception as e:
|
136 |
raise gr.Error(f"Failed to retrieve relevant ADUs: {e}")
|
137 |
|
|
|
51 |
def retrieve_similar_spans(
|
52 |
retriever: DocumentAwareSpanRetriever,
|
53 |
query_span_id: str,
|
54 |
+
min_score: float = 0.0,
|
55 |
**kwargs,
|
56 |
) -> pd.DataFrame:
|
57 |
if not query_span_id.strip():
|
|
|
61 |
records = []
|
62 |
for similar_span_doc in retrieval_result:
|
63 |
pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc)
|
64 |
+
query_span = retriever.get_span_by_id(span_id=query_span_id)
|
65 |
+
query_span_score = query_span.score
|
66 |
span_ann = metadata["attached_span"]
|
67 |
+
sim_score = metadata["relevance_score"]
|
68 |
+
span_score = span_ann.score
|
69 |
+
score = query_span_score * sim_score * span_score
|
70 |
records.append(
|
71 |
{
|
72 |
+
"score": score,
|
73 |
"doc_id": pie_doc.id,
|
74 |
"span_id": similar_span_doc.id,
|
75 |
+
"sim_score": sim_score,
|
76 |
+
"query_span_score": query_span_score,
|
77 |
+
"span_score": span_score,
|
78 |
"label": span_ann.label,
|
79 |
"text": str(span_ann),
|
80 |
}
|
81 |
)
|
82 |
+
result = (
|
83 |
+
pd.DataFrame(
|
84 |
+
records,
|
85 |
+
columns=[
|
86 |
+
"score",
|
87 |
+
"text",
|
88 |
+
"label",
|
89 |
+
"sim_score",
|
90 |
+
"span_score",
|
91 |
+
"query_span_score",
|
92 |
+
"doc_id",
|
93 |
+
"span_id",
|
94 |
+
],
|
95 |
+
)
|
96 |
.sort_values(by="score", ascending=False)
|
97 |
.round(3)
|
98 |
)
|
99 |
+
return result[result["score"] >= min_score]
|
100 |
except Exception as e:
|
101 |
raise gr.Error(f"Failed to retrieve similar ADUs: {e}")
|
102 |
|
|
|
105 |
retriever: DocumentAwareSpanRetriever,
|
106 |
query_span_id: str,
|
107 |
relation_label_mapping: Optional[dict[str, str]] = None,
|
108 |
+
min_score: float = 0.0,
|
109 |
**kwargs,
|
110 |
) -> pd.DataFrame:
|
111 |
if not query_span_id.strip():
|
|
|
121 |
mapped_relation_label = relation_label_mapping.get(
|
122 |
metadata["relation_label"], metadata["relation_label"]
|
123 |
)
|
124 |
+
|
125 |
+
query_span = retriever.get_span_by_id(span_id=query_span_id)
|
126 |
+
query_span_score = query_span.score
|
127 |
+
sim_score = metadata["relevance_score"]
|
128 |
+
ref_span_score = span_ann.score
|
129 |
+
rel_score = metadata["relation_score"]
|
130 |
+
span_score = tail_span_ann.score
|
131 |
+
score = query_span_score * sim_score * ref_span_score * rel_score * span_score
|
132 |
records.append(
|
133 |
{
|
134 |
"doc_id": pie_doc.id,
|
135 |
"type": mapped_relation_label,
|
136 |
+
"score": score,
|
137 |
+
"rel_score": rel_score,
|
138 |
"text": str(tail_span_ann),
|
139 |
"span_id": relevant_span_doc.id,
|
140 |
+
"span_score": span_score,
|
141 |
"label": tail_span_ann.label,
|
142 |
+
"sim_score": sim_score,
|
143 |
"ref_label": span_ann.label,
|
144 |
"ref_text": str(span_ann),
|
145 |
"ref_span_id": metadata["head_id"],
|
146 |
+
"ref_span_score": ref_span_score,
|
147 |
+
"query_span_score": query_span_score,
|
148 |
}
|
149 |
)
|
150 |
+
result = (
|
151 |
pd.DataFrame(
|
152 |
records,
|
153 |
columns=[
|
154 |
+
"score",
|
155 |
"type",
|
|
|
|
|
|
|
|
|
156 |
"text",
|
157 |
+
"label",
|
158 |
+
"rel_score",
|
159 |
+
"sim_score",
|
160 |
"ref_label",
|
161 |
+
"ref_span_score",
|
162 |
"ref_text",
|
163 |
"doc_id",
|
164 |
"span_id",
|
165 |
+
"span_score",
|
166 |
"ref_span_id",
|
167 |
+
"query_span_score",
|
168 |
],
|
169 |
)
|
170 |
+
.sort_values(by=["score"], ascending=False)
|
171 |
.round(3)
|
172 |
)
|
173 |
+
return result[result["score"] >= min_score]
|
174 |
+
|
175 |
except Exception as e:
|
176 |
raise gr.Error(f"Failed to retrieve relevant ADUs: {e}")
|
177 |
|
src/document/processing.py
CHANGED
@@ -1,16 +1,20 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import logging
|
4 |
-
from
|
|
|
5 |
|
6 |
-
from pie_modules.document.processing.merge_spans_via_relation import _merge_spans_via_relation
|
7 |
-
from pie_modules.documents import TextDocumentWithLabeledMultiSpansAndBinaryRelations
|
8 |
from pie_modules.utils.span import have_overlap
|
9 |
from pytorch_ie import AnnotationLayer
|
|
|
10 |
from pytorch_ie.core import Document
|
11 |
from pytorch_ie.core.document import Annotation, _enumerate_dependencies
|
12 |
|
13 |
-
from src.
|
|
|
|
|
|
|
|
|
14 |
from src.utils.span_utils import get_overlap_len
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
@@ -68,58 +72,6 @@ def remove_overlapping_entities(
|
|
68 |
return new_doc
|
69 |
|
70 |
|
71 |
-
# TODO: remove and use pie_modules.document.processing.SpansViaRelationMerger instead
|
72 |
-
def merge_spans_via_relation(
|
73 |
-
document: D,
|
74 |
-
relation_layer: str,
|
75 |
-
link_relation_label: str,
|
76 |
-
use_predicted_spans: bool = False,
|
77 |
-
process_predictions: bool = True,
|
78 |
-
create_multi_spans: bool = False,
|
79 |
-
) -> D:
|
80 |
-
|
81 |
-
rel_layer = document[relation_layer]
|
82 |
-
span_layer = rel_layer.target_layer
|
83 |
-
new_gold_spans, new_gold_relations = _merge_spans_via_relation(
|
84 |
-
spans=span_layer,
|
85 |
-
relations=rel_layer,
|
86 |
-
link_relation_label=link_relation_label,
|
87 |
-
create_multi_spans=create_multi_spans,
|
88 |
-
)
|
89 |
-
if process_predictions:
|
90 |
-
new_pred_spans, new_pred_relations = _merge_spans_via_relation(
|
91 |
-
spans=span_layer.predictions if use_predicted_spans else span_layer,
|
92 |
-
relations=rel_layer.predictions,
|
93 |
-
link_relation_label=link_relation_label,
|
94 |
-
create_multi_spans=create_multi_spans,
|
95 |
-
)
|
96 |
-
else:
|
97 |
-
assert not use_predicted_spans
|
98 |
-
new_pred_spans = set(span_layer.predictions.clear())
|
99 |
-
new_pred_relations = set(rel_layer.predictions.clear())
|
100 |
-
|
101 |
-
relation_layer_name = relation_layer
|
102 |
-
span_layer_name = document[relation_layer].target_name
|
103 |
-
if create_multi_spans:
|
104 |
-
doc_dict = document.asdict()
|
105 |
-
for f in document.annotation_fields():
|
106 |
-
doc_dict.pop(f.name)
|
107 |
-
|
108 |
-
result = TextDocumentWithLabeledMultiSpansAndBinaryRelations.fromdict(doc_dict)
|
109 |
-
result.labeled_multi_spans.extend(new_gold_spans)
|
110 |
-
result.labeled_multi_spans.predictions.extend(new_pred_spans)
|
111 |
-
result.binary_relations.extend(new_gold_relations)
|
112 |
-
result.binary_relations.predictions.extend(new_pred_relations)
|
113 |
-
else:
|
114 |
-
result = document.copy(with_annotations=False)
|
115 |
-
result[span_layer_name].extend(new_gold_spans)
|
116 |
-
result[span_layer_name].predictions.extend(new_pred_spans)
|
117 |
-
result[relation_layer_name].extend(new_gold_relations)
|
118 |
-
result[relation_layer_name].predictions.extend(new_pred_relations)
|
119 |
-
|
120 |
-
return result
|
121 |
-
|
122 |
-
|
123 |
def remove_partitions_by_labels(
|
124 |
document: D, partition_layer: str, label_blacklist: List[str], span_layer: Optional[str] = None
|
125 |
) -> D:
|
@@ -249,31 +201,19 @@ def relabel_annotations(
|
|
249 |
DWithSpans = TypeVar("DWithSpans", bound=Document)
|
250 |
|
251 |
|
252 |
-
def
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
predicted span and the gold span have an overlap of at least half of the maximum length
|
261 |
-
of the two spans, the predicted span is aligned with the gold span.
|
262 |
-
|
263 |
-
Args:
|
264 |
-
document: The document to process.
|
265 |
-
span_layer: The name of the span layer.
|
266 |
-
distance_type: The type of distance to calculate. One of: center, inner, outer
|
267 |
-
verbose: Whether to print debug information.
|
268 |
|
269 |
-
Returns:
|
270 |
-
The processed document.
|
271 |
-
"""
|
272 |
-
gold_spans = document[span_layer]
|
273 |
-
if len(gold_spans) == 0:
|
274 |
-
return document.copy()
|
275 |
|
276 |
-
|
|
|
|
|
277 |
old2new_pred_span = {}
|
278 |
span_id2gold_span = {}
|
279 |
for pred_span in pred_spans:
|
@@ -282,29 +222,32 @@ def align_predicted_span_annotations(
|
|
282 |
(
|
283 |
gold_span,
|
284 |
distance(
|
285 |
-
start_end=(pred_span
|
286 |
-
other_start_end=(gold_span
|
287 |
distance_type=distance_type,
|
288 |
),
|
289 |
)
|
290 |
for gold_span in gold_spans
|
291 |
]
|
|
|
|
|
292 |
|
293 |
closest_gold_span, min_distance = min(gold_spans_with_distance, key=lambda x: x[1])
|
294 |
# if the closest gold span is the same as the predicted span, we don't need to align
|
295 |
if min_distance == 0.0:
|
296 |
continue
|
297 |
|
|
|
|
|
|
|
298 |
if have_overlap(
|
299 |
-
start_end=
|
300 |
-
other_start_end=
|
301 |
):
|
302 |
-
overlap_len = get_overlap_len(
|
303 |
-
(pred_span.start, pred_span.end), (closest_gold_span.start, closest_gold_span.end)
|
304 |
-
)
|
305 |
-
# get the maximum length of the two spans
|
306 |
l_max = max(
|
307 |
-
|
|
|
308 |
)
|
309 |
# if the overlap is at least half of the maximum length, we consider it a valid match for alignment
|
310 |
valid_match = overlap_len >= (l_max / 2)
|
@@ -312,12 +255,140 @@ def align_predicted_span_annotations(
|
|
312 |
valid_match = False
|
313 |
|
314 |
if valid_match:
|
315 |
-
|
316 |
-
|
317 |
-
|
|
|
|
|
|
|
|
|
|
|
318 |
old2new_pred_span[pred_span._id] = aligned_pred_span
|
319 |
span_id2gold_span[pred_span._id] = closest_gold_span
|
320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
result = document.copy(with_annotations=False)
|
322 |
|
323 |
# multiple predicted spans can be aligned with the same gold span,
|
@@ -356,3 +427,88 @@ def align_predicted_span_annotations(
|
|
356 |
)
|
357 |
|
358 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import logging
|
4 |
+
from collections import defaultdict
|
5 |
+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union
|
6 |
|
|
|
|
|
7 |
from pie_modules.utils.span import have_overlap
|
8 |
from pytorch_ie import AnnotationLayer
|
9 |
+
from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan, MultiSpan, Span
|
10 |
from pytorch_ie.core import Document
|
11 |
from pytorch_ie.core.document import Annotation, _enumerate_dependencies
|
12 |
|
13 |
+
from src.document.types import (
|
14 |
+
RelatedRelation,
|
15 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
|
16 |
+
)
|
17 |
+
from src.utils import distance, distance_slices
|
18 |
from src.utils.span_utils import get_overlap_len
|
19 |
|
20 |
logger = logging.getLogger(__name__)
|
|
|
72 |
return new_doc
|
73 |
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
def remove_partitions_by_labels(
|
76 |
document: D, partition_layer: str, label_blacklist: List[str], span_layer: Optional[str] = None
|
77 |
) -> D:
|
|
|
201 |
DWithSpans = TypeVar("DWithSpans", bound=Document)
|
202 |
|
203 |
|
204 |
+
def get_start_end(span: Union[Span, MultiSpan]) -> Tuple[int, int]:
|
205 |
+
if isinstance(span, Span):
|
206 |
+
return span.start, span.end
|
207 |
+
elif isinstance(span, MultiSpan):
|
208 |
+
starts, ends = zip(*span.slices)
|
209 |
+
return min(starts), max(ends)
|
210 |
+
else:
|
211 |
+
raise ValueError(f"Unsupported span type: {type(span)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
+
def _get_aligned_span_mappings(
|
215 |
+
gold_spans: Iterable[Span], pred_spans: Iterable[Span], distance_type: str
|
216 |
+
) -> Tuple[Dict[int, Span], Dict[int, Span]]:
|
217 |
old2new_pred_span = {}
|
218 |
span_id2gold_span = {}
|
219 |
for pred_span in pred_spans:
|
|
|
222 |
(
|
223 |
gold_span,
|
224 |
distance(
|
225 |
+
start_end=get_start_end(pred_span),
|
226 |
+
other_start_end=get_start_end(gold_span),
|
227 |
distance_type=distance_type,
|
228 |
),
|
229 |
)
|
230 |
for gold_span in gold_spans
|
231 |
]
|
232 |
+
if len(gold_spans_with_distance) == 0:
|
233 |
+
continue
|
234 |
|
235 |
closest_gold_span, min_distance = min(gold_spans_with_distance, key=lambda x: x[1])
|
236 |
# if the closest gold span is the same as the predicted span, we don't need to align
|
237 |
if min_distance == 0.0:
|
238 |
continue
|
239 |
|
240 |
+
pred_start_end = get_start_end(pred_span)
|
241 |
+
closest_gold_start_end = get_start_end(closest_gold_span)
|
242 |
+
|
243 |
if have_overlap(
|
244 |
+
start_end=pred_start_end,
|
245 |
+
other_start_end=closest_gold_start_end,
|
246 |
):
|
247 |
+
overlap_len = get_overlap_len(pred_start_end, closest_gold_start_end)
|
|
|
|
|
|
|
248 |
l_max = max(
|
249 |
+
pred_start_end[1] - pred_start_end[0],
|
250 |
+
closest_gold_start_end[1] - closest_gold_start_end[0],
|
251 |
)
|
252 |
# if the overlap is at least half of the maximum length, we consider it a valid match for alignment
|
253 |
valid_match = overlap_len >= (l_max / 2)
|
|
|
255 |
valid_match = False
|
256 |
|
257 |
if valid_match:
|
258 |
+
if isinstance(pred_span, Span):
|
259 |
+
aligned_pred_span = pred_span.copy(
|
260 |
+
start=closest_gold_span.start, end=closest_gold_span.end
|
261 |
+
)
|
262 |
+
elif isinstance(pred_span, MultiSpan):
|
263 |
+
aligned_pred_span = pred_span.copy(slices=closest_gold_span.slices)
|
264 |
+
else:
|
265 |
+
raise ValueError(f"Unsupported span type: {type(pred_span)}")
|
266 |
old2new_pred_span[pred_span._id] = aligned_pred_span
|
267 |
span_id2gold_span[pred_span._id] = closest_gold_span
|
268 |
|
269 |
+
return old2new_pred_span, span_id2gold_span
|
270 |
+
|
271 |
+
|
272 |
+
def get_spans2multi_spans_mapping(multi_spans: Iterable[MultiSpan]) -> Dict[Span, MultiSpan]:
|
273 |
+
result = {}
|
274 |
+
for multi_span in multi_spans:
|
275 |
+
for start, end in multi_span.slices:
|
276 |
+
span_kwargs = dict(start=start, end=end, score=multi_span.score)
|
277 |
+
if isinstance(multi_span, LabeledMultiSpan):
|
278 |
+
result[LabeledSpan(label=multi_span.label, **span_kwargs)] = multi_span
|
279 |
+
else:
|
280 |
+
result[Span(**span_kwargs)] = multi_span
|
281 |
+
|
282 |
+
return result
|
283 |
+
|
284 |
+
|
285 |
+
def align_predicted_span_annotations(
|
286 |
+
document: DWithSpans,
|
287 |
+
span_layer: str,
|
288 |
+
distance_type: str = "center",
|
289 |
+
simple_multi_span: bool = False,
|
290 |
+
verbose: bool = False,
|
291 |
+
) -> DWithSpans:
|
292 |
+
"""
|
293 |
+
Aligns predicted span annotations with the closest gold spans in a document.
|
294 |
+
|
295 |
+
First, calculates the distance between each predicted span and each gold span. Then,
|
296 |
+
for each predicted span, the gold span with the smallest distance is selected. If the
|
297 |
+
predicted span and the gold span have an overlap of at least half of the maximum length
|
298 |
+
of the two spans, the predicted span is aligned with the gold span.
|
299 |
+
|
300 |
+
This also works for MultiSpan annotations, where the slices of the MultiSpan are used
|
301 |
+
to align the predicted spans. If any of the slices is aligned with a gold slice,
|
302 |
+
the MultiSpan is aligned with the respective gold MultiSpan. However, this may result in
|
303 |
+
the predicted MultiSpan being aligned with multiple gold MultiSpans, in which case the
|
304 |
+
closest gold MultiSpan is selected. A simplified version of this alignment can be achieved
|
305 |
+
by setting `simple_multi_span=True`, which treats MultiSpan annotations as simple Spans
|
306 |
+
by using their maximum and minimum start and end indices.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
document: The document to process.
|
310 |
+
span_layer: The name of the span layer.
|
311 |
+
distance_type: The type of distance to calculate. One of: center, inner, outer
|
312 |
+
simple_multi_span: Whether to treat MultiSpan annotations as simple Spans by using their
|
313 |
+
maximum and minimum start and end indices.
|
314 |
+
verbose: Whether to print debug information.
|
315 |
+
|
316 |
+
Returns:
|
317 |
+
The processed document.
|
318 |
+
"""
|
319 |
+
gold_spans = document[span_layer]
|
320 |
+
if len(gold_spans) == 0:
|
321 |
+
return document.copy()
|
322 |
+
|
323 |
+
pred_spans = document[span_layer].predictions
|
324 |
+
span_annotation_type = document.annotation_types()[span_layer]
|
325 |
+
if issubclass(span_annotation_type, Span) or simple_multi_span:
|
326 |
+
old2new_pred_span, span_id2gold_span = _get_aligned_span_mappings(
|
327 |
+
gold_spans=gold_spans, pred_spans=pred_spans, distance_type=distance_type
|
328 |
+
)
|
329 |
+
elif issubclass(span_annotation_type, MultiSpan):
|
330 |
+
# create Span objects from MultiSpan slices
|
331 |
+
gold_single_spans2multi_spans = get_spans2multi_spans_mapping(gold_spans)
|
332 |
+
pred_single_spans2multi_spans = get_spans2multi_spans_mapping(pred_spans)
|
333 |
+
# create the alignment mappings for the single spans
|
334 |
+
single_old2new_pred_span, single_span_id2gold_span = _get_aligned_span_mappings(
|
335 |
+
gold_spans=gold_single_spans2multi_spans.keys(),
|
336 |
+
pred_spans=pred_single_spans2multi_spans.keys(),
|
337 |
+
distance_type=distance_type,
|
338 |
+
)
|
339 |
+
# collect all Spans that are part of the same MultiSpan
|
340 |
+
pred_multi_span2single_spans: Dict[MultiSpan, List[Span]] = defaultdict(list)
|
341 |
+
for pred_span, multi_span in pred_single_spans2multi_spans.items():
|
342 |
+
pred_multi_span2single_spans[multi_span].append(pred_span)
|
343 |
+
|
344 |
+
# create the new mappings for the MultiSpans
|
345 |
+
old2new_pred_span = {}
|
346 |
+
span_id2gold_span = {}
|
347 |
+
for pred_multi_span, pred_single_spans in pred_multi_span2single_spans.items():
|
348 |
+
# if any of the single spans is aligned with a gold span, align the multi span
|
349 |
+
if any(
|
350 |
+
pred_single_span._id in single_old2new_pred_span
|
351 |
+
for pred_single_span in pred_single_spans
|
352 |
+
):
|
353 |
+
# get aligned gold multi spans
|
354 |
+
aligned_gold_multi_spans = set()
|
355 |
+
for pred_single_span in pred_single_spans:
|
356 |
+
if pred_single_span._id in single_old2new_pred_span:
|
357 |
+
aligned_gold_single_span = single_span_id2gold_span[pred_single_span._id]
|
358 |
+
aligned_gold_multi_span = gold_single_spans2multi_spans[
|
359 |
+
aligned_gold_single_span
|
360 |
+
]
|
361 |
+
aligned_gold_multi_spans.add(aligned_gold_multi_span)
|
362 |
+
|
363 |
+
# calculate distances between the predicted multi span and the aligned gold multi spans
|
364 |
+
gold_multi_spans_with_distance = [
|
365 |
+
(
|
366 |
+
gold_multi_span,
|
367 |
+
distance_slices(
|
368 |
+
slices=pred_multi_span.slices,
|
369 |
+
other_slices=gold_multi_span.slices,
|
370 |
+
distance_type=distance_type,
|
371 |
+
),
|
372 |
+
)
|
373 |
+
for gold_multi_span in aligned_gold_multi_spans
|
374 |
+
]
|
375 |
+
|
376 |
+
if len(aligned_gold_multi_spans) > 1:
|
377 |
+
logger.warning(
|
378 |
+
f"Multiple gold multi spans aligned with predicted multi span ({pred_multi_span}): "
|
379 |
+
f"{aligned_gold_multi_spans}"
|
380 |
+
)
|
381 |
+
# get the closest gold multi span
|
382 |
+
closest_gold_multi_span, min_distance = min(
|
383 |
+
gold_multi_spans_with_distance, key=lambda x: x[1]
|
384 |
+
)
|
385 |
+
old2new_pred_span[pred_multi_span._id] = pred_multi_span.copy(
|
386 |
+
slices=closest_gold_multi_span.slices
|
387 |
+
)
|
388 |
+
span_id2gold_span[pred_multi_span._id] = closest_gold_multi_span
|
389 |
+
else:
|
390 |
+
raise ValueError(f"Unsupported span annotation type: {span_annotation_type}")
|
391 |
+
|
392 |
result = document.copy(with_annotations=False)
|
393 |
|
394 |
# multiple predicted spans can be aligned with the same gold span,
|
|
|
427 |
)
|
428 |
|
429 |
return result
|
430 |
+
|
431 |
+
|
432 |
+
def add_related_relations_from_binary_relations(
|
433 |
+
document: TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
|
434 |
+
link_relation_label: str,
|
435 |
+
link_partition_whitelist: Optional[List[List[str]]] = None,
|
436 |
+
relation_label_whitelist: Optional[List[str]] = None,
|
437 |
+
reversed_relation_suffix: str = "_reversed",
|
438 |
+
symmetric_relations: Optional[List[str]] = None,
|
439 |
+
) -> TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations:
|
440 |
+
span2partition = {}
|
441 |
+
for multi_span in document.labeled_multi_spans:
|
442 |
+
found_partition = False
|
443 |
+
for partition in document.labeled_partitions or [
|
444 |
+
LabeledSpan(start=0, end=len(document.text), label="ALL")
|
445 |
+
]:
|
446 |
+
starts, ends = zip(*multi_span.slices)
|
447 |
+
if partition.start <= min(starts) and max(ends) <= partition.end:
|
448 |
+
span2partition[multi_span] = partition
|
449 |
+
found_partition = True
|
450 |
+
break
|
451 |
+
if not found_partition:
|
452 |
+
raise ValueError(f"No partition found for multi_span {multi_span}")
|
453 |
+
|
454 |
+
rel_head2rels = defaultdict(list)
|
455 |
+
rel_tail2rels = defaultdict(list)
|
456 |
+
for rel in document.binary_relations:
|
457 |
+
rel_head2rels[rel.head].append(rel)
|
458 |
+
rel_tail2rels[rel.tail].append(rel)
|
459 |
+
|
460 |
+
link_partition_whitelist_tuples = None
|
461 |
+
if link_partition_whitelist is not None:
|
462 |
+
link_partition_whitelist_tuples = {tuple(pair) for pair in link_partition_whitelist}
|
463 |
+
|
464 |
+
skipped_labels = []
|
465 |
+
for link_rel in document.binary_relations:
|
466 |
+
if link_rel.label == link_relation_label:
|
467 |
+
head_partition = span2partition[link_rel.head]
|
468 |
+
tail_partition = span2partition[link_rel.tail]
|
469 |
+
if link_partition_whitelist_tuples is None or (
|
470 |
+
(head_partition.label, tail_partition.label) in link_partition_whitelist_tuples
|
471 |
+
):
|
472 |
+
# link_head -> link_tail == rel_head -> rel_tail
|
473 |
+
for rel in rel_head2rels.get(link_rel.tail, []):
|
474 |
+
label = rel.label
|
475 |
+
if relation_label_whitelist is None or label in relation_label_whitelist:
|
476 |
+
new_rel = RelatedRelation(
|
477 |
+
head=link_rel.head,
|
478 |
+
tail=rel.tail,
|
479 |
+
link_relation=link_rel,
|
480 |
+
relation=rel,
|
481 |
+
label=label,
|
482 |
+
)
|
483 |
+
document.related_relations.append(new_rel)
|
484 |
+
else:
|
485 |
+
skipped_labels.append(label)
|
486 |
+
|
487 |
+
# link_head -> link_tail == rel_tail -> rel_head
|
488 |
+
if reversed_relation_suffix is not None:
|
489 |
+
for reversed_rel in rel_tail2rels.get(link_rel.tail, []):
|
490 |
+
label = reversed_rel.label
|
491 |
+
if not (symmetric_relations is not None and label in symmetric_relations):
|
492 |
+
label = f"{label}{reversed_relation_suffix}"
|
493 |
+
if relation_label_whitelist is None or label in relation_label_whitelist:
|
494 |
+
new_rel = RelatedRelation(
|
495 |
+
head=link_rel.head,
|
496 |
+
tail=reversed_rel.head,
|
497 |
+
link_relation=link_rel,
|
498 |
+
relation=reversed_rel,
|
499 |
+
label=label,
|
500 |
+
)
|
501 |
+
document.related_relations.append(new_rel)
|
502 |
+
else:
|
503 |
+
skipped_labels.append(label)
|
504 |
+
|
505 |
+
else:
|
506 |
+
logger.warning(
|
507 |
+
f"Skipping related relation because of partition whitelist ({[head_partition.label, tail_partition.label]}): {link_rel.resolve()}"
|
508 |
+
)
|
509 |
+
if len(skipped_labels) > 0:
|
510 |
+
logger.warning(
|
511 |
+
f"Skipped relations with labels not in whitelist: {sorted(set(skipped_labels))}"
|
512 |
+
)
|
513 |
+
|
514 |
+
return document
|
src/document/types.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
|
3 |
+
from pytorch_ie import AnnotationLayer, annotation_field
|
4 |
+
from pytorch_ie.annotations import BinaryRelation
|
5 |
+
from pytorch_ie.documents import (
|
6 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
@dataclasses.dataclass(eq=True, frozen=True)
|
11 |
+
class RelatedRelation(BinaryRelation):
|
12 |
+
link_relation: BinaryRelation = dataclasses.field(default=None, compare=False)
|
13 |
+
relation: BinaryRelation = dataclasses.field(default=None, compare=False)
|
14 |
+
|
15 |
+
def __post_init__(self):
|
16 |
+
super().__post_init__()
|
17 |
+
# check if the reference_span is correct
|
18 |
+
self.reference_span
|
19 |
+
|
20 |
+
@property
|
21 |
+
def reference_span(self):
|
22 |
+
if self.link_relation is None:
|
23 |
+
raise ValueError(
|
24 |
+
"No semantically_same_relation available, cannot return reference_span"
|
25 |
+
)
|
26 |
+
if self.link_relation.head == self.head:
|
27 |
+
return self.link_relation.tail
|
28 |
+
elif self.link_relation.tail == self.head:
|
29 |
+
return self.link_relation.head
|
30 |
+
elif self.link_relation.head == self.tail:
|
31 |
+
return self.link_relation.tail
|
32 |
+
elif self.link_relation.tail == self.tail:
|
33 |
+
return self.link_relation.head
|
34 |
+
else:
|
35 |
+
raise ValueError(
|
36 |
+
"The semantically_same_relation is neither linked to head nor tail of the current relation"
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
@dataclasses.dataclass
|
41 |
+
class TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations(
|
42 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
43 |
+
):
|
44 |
+
related_relations: AnnotationLayer[RelatedRelation] = annotation_field(
|
45 |
+
targets=["labeled_multi_spans", "binary_relations"]
|
46 |
+
)
|
src/metrics/__init__.py
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
from .coref_sklearn import CorefMetricsSKLearn
|
2 |
from .coref_torchmetrics import CorefMetricsTorchmetrics
|
|
|
|
1 |
from .coref_sklearn import CorefMetricsSKLearn
|
2 |
from .coref_torchmetrics import CorefMetricsTorchmetrics
|
3 |
+
from .score_distribution import ScoreDistribution
|
src/metrics/score_distribution.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from typing import Any, Dict, List, Optional, Tuple
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
from pytorch_ie import Document, DocumentMetric
|
6 |
+
|
7 |
+
|
8 |
+
class ScoreDistribution(DocumentMetric):
|
9 |
+
"""Computes the distribution of prediction scores for annotations in a layer. The scores are
|
10 |
+
separated into true positives (TP) and false positives (FP) based on the gold annotations.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
layer: The name of the annotation layer to analyze.
|
14 |
+
per_label: If True, the scores are separated per label. Default is False.
|
15 |
+
label_field: The field name of the label to use for separating the scores per label. Default is "label".
|
16 |
+
equal_sample_size_binning: If True, the scores are binned into equal sample sizes. If False,
|
17 |
+
the scores are binned into equal width. The former is useful when the distribution of scores is skewed.
|
18 |
+
Default is True.
|
19 |
+
show_plot: If True, a plot of the score distribution is shown. Default is False.
|
20 |
+
plotting_backend: The plotting backend to use. Default is "plotly".
|
21 |
+
plotting_caption_mapping: A mapping to rename any caption entries for plotting, i.e., the layer name,
|
22 |
+
labels, or TP/FP. Default is None.
|
23 |
+
plotting_colors: A dictionary mapping from gold scores to colors for plotting. Default is None.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
layer: str,
|
29 |
+
label_field: str = "label",
|
30 |
+
per_label: bool = False,
|
31 |
+
show_plot: bool = False,
|
32 |
+
equal_sample_size_binning: bool = True,
|
33 |
+
plotting_backend: str = "plotly",
|
34 |
+
plotting_caption_mapping: Optional[Dict[str, str]] = None,
|
35 |
+
plotting_colors: Optional[Dict[str, str]] = None,
|
36 |
+
plotly_use_create_distplot: bool = True,
|
37 |
+
plotly_barmode: Optional[str] = None,
|
38 |
+
plotly_marginal: Optional[str] = "violin",
|
39 |
+
plotly_font_size: int = 18,
|
40 |
+
plotly_font_family: Optional[str] = None,
|
41 |
+
plotly_background_color: Optional[str] = None,
|
42 |
+
):
|
43 |
+
super().__init__()
|
44 |
+
self.layer = layer
|
45 |
+
self.label_field = label_field
|
46 |
+
self.per_label = per_label
|
47 |
+
self.equal_sample_size_binning = equal_sample_size_binning
|
48 |
+
self.plotting_backend = plotting_backend
|
49 |
+
self.show_plot = show_plot
|
50 |
+
self.plotting_caption_mapping = plotting_caption_mapping or {}
|
51 |
+
self.plotting_colors = plotting_colors
|
52 |
+
self.plotly_use_create_distplot = plotly_use_create_distplot
|
53 |
+
self.plotly_barmode = plotly_barmode
|
54 |
+
self.plotly_marginal = plotly_marginal
|
55 |
+
self.plotly_font_size = plotly_font_size
|
56 |
+
self.plotly_font_family = plotly_font_family
|
57 |
+
self.plotly_background_color = plotly_background_color
|
58 |
+
self.scores: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
|
59 |
+
|
60 |
+
def reset(self):
|
61 |
+
self.scores = defaultdict(lambda: defaultdict(list))
|
62 |
+
|
63 |
+
def _update(self, document: Document):
|
64 |
+
|
65 |
+
gold_annotations = set(document[self.layer])
|
66 |
+
for ann in document[self.layer].predictions:
|
67 |
+
if self.per_label:
|
68 |
+
label = getattr(ann, self.label_field)
|
69 |
+
else:
|
70 |
+
label = "ALL"
|
71 |
+
if ann in gold_annotations:
|
72 |
+
self.scores[label]["TP"].append(ann.score)
|
73 |
+
else:
|
74 |
+
self.scores[label]["FP"].append(ann.score)
|
75 |
+
|
76 |
+
def _combine_scores(
|
77 |
+
self,
|
78 |
+
scores_tp: List[float],
|
79 |
+
score_fp: List[float],
|
80 |
+
col_name_pred: str = "prediction",
|
81 |
+
col_name_gold: str = "gold",
|
82 |
+
) -> pd.DataFrame:
|
83 |
+
scores_tp_df = pd.DataFrame(scores_tp, columns=[col_name_pred])
|
84 |
+
scores_tp_df[col_name_gold] = 1.0
|
85 |
+
scores_fp_df = pd.DataFrame(score_fp, columns=[col_name_pred])
|
86 |
+
scores_fp_df[col_name_gold] = 0.0
|
87 |
+
scores_df = pd.concat([scores_tp_df, scores_fp_df])
|
88 |
+
return scores_df
|
89 |
+
|
90 |
+
def _get_calibration_data_and_metrics(
|
91 |
+
self, scores: pd.DataFrame, q: int = 20
|
92 |
+
) -> Tuple[pd.DataFrame, pd.Series]:
|
93 |
+
from sklearn.metrics import brier_score_loss
|
94 |
+
|
95 |
+
if self.equal_sample_size_binning:
|
96 |
+
# Create bins with equal number of samples.
|
97 |
+
scores["bin"] = pd.qcut(scores["prediction"], q=q, labels=False)
|
98 |
+
else:
|
99 |
+
# Create bins with equal width.
|
100 |
+
scores["bin"] = pd.cut(
|
101 |
+
scores["prediction"],
|
102 |
+
bins=q,
|
103 |
+
include_lowest=True,
|
104 |
+
right=True,
|
105 |
+
labels=False,
|
106 |
+
)
|
107 |
+
|
108 |
+
calibration_data = (
|
109 |
+
scores.groupby("bin")
|
110 |
+
.apply(
|
111 |
+
lambda x: pd.Series(
|
112 |
+
{
|
113 |
+
"avg_score": x["prediction"].mean(),
|
114 |
+
"fraction_positive": x["gold"].mean(),
|
115 |
+
"count": len(x),
|
116 |
+
}
|
117 |
+
)
|
118 |
+
)
|
119 |
+
.reset_index()
|
120 |
+
)
|
121 |
+
|
122 |
+
total_count = scores.shape[0]
|
123 |
+
calibration_data["bin_weight"] = calibration_data["count"] / total_count
|
124 |
+
|
125 |
+
# Calculate the absolute differences and squared differences.
|
126 |
+
calibration_data["abs_diff"] = abs(
|
127 |
+
calibration_data["avg_score"] - calibration_data["fraction_positive"]
|
128 |
+
)
|
129 |
+
calibration_data["squared_diff"] = (
|
130 |
+
calibration_data["avg_score"] - calibration_data["fraction_positive"]
|
131 |
+
) ** 2
|
132 |
+
|
133 |
+
# Compute Expected Calibration Error (ECE): weighted average of absolute differences.
|
134 |
+
ece = (calibration_data["abs_diff"] * calibration_data["bin_weight"]).sum()
|
135 |
+
|
136 |
+
# Compute Maximum Calibration Error (MCE): maximum absolute difference.
|
137 |
+
mce = calibration_data["abs_diff"].max()
|
138 |
+
|
139 |
+
# Compute Mean Squared Error (MSE): weighted average of squared differences.
|
140 |
+
mse = (calibration_data["squared_diff"] * calibration_data["bin_weight"]).sum()
|
141 |
+
|
142 |
+
# Compute the Brier Score on the raw predictions.
|
143 |
+
brier = brier_score_loss(scores["gold"], scores["prediction"])
|
144 |
+
|
145 |
+
values = {
|
146 |
+
"ece": ece,
|
147 |
+
"mce": mce,
|
148 |
+
"mse": mse,
|
149 |
+
"brier": brier,
|
150 |
+
}
|
151 |
+
return calibration_data, pd.Series(values)
|
152 |
+
|
153 |
+
def calculate_calibration_metrics(self, scores_combined: pd.DataFrame) -> pd.DataFrame:
|
154 |
+
|
155 |
+
calibration_data_dict = {}
|
156 |
+
calibration_metrics_dict = {}
|
157 |
+
for label, current_scores in scores_combined.groupby("label"):
|
158 |
+
calibration_data, calibration_metrics = self._get_calibration_data_and_metrics(
|
159 |
+
current_scores, q=20
|
160 |
+
)
|
161 |
+
calibration_data_dict[label] = calibration_data
|
162 |
+
calibration_metrics_dict[label] = calibration_metrics
|
163 |
+
all_calibration_data = pd.concat(
|
164 |
+
calibration_data_dict, names=["label", "idx"]
|
165 |
+
).reset_index(level=0)
|
166 |
+
all_calibration_metrics = pd.concat(calibration_metrics_dict, axis=1).T
|
167 |
+
|
168 |
+
if self.show_plot:
|
169 |
+
self.plot_calibration_data(calibration_data=all_calibration_data)
|
170 |
+
|
171 |
+
return all_calibration_metrics
|
172 |
+
|
173 |
+
def calculate_correlation(self, scores: pd.DataFrame) -> pd.Series:
|
174 |
+
result_dict = {}
|
175 |
+
for label, current_scores in scores.groupby("label"):
|
176 |
+
result_dict[label] = current_scores.drop("label", axis=1).corr()["prediction"]["gold"]
|
177 |
+
|
178 |
+
return pd.Series(result_dict, name="correlation")
|
179 |
+
|
180 |
+
@property
|
181 |
+
def mapped_layer(self):
|
182 |
+
return self.plotting_caption_mapping.get(self.layer, self.layer)
|
183 |
+
|
184 |
+
def plot_score_distribution(self, scores: pd.DataFrame):
|
185 |
+
if self.plotting_backend == "plotly":
|
186 |
+
for label in scores["label"].unique():
|
187 |
+
description = f"Distribution of Predicted Scores for {self.mapped_layer}"
|
188 |
+
if self.per_label:
|
189 |
+
label_mapped = self.plotting_caption_mapping.get(label, label)
|
190 |
+
description += f" ({label_mapped})"
|
191 |
+
if self.plotly_use_create_distplot:
|
192 |
+
import plotly.figure_factory as ff
|
193 |
+
|
194 |
+
current_scores = scores[scores["label"] == label]
|
195 |
+
# group by gold score
|
196 |
+
scores_dict = (
|
197 |
+
current_scores.groupby("gold")["prediction"].apply(list).to_dict()
|
198 |
+
)
|
199 |
+
group_labels, hist_data = zip(*scores_dict.items())
|
200 |
+
group_labels_renamed = [
|
201 |
+
self.plotting_caption_mapping.get(label, label) for label in group_labels
|
202 |
+
]
|
203 |
+
if self.plotting_colors is not None:
|
204 |
+
colors = [
|
205 |
+
self.plotting_colors[group_label] for group_label in group_labels
|
206 |
+
]
|
207 |
+
else:
|
208 |
+
colors = None
|
209 |
+
fig = ff.create_distplot(
|
210 |
+
hist_data,
|
211 |
+
group_labels=group_labels_renamed,
|
212 |
+
show_hist=True,
|
213 |
+
colors=colors,
|
214 |
+
bin_size=0.025,
|
215 |
+
)
|
216 |
+
else:
|
217 |
+
import plotly.express as px
|
218 |
+
|
219 |
+
fig = px.histogram(
|
220 |
+
scores,
|
221 |
+
x="prediction",
|
222 |
+
color="gold",
|
223 |
+
marginal=self.plotly_marginal, # "violin", # or box, violin, rug
|
224 |
+
hover_data=scores.columns,
|
225 |
+
color_discrete_map=self.plotting_colors,
|
226 |
+
nbins=50,
|
227 |
+
)
|
228 |
+
|
229 |
+
fig.update_layout(
|
230 |
+
height=600,
|
231 |
+
width=800,
|
232 |
+
title_text=description,
|
233 |
+
title_x=0.5,
|
234 |
+
font=dict(size=self.plotly_font_size),
|
235 |
+
legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
|
236 |
+
)
|
237 |
+
if self.plotly_barmode is not None:
|
238 |
+
fig.update_layout(barmode=self.plotly_barmode)
|
239 |
+
if self.plotly_font_family is not None:
|
240 |
+
fig.update_layout(font_family=self.plotly_font_family)
|
241 |
+
if self.plotly_background_color is not None:
|
242 |
+
fig.update_layout(
|
243 |
+
plot_bgcolor=self.plotly_background_color,
|
244 |
+
paper_bgcolor=self.plotly_background_color,
|
245 |
+
)
|
246 |
+
|
247 |
+
fig.show()
|
248 |
+
else:
|
249 |
+
raise NotImplementedError(f"Plotting backend {self.plotting_backend} not implemented")
|
250 |
+
|
251 |
+
def plot_calibration_data(self, calibration_data: pd.DataFrame):
|
252 |
+
import plotly.express as px
|
253 |
+
import plotly.graph_objects as go
|
254 |
+
|
255 |
+
color = "label" if self.per_label else None
|
256 |
+
x_col = "avg_score"
|
257 |
+
y_col = "fraction_positive"
|
258 |
+
fig = px.scatter(
|
259 |
+
calibration_data,
|
260 |
+
x=x_col,
|
261 |
+
y=y_col,
|
262 |
+
color=color,
|
263 |
+
trendline="ols",
|
264 |
+
labels=self.plotting_caption_mapping,
|
265 |
+
)
|
266 |
+
if not self.per_label:
|
267 |
+
fig["data"][1]["name"] = "prediction vs. gold"
|
268 |
+
|
269 |
+
# show legend only for trendlines
|
270 |
+
for idx, trace_data in enumerate(fig["data"]):
|
271 |
+
if idx % 2 == 0:
|
272 |
+
trace_data["showlegend"] = False
|
273 |
+
else:
|
274 |
+
trace_data["showlegend"] = True
|
275 |
+
|
276 |
+
# add the optimal line
|
277 |
+
minimum = calibration_data[x_col].min()
|
278 |
+
maximum = calibration_data[x_col].max()
|
279 |
+
fig.add_trace(
|
280 |
+
go.Scatter(
|
281 |
+
x=[minimum, maximum],
|
282 |
+
y=[minimum, maximum],
|
283 |
+
mode="lines",
|
284 |
+
name="optimal",
|
285 |
+
line=dict(color="black", dash="dash"),
|
286 |
+
)
|
287 |
+
)
|
288 |
+
fig.update_layout(
|
289 |
+
height=600,
|
290 |
+
width=800,
|
291 |
+
title_text=f"Mean Binned Scores for {self.mapped_layer}",
|
292 |
+
title_x=0.5,
|
293 |
+
font=dict(size=self.plotly_font_size),
|
294 |
+
)
|
295 |
+
fig.update_layout(
|
296 |
+
legend=dict(
|
297 |
+
yanchor="top",
|
298 |
+
y=0.99,
|
299 |
+
xanchor="left",
|
300 |
+
x=0.01,
|
301 |
+
title="OLS trendline" + ("s" if self.per_label else ""),
|
302 |
+
),
|
303 |
+
)
|
304 |
+
if self.plotly_background_color is not None:
|
305 |
+
fig.update_layout(
|
306 |
+
plot_bgcolor=self.plotly_background_color,
|
307 |
+
paper_bgcolor=self.plotly_background_color,
|
308 |
+
)
|
309 |
+
|
310 |
+
if self.plotly_font_family is not None:
|
311 |
+
fig.update_layout(font_family=self.plotly_font_family)
|
312 |
+
|
313 |
+
fig.show()
|
314 |
+
|
315 |
+
def _compute(self) -> Dict[str, Dict[str, Any]]:
|
316 |
+
scores_combined = pd.concat(
|
317 |
+
{
|
318 |
+
label: self._combine_scores(scores["TP"], scores["FP"])
|
319 |
+
for label, scores in self.scores.items()
|
320 |
+
},
|
321 |
+
names=["label", "idx"],
|
322 |
+
).reset_index(level=0)
|
323 |
+
|
324 |
+
result_df = scores_combined.groupby("label")["prediction"].agg(["mean", "std", "count"])
|
325 |
+
if self.show_plot:
|
326 |
+
self.plot_score_distribution(scores=scores_combined)
|
327 |
+
|
328 |
+
calibration_metrics = self.calculate_calibration_metrics(scores_combined)
|
329 |
+
calibration_metrics["correlation"] = self.calculate_correlation(scores_combined)
|
330 |
+
|
331 |
+
result_df = pd.concat(
|
332 |
+
{"prediction": result_df, "prediction vs. gold": calibration_metrics}, axis=1
|
333 |
+
)
|
334 |
+
|
335 |
+
if not self.per_label:
|
336 |
+
result = result_df.xs("ALL")
|
337 |
+
else:
|
338 |
+
result = result_df.T.stack().unstack()
|
339 |
+
|
340 |
+
result_dict = {
|
341 |
+
main_key: result.xs(main_key).T.to_dict()
|
342 |
+
for main_key in result.index.get_level_values(0).unique()
|
343 |
+
}
|
344 |
+
|
345 |
+
return result_dict
|
src/models/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from .sequence_classification_with_pooler import (
|
2 |
SequencePairSimilarityModelWithMaxCosineSim,
|
3 |
SequencePairSimilarityModelWithPooler2,
|
|
|
1 |
+
from .sequence_classification import SimpleSequenceClassificationModelWithInputTypeIds
|
2 |
from .sequence_classification_with_pooler import (
|
3 |
SequencePairSimilarityModelWithMaxCosineSim,
|
4 |
SequencePairSimilarityModelWithPooler2,
|
src/models/sequence_classification.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from pie_modules.models import SimpleSequenceClassificationModel
|
4 |
+
from pie_modules.models.simple_sequence_classification import InputType, OutputType, TargetType
|
5 |
+
from pytorch_ie import PyTorchIEModel
|
6 |
+
from torch import nn
|
7 |
+
from transformers import BertModel
|
8 |
+
from transformers.utils import is_accelerate_available
|
9 |
+
|
10 |
+
if is_accelerate_available():
|
11 |
+
from accelerate.hooks import add_hook_to_module
|
12 |
+
|
13 |
+
|
14 |
+
@PyTorchIEModel.register()
|
15 |
+
class SimpleSequenceClassificationModelWithInputTypeIds(SimpleSequenceClassificationModel):
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self, num_token_type_ids: int, use_as_token_type_ids: str = "token_type_ids", **kwargs
|
19 |
+
):
|
20 |
+
super().__init__(**kwargs)
|
21 |
+
self.num_token_type_ids = num_token_type_ids
|
22 |
+
self.token_type_ids_key = use_as_token_type_ids
|
23 |
+
self.resize_type_embeddings(num_token_type_ids)
|
24 |
+
|
25 |
+
def get_input_type_embeddings(self) -> nn.Module:
|
26 |
+
base_model: BertModel = getattr(self.model, self.model.base_model_prefix)
|
27 |
+
if base_model is None:
|
28 |
+
raise ValueError("Model has no base model.")
|
29 |
+
return base_model.embeddings.token_type_embeddings
|
30 |
+
|
31 |
+
def set_input_type_embeddings(self, value):
|
32 |
+
base_model: BertModel = getattr(self.model, self.model.base_model_prefix)
|
33 |
+
if base_model is None:
|
34 |
+
raise ValueError("Model has no base model.")
|
35 |
+
base_model.embeddings.token_type_embeddings = value
|
36 |
+
|
37 |
+
def _resize_type_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
|
38 |
+
old_embeddings = self.get_input_type_embeddings()
|
39 |
+
new_embeddings = self.model._get_resized_embeddings(
|
40 |
+
old_embeddings, new_num_tokens, pad_to_multiple_of
|
41 |
+
)
|
42 |
+
if hasattr(old_embeddings, "_hf_hook"):
|
43 |
+
hook = old_embeddings._hf_hook
|
44 |
+
add_hook_to_module(new_embeddings, hook)
|
45 |
+
old_embeddings_requires_grad = old_embeddings.weight.requires_grad
|
46 |
+
new_embeddings.requires_grad_(old_embeddings_requires_grad)
|
47 |
+
self.set_input_type_embeddings(new_embeddings)
|
48 |
+
|
49 |
+
return self.get_input_type_embeddings()
|
50 |
+
|
51 |
+
def resize_type_embeddings(
|
52 |
+
self, new_num_types: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
|
53 |
+
) -> nn.Embedding:
|
54 |
+
"""
|
55 |
+
Same as resize_token_embeddings but for the token type embeddings.
|
56 |
+
|
57 |
+
Resizes input token type embeddings matrix of the model if `new_num_types != config.type_vocab_size`.
|
58 |
+
|
59 |
+
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
60 |
+
|
61 |
+
Arguments:
|
62 |
+
new_num_types (`int`, *optional*):
|
63 |
+
The number of new token types in the embedding matrix. Increasing the size will add newly initialized
|
64 |
+
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
|
65 |
+
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
|
66 |
+
pad_to_multiple_of (`int`, *optional*):
|
67 |
+
If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
|
68 |
+
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
|
69 |
+
|
70 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
71 |
+
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
|
72 |
+
details about this, or help on choosing the correct value for resizing, refer to this guide:
|
73 |
+
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
|
74 |
+
|
75 |
+
Return:
|
76 |
+
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
|
77 |
+
"""
|
78 |
+
model_embeds = self._resize_type_embeddings(new_num_types, pad_to_multiple_of)
|
79 |
+
if new_num_types is None and pad_to_multiple_of is None:
|
80 |
+
return model_embeds
|
81 |
+
|
82 |
+
# Update base model and current model config
|
83 |
+
self.model.config.type_vocab_size = model_embeds.weight.shape[0]
|
84 |
+
|
85 |
+
# Tie weights again if needed
|
86 |
+
self.model.tie_weights()
|
87 |
+
|
88 |
+
return model_embeds
|
89 |
+
|
90 |
+
def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> OutputType:
|
91 |
+
kwargs = {**inputs, **(targets or {})}
|
92 |
+
# rename key to input_type_ids
|
93 |
+
kwargs["token_type_ids"] = kwargs.pop(self.token_type_ids_key)
|
94 |
+
return self.model(**kwargs)
|
src/pipeline/ner_re_pipeline.py
CHANGED
@@ -15,6 +15,7 @@ from typing import (
|
|
15 |
overload,
|
16 |
)
|
17 |
|
|
|
18 |
from pie_modules.utils import resolve_type
|
19 |
from pytorch_ie import AutoPipeline, WithDocumentTypeMixin
|
20 |
from pytorch_ie.core import Document
|
@@ -53,31 +54,105 @@ def move_annotations_to_predictions(doc: D, layer_names: List[str]) -> None:
|
|
53 |
doc[layer_name].predictions.extend(annotations)
|
54 |
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
def add_annotations_from_other_documents(
|
57 |
docs: Iterable[D],
|
58 |
other_docs: Sequence[Document],
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
79 |
else:
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
|
83 |
def process_pipeline_steps(
|
@@ -227,6 +302,9 @@ class NerRePipeline:
|
|
227 |
"re_add_gold_data": partial(
|
228 |
add_annotations_from_other_documents,
|
229 |
other_docs=original_docs,
|
|
|
|
|
|
|
230 |
layer_names=[self.entity_layer, self.relation_layer],
|
231 |
**self.processor_kwargs.get("re_add_gold_data", {}),
|
232 |
),
|
|
|
15 |
overload,
|
16 |
)
|
17 |
|
18 |
+
from pie_datasets import Dataset
|
19 |
from pie_modules.utils import resolve_type
|
20 |
from pytorch_ie import AutoPipeline, WithDocumentTypeMixin
|
21 |
from pytorch_ie.core import Document
|
|
|
54 |
doc[layer_name].predictions.extend(annotations)
|
55 |
|
56 |
|
57 |
+
def _add_annotations_from_other_document(
|
58 |
+
doc: D,
|
59 |
+
from_predictions: bool,
|
60 |
+
to_predictions: bool,
|
61 |
+
clear_before: bool,
|
62 |
+
other_doc: Optional[D] = None,
|
63 |
+
other_docs_dict: Optional[Dict[str, D]] = None,
|
64 |
+
layer_names: Optional[List[str]] = None,
|
65 |
+
) -> D:
|
66 |
+
if other_doc is None:
|
67 |
+
if other_docs_dict is None:
|
68 |
+
raise ValueError("Either other_doc or other_docs_dict must be provided")
|
69 |
+
other_doc = other_docs_dict.get(doc.id)
|
70 |
+
if other_doc is None:
|
71 |
+
logger.warning(f"Document with ID {doc.id} not found in other_docs")
|
72 |
+
return doc
|
73 |
+
|
74 |
+
# copy to not modify the input
|
75 |
+
other_doc_copy = type(other_doc).fromdict(other_doc.asdict())
|
76 |
+
|
77 |
+
if layer_names is None:
|
78 |
+
layer_names = [field.name for field in doc.annotation_fields()]
|
79 |
+
|
80 |
+
for layer_name in layer_names:
|
81 |
+
layer = doc[layer_name]
|
82 |
+
if to_predictions:
|
83 |
+
layer = layer.predictions
|
84 |
+
if clear_before:
|
85 |
+
layer.clear()
|
86 |
+
other_layer = other_doc_copy[layer_name]
|
87 |
+
if from_predictions:
|
88 |
+
other_layer = other_layer.predictions
|
89 |
+
other_annotations = list(other_layer)
|
90 |
+
other_layer.clear()
|
91 |
+
layer.extend(other_annotations)
|
92 |
+
|
93 |
+
return doc
|
94 |
+
|
95 |
+
|
96 |
def add_annotations_from_other_documents(
|
97 |
docs: Iterable[D],
|
98 |
other_docs: Sequence[Document],
|
99 |
+
get_other_doc_by_id: bool = False,
|
100 |
+
**kwargs,
|
101 |
+
) -> Sequence[D]:
|
102 |
+
other_id2doc = None
|
103 |
+
if get_other_doc_by_id:
|
104 |
+
other_id2doc = {doc.id: doc for doc in other_docs}
|
105 |
+
|
106 |
+
if isinstance(docs, Dataset):
|
107 |
+
if other_id2doc is None:
|
108 |
+
raise ValueError("get_other_doc_by_id must be True when passing a Dataset")
|
109 |
+
result = docs.map(
|
110 |
+
_add_annotations_from_other_document,
|
111 |
+
fn_kwargs=dict(other_docs_dict=other_id2doc, **kwargs),
|
112 |
+
)
|
113 |
+
elif isinstance(docs, list):
|
114 |
+
result = []
|
115 |
+
for i, doc in enumerate(docs):
|
116 |
+
if other_id2doc is not None:
|
117 |
+
other_doc = other_id2doc.get(doc.id)
|
118 |
+
if other_doc is None:
|
119 |
+
logger.warning(f"Document with ID {doc.id} not found in other_docs")
|
120 |
+
continue
|
121 |
else:
|
122 |
+
other_doc = other_docs[i]
|
123 |
+
|
124 |
+
# check if the IDs of the documents match
|
125 |
+
doc_id = getattr(doc, "id", None)
|
126 |
+
other_doc_id = getattr(other_doc, "id", None)
|
127 |
+
if doc_id is not None and doc_id != other_doc_id:
|
128 |
+
raise ValueError(
|
129 |
+
f"IDs of the documents do not match: {doc_id} != {other_doc_id}"
|
130 |
+
)
|
131 |
+
|
132 |
+
current_result = _add_annotations_from_other_document(
|
133 |
+
doc, other_doc=other_doc, **kwargs
|
134 |
+
)
|
135 |
+
result.append(current_result)
|
136 |
+
else:
|
137 |
+
raise ValueError(f"Unsupported type: {type(docs)}")
|
138 |
+
|
139 |
+
return result
|
140 |
+
|
141 |
+
|
142 |
+
DM = TypeVar("DM", bound=Dict[str, Iterable[Document]])
|
143 |
+
|
144 |
+
|
145 |
+
def add_annotations_from_other_documents_dict(
|
146 |
+
docs: DM, other_docs: Dict[str, Sequence[Document]], **kwargs
|
147 |
+
) -> DM:
|
148 |
+
if set(docs.keys()) != set(other_docs.keys()):
|
149 |
+
raise ValueError("Keys of the documents do not match")
|
150 |
+
|
151 |
+
result_dict = {
|
152 |
+
key: add_annotations_from_other_documents(doc_list, other_docs[key], **kwargs)
|
153 |
+
for key, doc_list in docs.items()
|
154 |
+
}
|
155 |
+
return type(docs)(result_dict)
|
156 |
|
157 |
|
158 |
def process_pipeline_steps(
|
|
|
302 |
"re_add_gold_data": partial(
|
303 |
add_annotations_from_other_documents,
|
304 |
other_docs=original_docs,
|
305 |
+
from_predictions=False,
|
306 |
+
to_predictions=False,
|
307 |
+
clear_before=False,
|
308 |
layer_names=[self.entity_layer, self.relation_layer],
|
309 |
**self.processor_kwargs.get("re_add_gold_data", {}),
|
310 |
),
|
src/predict.py
CHANGED
@@ -34,14 +34,13 @@ root = pyrootutils.setup_root(
|
|
34 |
# ------------------------------------------------------------------------------------ #
|
35 |
|
36 |
import os
|
37 |
-
import timeit
|
38 |
from collections.abc import Iterable, Sequence
|
39 |
from typing import Any, Dict, Optional, Tuple, Union
|
40 |
|
41 |
import hydra
|
42 |
import pytorch_lightning as pl
|
43 |
from omegaconf import DictConfig, OmegaConf
|
44 |
-
from pie_datasets import
|
45 |
from pie_modules.models import * # noqa: F403
|
46 |
from pie_modules.taskmodules import * # noqa: F403
|
47 |
from pytorch_ie import Document, Pipeline
|
@@ -132,38 +131,13 @@ def predict(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
132 |
"pipeline": pipeline,
|
133 |
"serializer": serializer,
|
134 |
}
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
document_batch_size = cfg.get("document_batch_size", None)
|
143 |
-
for docs_batch in (
|
144 |
-
document_batch_iter(dataset_predict, document_batch_size)
|
145 |
-
if document_batch_size
|
146 |
-
else [dataset_predict]
|
147 |
-
):
|
148 |
-
if pipeline is not None:
|
149 |
-
t_start = timeit.default_timer()
|
150 |
-
docs_batch = pipeline(docs_batch, inplace=False)
|
151 |
-
prediction_time += timeit.default_timer() - t_start # type: ignore
|
152 |
-
|
153 |
-
# serialize the documents
|
154 |
-
if serializer is not None:
|
155 |
-
# the serializer should not return the serialized documents, but write them to disk
|
156 |
-
# and instead return some metadata such as the path to the serialized documents
|
157 |
-
serializer_result = serializer(docs_batch)
|
158 |
-
if "serializer" in result and result["serializer"] != serializer_result:
|
159 |
-
log.warning(
|
160 |
-
f"serializer result changed from {result['serializer']} to {serializer_result}"
|
161 |
-
" during prediction. Only the last result is returned."
|
162 |
-
)
|
163 |
-
result["serializer"] = serializer_result
|
164 |
-
|
165 |
-
if prediction_time is not None:
|
166 |
-
result["prediction_time"] = prediction_time
|
167 |
|
168 |
# serialize config with resolved paths
|
169 |
if cfg.get("config_out_path"):
|
|
|
34 |
# ------------------------------------------------------------------------------------ #
|
35 |
|
36 |
import os
|
|
|
37 |
from collections.abc import Iterable, Sequence
|
38 |
from typing import Any, Dict, Optional, Tuple, Union
|
39 |
|
40 |
import hydra
|
41 |
import pytorch_lightning as pl
|
42 |
from omegaconf import DictConfig, OmegaConf
|
43 |
+
from pie_datasets import DatasetDict
|
44 |
from pie_modules.models import * # noqa: F403
|
45 |
from pie_modules.taskmodules import * # noqa: F403
|
46 |
from pytorch_ie import Document, Pipeline
|
|
|
131 |
"pipeline": pipeline,
|
132 |
"serializer": serializer,
|
133 |
}
|
134 |
+
# predict and serialize
|
135 |
+
result: Dict[str, Any] = utils.predict_and_serialize(
|
136 |
+
pipeline=pipeline,
|
137 |
+
serializer=serializer,
|
138 |
+
dataset=dataset_predict,
|
139 |
+
document_batch_size=cfg.get("document_batch_size", None),
|
140 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
# serialize config with resolved paths
|
143 |
if cfg.get("config_out_path"):
|
src/serializer/interface.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
-
from typing import Any,
|
3 |
|
4 |
from pytorch_ie.core import Document
|
5 |
|
@@ -12,5 +12,5 @@ class DocumentSerializer(ABC):
|
|
12 |
"""
|
13 |
|
14 |
@abstractmethod
|
15 |
-
def __call__(self, documents:
|
16 |
pass
|
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
+
from typing import Any, Iterable
|
3 |
|
4 |
from pytorch_ie.core import Document
|
5 |
|
|
|
12 |
"""
|
13 |
|
14 |
@abstractmethod
|
15 |
+
def __call__(self, documents: Iterable[Document]) -> Any:
|
16 |
pass
|
src/serializer/json.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import json
|
2 |
import os
|
3 |
-
from typing import Dict, List, Optional, Sequence, Type, TypeVar
|
4 |
|
5 |
from pie_datasets import Dataset, DatasetDict, IterableDataset
|
6 |
from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
|
@@ -8,7 +8,7 @@ from pytorch_ie.core import Document
|
|
8 |
from pytorch_ie.utils.hydra import resolve_optional_document_type, serialize_document_type
|
9 |
|
10 |
from src.serializer.interface import DocumentSerializer
|
11 |
-
from src.utils import get_pylogger
|
12 |
|
13 |
log = get_pylogger(__name__)
|
14 |
|
@@ -31,7 +31,7 @@ class JsonSerializer(DocumentSerializer):
|
|
31 |
@classmethod
|
32 |
def write(
|
33 |
cls,
|
34 |
-
documents:
|
35 |
path: str,
|
36 |
file_name: str = "documents.jsonl",
|
37 |
metadata_file_name: str = METADATA_FILE_NAME,
|
@@ -42,6 +42,9 @@ class JsonSerializer(DocumentSerializer):
|
|
42 |
log.info(f'serialize documents to "{realpath}" ...')
|
43 |
os.makedirs(realpath, exist_ok=True)
|
44 |
|
|
|
|
|
|
|
45 |
# dump metadata including the document_type
|
46 |
if len(documents) == 0:
|
47 |
raise Exception("cannot serialize empty list of documents")
|
@@ -130,7 +133,7 @@ class JsonSerializer(DocumentSerializer):
|
|
130 |
all_kwargs = {**self.default_kwargs, **kwargs}
|
131 |
return self.write(**all_kwargs)
|
132 |
|
133 |
-
def __call__(self, documents:
|
134 |
return self.write_with_defaults(documents=documents, **kwargs)
|
135 |
|
136 |
|
@@ -141,12 +144,15 @@ class JsonSerializer2(DocumentSerializer):
|
|
141 |
@classmethod
|
142 |
def write(
|
143 |
cls,
|
144 |
-
documents:
|
145 |
path: str,
|
146 |
split: str = "train",
|
147 |
) -> Dict[str, str]:
|
148 |
if not isinstance(documents, (Dataset, IterableDataset)):
|
149 |
-
|
|
|
|
|
|
|
150 |
dataset_dict = DatasetDict({split: documents})
|
151 |
dataset_dict.to_json(path=path)
|
152 |
return {"path": path, "split": split}
|
@@ -175,5 +181,5 @@ class JsonSerializer2(DocumentSerializer):
|
|
175 |
all_kwargs = {**self.default_kwargs, **kwargs}
|
176 |
return self.write(**all_kwargs)
|
177 |
|
178 |
-
def __call__(self, documents:
|
179 |
return self.write_with_defaults(documents=documents, **kwargs)
|
|
|
1 |
import json
|
2 |
import os
|
3 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Type, TypeVar
|
4 |
|
5 |
from pie_datasets import Dataset, DatasetDict, IterableDataset
|
6 |
from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
|
|
|
8 |
from pytorch_ie.utils.hydra import resolve_optional_document_type, serialize_document_type
|
9 |
|
10 |
from src.serializer.interface import DocumentSerializer
|
11 |
+
from src.utils.logging_utils import get_pylogger
|
12 |
|
13 |
log = get_pylogger(__name__)
|
14 |
|
|
|
31 |
@classmethod
|
32 |
def write(
|
33 |
cls,
|
34 |
+
documents: Iterable[Document],
|
35 |
path: str,
|
36 |
file_name: str = "documents.jsonl",
|
37 |
metadata_file_name: str = METADATA_FILE_NAME,
|
|
|
42 |
log.info(f'serialize documents to "{realpath}" ...')
|
43 |
os.makedirs(realpath, exist_ok=True)
|
44 |
|
45 |
+
if not isinstance(documents, Sequence):
|
46 |
+
documents = list(documents)
|
47 |
+
|
48 |
# dump metadata including the document_type
|
49 |
if len(documents) == 0:
|
50 |
raise Exception("cannot serialize empty list of documents")
|
|
|
133 |
all_kwargs = {**self.default_kwargs, **kwargs}
|
134 |
return self.write(**all_kwargs)
|
135 |
|
136 |
+
def __call__(self, documents: Iterable[Document], **kwargs) -> Dict[str, str]:
|
137 |
return self.write_with_defaults(documents=documents, **kwargs)
|
138 |
|
139 |
|
|
|
144 |
@classmethod
|
145 |
def write(
|
146 |
cls,
|
147 |
+
documents: Iterable[Document],
|
148 |
path: str,
|
149 |
split: str = "train",
|
150 |
) -> Dict[str, str]:
|
151 |
if not isinstance(documents, (Dataset, IterableDataset)):
|
152 |
+
if not isinstance(documents, Sequence):
|
153 |
+
documents = IterableDataset.from_documents(documents)
|
154 |
+
else:
|
155 |
+
documents = Dataset.from_documents(documents)
|
156 |
dataset_dict = DatasetDict({split: documents})
|
157 |
dataset_dict.to_json(path=path)
|
158 |
return {"path": path, "split": split}
|
|
|
181 |
all_kwargs = {**self.default_kwargs, **kwargs}
|
182 |
return self.write(**all_kwargs)
|
183 |
|
184 |
+
def __call__(self, documents: Iterable[Document], **kwargs) -> Dict[str, str]:
|
185 |
return self.write_with_defaults(documents=documents, **kwargs)
|
src/start_demo.py
CHANGED
@@ -99,6 +99,7 @@ def main(cfg: DictConfig) -> None:
|
|
99 |
render_caption2mode = {v: k for k, v in render_mode2caption.items()}
|
100 |
default_min_similarity = cfg["default_min_similarity"]
|
101 |
default_top_k = cfg["default_top_k"]
|
|
|
102 |
layer_caption_mapping = cfg["layer_caption_mapping"]
|
103 |
relation_name_mapping = cfg["relation_name_mapping"]
|
104 |
|
@@ -287,6 +288,13 @@ def main(cfg: DictConfig) -> None:
|
|
287 |
step=1,
|
288 |
value=default_top_k,
|
289 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
retrieve_similar_adus_btn = gr.Button(
|
291 |
"Retrieve *similar* ADUs for *selected* ADU"
|
292 |
)
|
@@ -361,18 +369,23 @@ def main(cfg: DictConfig) -> None:
|
|
361 |
load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
|
362 |
|
363 |
render_event_kwargs = dict(
|
364 |
-
fn=lambda _retriever, _document_id, _render_as, _render_kwargs, _all_relevant_adus_df, _all_relevant_adus_query_doc_id:
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
|
|
|
|
|
|
|
|
374 |
),
|
375 |
inputs=[
|
|
|
376 |
retriever_state,
|
377 |
selected_document_id,
|
378 |
render_as,
|
@@ -583,10 +596,11 @@ def main(cfg: DictConfig) -> None:
|
|
583 |
).success(**show_stats_kwargs)
|
584 |
|
585 |
retrieve_relevant_adus_event_kwargs = dict(
|
586 |
-
fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans(
|
587 |
retriever=_retriever[0],
|
588 |
query_span_id=_selected_adu_id,
|
589 |
k=_top_k,
|
|
|
590 |
score_threshold=_min_similarity,
|
591 |
relation_label_mapping=relation_name_mapping,
|
592 |
# columns=relevant_adus.headers
|
@@ -596,6 +610,7 @@ def main(cfg: DictConfig) -> None:
|
|
596 |
selected_adu_id,
|
597 |
min_similarity,
|
598 |
top_k,
|
|
|
599 |
],
|
600 |
outputs=[relevant_adus_df],
|
601 |
)
|
@@ -614,10 +629,11 @@ def main(cfg: DictConfig) -> None:
|
|
614 |
).success(**retrieve_relevant_adus_event_kwargs)
|
615 |
|
616 |
retrieve_similar_adus_btn.click(
|
617 |
-
fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k: retrieve_similar_spans(
|
618 |
retriever=_retriever[0],
|
619 |
query_span_id=_selected_adu_id,
|
620 |
k=_tok_k,
|
|
|
621 |
score_threshold=_min_similarity,
|
622 |
),
|
623 |
inputs=[
|
@@ -625,6 +641,7 @@ def main(cfg: DictConfig) -> None:
|
|
625 |
selected_adu_id,
|
626 |
min_similarity,
|
627 |
top_k,
|
|
|
628 |
],
|
629 |
outputs=[similar_adus_df],
|
630 |
)
|
@@ -635,10 +652,11 @@ def main(cfg: DictConfig) -> None:
|
|
635 |
)
|
636 |
|
637 |
retrieve_all_similar_adus_btn.click(
|
638 |
-
fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_similar_spans(
|
639 |
retriever=_retriever[0],
|
640 |
query_doc_id=_document_id,
|
641 |
k=_tok_k,
|
|
|
642 |
score_threshold=_min_similarity,
|
643 |
query_span_id_column="query_span_id",
|
644 |
),
|
@@ -647,16 +665,18 @@ def main(cfg: DictConfig) -> None:
|
|
647 |
selected_document_id,
|
648 |
min_similarity,
|
649 |
top_k,
|
|
|
650 |
],
|
651 |
outputs=[all_similar_adus_df],
|
652 |
)
|
653 |
|
654 |
retrieve_all_relevant_adus_btn.click(
|
655 |
-
fn=lambda _retriever, _document_id, _min_similarity, _tok_k: (
|
656 |
retrieve_all_relevant_spans(
|
657 |
retriever=_retriever[0],
|
658 |
query_doc_id=_document_id,
|
659 |
k=_tok_k,
|
|
|
660 |
score_threshold=_min_similarity,
|
661 |
query_span_id_column="query_span_id",
|
662 |
query_span_text_column="query_span_text",
|
@@ -668,6 +688,7 @@ def main(cfg: DictConfig) -> None:
|
|
668 |
selected_document_id,
|
669 |
min_similarity,
|
670 |
top_k,
|
|
|
671 |
],
|
672 |
outputs=[all_relevant_adus_df, all_relevant_adus_query_doc_id],
|
673 |
)
|
|
|
99 |
render_caption2mode = {v: k for k, v in render_mode2caption.items()}
|
100 |
default_min_similarity = cfg["default_min_similarity"]
|
101 |
default_top_k = cfg["default_top_k"]
|
102 |
+
default_min_score = cfg["default_min_score"]
|
103 |
layer_caption_mapping = cfg["layer_caption_mapping"]
|
104 |
relation_name_mapping = cfg["relation_name_mapping"]
|
105 |
|
|
|
288 |
step=1,
|
289 |
value=default_top_k,
|
290 |
)
|
291 |
+
min_score = gr.Slider(
|
292 |
+
label="Minimum Score",
|
293 |
+
minimum=0.0,
|
294 |
+
maximum=1.0,
|
295 |
+
step=0.01,
|
296 |
+
value=default_min_score,
|
297 |
+
)
|
298 |
retrieve_similar_adus_btn = gr.Button(
|
299 |
"Retrieve *similar* ADUs for *selected* ADU"
|
300 |
)
|
|
|
369 |
load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
|
370 |
|
371 |
render_event_kwargs = dict(
|
372 |
+
fn=lambda _rendered_output, _retriever, _document_id, _render_as, _render_kwargs, _all_relevant_adus_df, _all_relevant_adus_query_doc_id: (
|
373 |
+
render_annotated_document(
|
374 |
+
retriever=_retriever[0],
|
375 |
+
document_id=_document_id,
|
376 |
+
render_with=render_caption2mode[_render_as],
|
377 |
+
render_kwargs_json=_render_kwargs,
|
378 |
+
highlight_span_ids=(
|
379 |
+
_all_relevant_adus_df["query_span_id"].tolist()
|
380 |
+
if _document_id == _all_relevant_adus_query_doc_id
|
381 |
+
else None
|
382 |
+
),
|
383 |
+
)
|
384 |
+
if _document_id.strip() != ""
|
385 |
+
else _rendered_output
|
386 |
),
|
387 |
inputs=[
|
388 |
+
rendered_output,
|
389 |
retriever_state,
|
390 |
selected_document_id,
|
391 |
render_as,
|
|
|
596 |
).success(**show_stats_kwargs)
|
597 |
|
598 |
retrieve_relevant_adus_event_kwargs = dict(
|
599 |
+
fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k, _min_score: retrieve_relevant_spans(
|
600 |
retriever=_retriever[0],
|
601 |
query_span_id=_selected_adu_id,
|
602 |
k=_top_k,
|
603 |
+
min_score=_min_score,
|
604 |
score_threshold=_min_similarity,
|
605 |
relation_label_mapping=relation_name_mapping,
|
606 |
# columns=relevant_adus.headers
|
|
|
610 |
selected_adu_id,
|
611 |
min_similarity,
|
612 |
top_k,
|
613 |
+
min_score,
|
614 |
],
|
615 |
outputs=[relevant_adus_df],
|
616 |
)
|
|
|
629 |
).success(**retrieve_relevant_adus_event_kwargs)
|
630 |
|
631 |
retrieve_similar_adus_btn.click(
|
632 |
+
fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k, _min_score: retrieve_similar_spans(
|
633 |
retriever=_retriever[0],
|
634 |
query_span_id=_selected_adu_id,
|
635 |
k=_tok_k,
|
636 |
+
min_score=_min_score,
|
637 |
score_threshold=_min_similarity,
|
638 |
),
|
639 |
inputs=[
|
|
|
641 |
selected_adu_id,
|
642 |
min_similarity,
|
643 |
top_k,
|
644 |
+
min_score,
|
645 |
],
|
646 |
outputs=[similar_adus_df],
|
647 |
)
|
|
|
652 |
)
|
653 |
|
654 |
retrieve_all_similar_adus_btn.click(
|
655 |
+
fn=lambda _retriever, _document_id, _min_similarity, _tok_k, _min_score: retrieve_all_similar_spans(
|
656 |
retriever=_retriever[0],
|
657 |
query_doc_id=_document_id,
|
658 |
k=_tok_k,
|
659 |
+
min_score=_min_score,
|
660 |
score_threshold=_min_similarity,
|
661 |
query_span_id_column="query_span_id",
|
662 |
),
|
|
|
665 |
selected_document_id,
|
666 |
min_similarity,
|
667 |
top_k,
|
668 |
+
min_score,
|
669 |
],
|
670 |
outputs=[all_similar_adus_df],
|
671 |
)
|
672 |
|
673 |
retrieve_all_relevant_adus_btn.click(
|
674 |
+
fn=lambda _retriever, _document_id, _min_similarity, _tok_k, _min_score: (
|
675 |
retrieve_all_relevant_spans(
|
676 |
retriever=_retriever[0],
|
677 |
query_doc_id=_document_id,
|
678 |
k=_tok_k,
|
679 |
+
min_score=_min_score,
|
680 |
score_threshold=_min_similarity,
|
681 |
query_span_id_column="query_span_id",
|
682 |
query_span_text_column="query_span_text",
|
|
|
688 |
selected_document_id,
|
689 |
min_similarity,
|
690 |
top_k,
|
691 |
+
min_score,
|
692 |
],
|
693 |
outputs=[all_relevant_adus_df, all_relevant_adus_query_doc_id],
|
694 |
)
|
src/taskmodules/cross_text_binary_coref_nli.py
CHANGED
@@ -62,6 +62,9 @@ class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleTyp
|
|
62 |
tokenizer_name_or_path: str,
|
63 |
labels: List[str],
|
64 |
entailment_label: str,
|
|
|
|
|
|
|
65 |
**kwargs,
|
66 |
) -> None:
|
67 |
super().__init__(**kwargs)
|
@@ -69,6 +72,9 @@ class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleTyp
|
|
69 |
|
70 |
self.labels = labels
|
71 |
self.entailment_label = entailment_label
|
|
|
|
|
|
|
72 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
|
73 |
|
74 |
def _post_prepare(self):
|
@@ -118,9 +124,18 @@ class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleTyp
|
|
118 |
for task_encoding in task_encodings:
|
119 |
all_texts.extend(task_encoding.inputs["text"])
|
120 |
all_texts_pair.extend(task_encoding.inputs["text_pair"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
inputs = self.tokenizer(
|
122 |
-
text=
|
123 |
-
text_pair=
|
124 |
truncation=True,
|
125 |
padding=True,
|
126 |
return_tensors="pt",
|
@@ -159,8 +174,20 @@ class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleTyp
|
|
159 |
task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType],
|
160 |
task_output: TaskOutputType,
|
161 |
) -> Iterator[Tuple[str, Annotation]]:
|
162 |
-
if
|
|
|
|
|
|
|
163 |
probs = task_output["entailment_probability_pair"]
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
new_coref_rel = task_encoding.metadata["candidate_annotation"].copy(score=score)
|
166 |
yield "binary_coref_relations", new_coref_rel
|
|
|
62 |
tokenizer_name_or_path: str,
|
63 |
labels: List[str],
|
64 |
entailment_label: str,
|
65 |
+
combine_score_method: str = "average",
|
66 |
+
keep_all_relations: bool = False,
|
67 |
+
as_text_pair: bool = True,
|
68 |
**kwargs,
|
69 |
) -> None:
|
70 |
super().__init__(**kwargs)
|
|
|
72 |
|
73 |
self.labels = labels
|
74 |
self.entailment_label = entailment_label
|
75 |
+
self.combine_score_method = combine_score_method
|
76 |
+
self.keep_all_relations = keep_all_relations
|
77 |
+
self.as_text_pair = as_text_pair
|
78 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
|
79 |
|
80 |
def _post_prepare(self):
|
|
|
124 |
for task_encoding in task_encodings:
|
125 |
all_texts.extend(task_encoding.inputs["text"])
|
126 |
all_texts_pair.extend(task_encoding.inputs["text_pair"])
|
127 |
+
if self.as_text_pair:
|
128 |
+
text = all_texts
|
129 |
+
text_pair = all_texts_pair
|
130 |
+
else:
|
131 |
+
text = [
|
132 |
+
f"{text}{self.tokenizer.sep_token}{text_pair}"
|
133 |
+
for text, text_pair in zip(all_texts, all_texts_pair)
|
134 |
+
]
|
135 |
+
text_pair = None
|
136 |
inputs = self.tokenizer(
|
137 |
+
text=text,
|
138 |
+
text_pair=text_pair,
|
139 |
truncation=True,
|
140 |
padding=True,
|
141 |
return_tensors="pt",
|
|
|
174 |
task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType],
|
175 |
task_output: TaskOutputType,
|
176 |
) -> Iterator[Tuple[str, Annotation]]:
|
177 |
+
if (
|
178 |
+
all(label == self.entailment_label for label in task_output["label_pair"])
|
179 |
+
or self.keep_all_relations
|
180 |
+
):
|
181 |
probs = task_output["entailment_probability_pair"]
|
182 |
+
if self.combine_score_method == "average":
|
183 |
+
score = (probs[0] + probs[1]) / 2
|
184 |
+
elif self.combine_score_method == "min":
|
185 |
+
score = min(probs)
|
186 |
+
elif self.combine_score_method == "max":
|
187 |
+
score = max(probs)
|
188 |
+
elif self.combine_score_method == "product":
|
189 |
+
score = probs[0] * probs[1]
|
190 |
+
else:
|
191 |
+
raise ValueError(f"Unsupported combine_score_method: {self.combine_score_method}")
|
192 |
new_coref_rel = task_encoding.metadata["candidate_annotation"].copy(score=score)
|
193 |
yield "binary_coref_relations", new_coref_rel
|
src/train.py
CHANGED
@@ -38,13 +38,14 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
38 |
|
39 |
import hydra
|
40 |
import pytorch_lightning as pl
|
41 |
-
from omegaconf import DictConfig
|
42 |
from pie_datasets import DatasetDict
|
43 |
from pie_modules.models import * # noqa: F403
|
44 |
from pie_modules.models import SimpleGenerativeModel
|
45 |
from pie_modules.models.interface import RequiresTaskmoduleConfig
|
46 |
from pie_modules.taskmodules import * # noqa: F403
|
47 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
|
|
48 |
from pytorch_ie.core import PyTorchIEModel, TaskModule
|
49 |
from pytorch_ie.models import * # noqa: F403
|
50 |
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
|
@@ -56,6 +57,7 @@ from pytorch_lightning.loggers import Logger
|
|
56 |
from src import utils
|
57 |
from src.datamodules import PieDataModule
|
58 |
from src.models import * # noqa: F403
|
|
|
59 |
from src.taskmodules import * # noqa: F403
|
60 |
|
61 |
log = utils.get_pylogger(__name__)
|
@@ -81,6 +83,27 @@ def get_metric_value(metric_dict: dict, metric_name: str) -> Optional[float]:
|
|
81 |
return metric_value
|
82 |
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
@utils.task_wrapper
|
85 |
def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
86 |
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
|
@@ -179,6 +202,11 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
179 |
)
|
180 |
additional_model_kwargs["base_model_config"] = base_model_config
|
181 |
|
|
|
|
|
|
|
|
|
|
|
182 |
# initialize the model
|
183 |
model: PyTorchIEModel = hydra.utils.instantiate(
|
184 |
cfg.model, _convert_="partial", **additional_model_kwargs
|
@@ -207,9 +235,11 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
207 |
log.info("Logging hyperparameters!")
|
208 |
utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)
|
209 |
|
210 |
-
if cfg.model_save_dir is not None:
|
211 |
-
log.info(f"Save taskmodule to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
|
212 |
-
taskmodule.save_pretrained(
|
|
|
|
|
213 |
else:
|
214 |
log.warning("the taskmodule is not saved because no save_dir is specified")
|
215 |
|
@@ -238,15 +268,17 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
238 |
f"Expected format: " + '"epoch_{best_epoch}.ckpt"'
|
239 |
)
|
240 |
|
241 |
-
if not cfg.trainer.get("fast_dev_run"):
|
242 |
-
if cfg.model_save_dir is not None:
|
243 |
if best_ckpt_path == "":
|
244 |
log.warning("Best ckpt not found! Using current weights for saving...")
|
245 |
else:
|
246 |
model = type(model).load_from_checkpoint(best_ckpt_path)
|
247 |
|
248 |
-
log.info(f"Save model to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
|
249 |
-
model.save_pretrained(
|
|
|
|
|
250 |
else:
|
251 |
log.warning("the model is not saved because no save_dir is specified")
|
252 |
|
@@ -275,8 +307,36 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
275 |
|
276 |
# add model_save_dir to the result so that it gets dumped to job_return_value.json
|
277 |
# if we use hydra_callbacks.SaveJobReturnValueCallback
|
278 |
-
if cfg.get("model_save_dir") is not None:
|
279 |
-
metric_dict["model_save_dir"] = cfg.model_save_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
|
281 |
return metric_dict, object_dict
|
282 |
|
@@ -301,4 +361,5 @@ def main(cfg: DictConfig) -> Optional[float]:
|
|
301 |
if __name__ == "__main__":
|
302 |
utils.replace_sys_args_with_values_from_files()
|
303 |
utils.prepare_omegaconf()
|
|
|
304 |
main()
|
|
|
38 |
|
39 |
import hydra
|
40 |
import pytorch_lightning as pl
|
41 |
+
from omegaconf import DictConfig, OmegaConf
|
42 |
from pie_datasets import DatasetDict
|
43 |
from pie_modules.models import * # noqa: F403
|
44 |
from pie_modules.models import SimpleGenerativeModel
|
45 |
from pie_modules.models.interface import RequiresTaskmoduleConfig
|
46 |
from pie_modules.taskmodules import * # noqa: F403
|
47 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
48 |
+
from pytorch_ie import Pipeline
|
49 |
from pytorch_ie.core import PyTorchIEModel, TaskModule
|
50 |
from pytorch_ie.models import * # noqa: F403
|
51 |
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
|
|
|
57 |
from src import utils
|
58 |
from src.datamodules import PieDataModule
|
59 |
from src.models import * # noqa: F403
|
60 |
+
from src.serializer.interface import DocumentSerializer
|
61 |
from src.taskmodules import * # noqa: F403
|
62 |
|
63 |
log = utils.get_pylogger(__name__)
|
|
|
83 |
return metric_value
|
84 |
|
85 |
|
86 |
+
def flatten_nested_dict(d: Dict[str, Any], parent_key: str = "", sep: str = ".") -> Dict[str, Any]:
|
87 |
+
"""Flatten a nested dictionary.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
d (Dict[str, Any]): The dictionary to flatten.
|
91 |
+
parent_key (str): The parent key.
|
92 |
+
sep (str): The separator.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
Dict[str, Any]: The flattened dictionary.
|
96 |
+
"""
|
97 |
+
items: List[Tuple[str, Any]] = []
|
98 |
+
for k, v in d.items():
|
99 |
+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
100 |
+
if isinstance(v, dict):
|
101 |
+
items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
|
102 |
+
else:
|
103 |
+
items.append((new_key, v))
|
104 |
+
return dict(items)
|
105 |
+
|
106 |
+
|
107 |
@utils.task_wrapper
|
108 |
def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
109 |
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
|
|
|
202 |
)
|
203 |
additional_model_kwargs["base_model_config"] = base_model_config
|
204 |
|
205 |
+
if issubclass(model_cls, SimpleSequenceClassificationModelWithInputTypeIds): # noqa: F405
|
206 |
+
# add the number of input type ids to the model:
|
207 |
+
# 2 for B- and I-labels for each entity type, 1 for O labels, 1 for padding
|
208 |
+
additional_model_kwargs["num_token_type_ids"] = len(taskmodule.entity_labels) * 2 + 1 + 1
|
209 |
+
|
210 |
# initialize the model
|
211 |
model: PyTorchIEModel = hydra.utils.instantiate(
|
212 |
cfg.model, _convert_="partial", **additional_model_kwargs
|
|
|
235 |
log.info("Logging hyperparameters!")
|
236 |
utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)
|
237 |
|
238 |
+
if cfg.paths.model_save_dir is not None:
|
239 |
+
log.info(f"Save taskmodule to {cfg.paths.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
|
240 |
+
taskmodule.save_pretrained(
|
241 |
+
save_directory=cfg.paths.model_save_dir, push_to_hub=cfg.push_to_hub
|
242 |
+
)
|
243 |
else:
|
244 |
log.warning("the taskmodule is not saved because no save_dir is specified")
|
245 |
|
|
|
268 |
f"Expected format: " + '"epoch_{best_epoch}.ckpt"'
|
269 |
)
|
270 |
|
271 |
+
if not cfg.trainer.get("fast_dev_run") or cfg.get("predict", False):
|
272 |
+
if cfg.paths.model_save_dir is not None:
|
273 |
if best_ckpt_path == "":
|
274 |
log.warning("Best ckpt not found! Using current weights for saving...")
|
275 |
else:
|
276 |
model = type(model).load_from_checkpoint(best_ckpt_path)
|
277 |
|
278 |
+
log.info(f"Save model to {cfg.paths.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
|
279 |
+
model.save_pretrained(
|
280 |
+
save_directory=cfg.paths.model_save_dir, push_to_hub=cfg.push_to_hub
|
281 |
+
)
|
282 |
else:
|
283 |
log.warning("the model is not saved because no save_dir is specified")
|
284 |
|
|
|
307 |
|
308 |
# add model_save_dir to the result so that it gets dumped to job_return_value.json
|
309 |
# if we use hydra_callbacks.SaveJobReturnValueCallback
|
310 |
+
if cfg.paths.get("model_save_dir") is not None:
|
311 |
+
metric_dict["model_save_dir"] = cfg.paths.model_save_dir
|
312 |
+
|
313 |
+
if cfg.get("predict"):
|
314 |
+
# Init the inference pipeline
|
315 |
+
pipeline: Optional[Pipeline] = None
|
316 |
+
if cfg.get("pipeline") and cfg.pipeline.get("_target_"):
|
317 |
+
log.info(f"Instantiating inference pipeline <{cfg.pipeline._target_}>")
|
318 |
+
pipeline = hydra.utils.instantiate(cfg.pipeline, _convert_="partial")
|
319 |
+
# Init the serializer
|
320 |
+
serializer: Optional[DocumentSerializer] = None
|
321 |
+
if cfg.get("serializer") and cfg.serializer.get("_target_"):
|
322 |
+
log.info(f"Instantiating serializer <{cfg.serializer._target_}>")
|
323 |
+
serializer = hydra.utils.instantiate(cfg.serializer, _convert_="partial")
|
324 |
+
# predict and serialize
|
325 |
+
predict_metrics: Dict[str, Any] = utils.predict_and_serialize(
|
326 |
+
pipeline=pipeline,
|
327 |
+
serializer=serializer,
|
328 |
+
dataset=dataset[cfg.dataset_split],
|
329 |
+
document_batch_size=cfg.get("document_batch_size", None),
|
330 |
+
)
|
331 |
+
# flatten the predict_metrics dict
|
332 |
+
predict_metrics_flat = flatten_nested_dict(predict_metrics, sep="/")
|
333 |
+
metric_dict.update(predict_metrics_flat)
|
334 |
+
|
335 |
+
if cfg.get("delete_model_dir"):
|
336 |
+
import shutil
|
337 |
+
|
338 |
+
log.info(f"Deleting model directory {cfg.paths.model_save_dir}")
|
339 |
+
shutil.rmtree(cfg.paths.model_save_dir)
|
340 |
|
341 |
return metric_dict, object_dict
|
342 |
|
|
|
361 |
if __name__ == "__main__":
|
362 |
utils.replace_sys_args_with_values_from_files()
|
363 |
utils.prepare_omegaconf()
|
364 |
+
OmegaConf.register_new_resolver("eval", eval)
|
365 |
main()
|
src/utils/__init__.py
CHANGED
@@ -5,7 +5,8 @@ from .config_utils import (
|
|
5 |
prepare_omegaconf,
|
6 |
)
|
7 |
from .data_utils import download_and_unzip, filter_dataframe_and_get_column
|
|
|
8 |
from .logging_utils import close_loggers, get_pylogger, log_hyperparameters
|
9 |
from .rich_utils import enforce_tags, print_config_tree
|
10 |
-
from .span_utils import distance
|
11 |
from .task_utils import extras, replace_sys_args_with_values_from_files, save_file, task_wrapper
|
|
|
5 |
prepare_omegaconf,
|
6 |
)
|
7 |
from .data_utils import download_and_unzip, filter_dataframe_and_get_column
|
8 |
+
from .inference_utils import predict_and_serialize
|
9 |
from .logging_utils import close_loggers, get_pylogger, log_hyperparameters
|
10 |
from .rich_utils import enforce_tags, print_config_tree
|
11 |
+
from .span_utils import distance, distance_slices
|
12 |
from .task_utils import extras, replace_sys_args_with_values_from_files, save_file, task_wrapper
|
src/utils/inference_utils.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import timeit
|
2 |
+
from collections.abc import Iterable, Sequence
|
3 |
+
from typing import Any, Dict, Optional, Union
|
4 |
+
|
5 |
+
from pytorch_ie import Document, Pipeline
|
6 |
+
|
7 |
+
from src.serializer.interface import DocumentSerializer
|
8 |
+
|
9 |
+
from .logging_utils import get_pylogger
|
10 |
+
|
11 |
+
log = get_pylogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def document_batch_iter(
|
15 |
+
dataset: Iterable[Document], batch_size: int
|
16 |
+
) -> Iterable[Sequence[Document]]:
|
17 |
+
if isinstance(dataset, Sequence):
|
18 |
+
for i in range(0, len(dataset), batch_size):
|
19 |
+
yield dataset[i : i + batch_size]
|
20 |
+
elif isinstance(dataset, Iterable):
|
21 |
+
docs = []
|
22 |
+
for doc in dataset:
|
23 |
+
docs.append(doc)
|
24 |
+
if len(docs) == batch_size:
|
25 |
+
yield docs
|
26 |
+
docs = []
|
27 |
+
if docs:
|
28 |
+
yield docs
|
29 |
+
else:
|
30 |
+
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
|
31 |
+
|
32 |
+
|
33 |
+
def predict_and_serialize(
|
34 |
+
pipeline: Optional[Pipeline],
|
35 |
+
serializer: Optional[DocumentSerializer],
|
36 |
+
dataset: Iterable[Document],
|
37 |
+
document_batch_size: Optional[int] = None,
|
38 |
+
) -> Dict[str, Any]:
|
39 |
+
result: Dict[str, Any] = {}
|
40 |
+
if pipeline is not None:
|
41 |
+
log.info("Starting inference!")
|
42 |
+
prediction_time = 0.0
|
43 |
+
else:
|
44 |
+
log.warning("No prediction pipeline is defined, skip inference!")
|
45 |
+
prediction_time = None
|
46 |
+
docs_batch: Union[Iterable[Document], Sequence[Document]]
|
47 |
+
|
48 |
+
batch_iter: Union[Sequence[Iterable[Document]], Iterable[Sequence[Document]]]
|
49 |
+
if document_batch_size is None:
|
50 |
+
batch_iter = [dataset]
|
51 |
+
else:
|
52 |
+
batch_iter = document_batch_iter(dataset=dataset, batch_size=document_batch_size)
|
53 |
+
for docs_batch in batch_iter:
|
54 |
+
if pipeline is not None:
|
55 |
+
t_start = timeit.default_timer()
|
56 |
+
docs_batch = pipeline(docs_batch, inplace=False)
|
57 |
+
prediction_time += timeit.default_timer() - t_start # type: ignore
|
58 |
+
|
59 |
+
# serialize the documents
|
60 |
+
if serializer is not None:
|
61 |
+
# the serializer should not return the serialized documents, but write them to disk
|
62 |
+
# and instead return some metadata such as the path to the serialized documents
|
63 |
+
serializer_result = serializer(docs_batch)
|
64 |
+
if "serializer" in result and result["serializer"] != serializer_result:
|
65 |
+
log.warning(
|
66 |
+
f"serializer result changed from {result['serializer']} to {serializer_result}"
|
67 |
+
" during prediction. Only the last result is returned."
|
68 |
+
)
|
69 |
+
result["serializer"] = serializer_result
|
70 |
+
|
71 |
+
if prediction_time is not None:
|
72 |
+
result["prediction_time"] = prediction_time
|
73 |
+
|
74 |
+
return result
|
src/utils/span_utils.py
CHANGED
@@ -58,3 +58,17 @@ def distance(
|
|
58 |
raise ValueError(
|
59 |
f"unknown distance_type={distance_type}. use one of: center, inner, outer"
|
60 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
raise ValueError(
|
59 |
f"unknown distance_type={distance_type}. use one of: center, inner, outer"
|
60 |
)
|
61 |
+
|
62 |
+
|
63 |
+
def distance_slices(
|
64 |
+
slices: Tuple[Tuple[int, int], ...],
|
65 |
+
other_slices: Tuple[Tuple[int, int], ...],
|
66 |
+
distance_type: str,
|
67 |
+
) -> float:
|
68 |
+
starts, ends = zip(*slices)
|
69 |
+
other_starts, other_ends = zip(*other_slices)
|
70 |
+
return distance(
|
71 |
+
start_end=(min(starts), max(ends)),
|
72 |
+
other_start_end=(min(other_starts), max(other_ends)),
|
73 |
+
distance_type=distance_type,
|
74 |
+
)
|