Spaces:
Runtime error
Runtime error
chicham
commited on
Modify the way the results a shown (#4)
Browse files
app.py
CHANGED
|
@@ -2,8 +2,6 @@
|
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import functools
|
| 5 |
-
from collections import defaultdict
|
| 6 |
-
from itertools import chain
|
| 7 |
from typing import Any
|
| 8 |
from typing import Callable
|
| 9 |
from typing import Mapping
|
|
@@ -13,7 +11,6 @@ import attr
|
|
| 13 |
import environ
|
| 14 |
import fasttext # not working with python3.9
|
| 15 |
import gradio as gr
|
| 16 |
-
from tokenizers.pre_tokenizers import Whitespace
|
| 17 |
from transformers.pipelines import pipeline
|
| 18 |
from transformers.pipelines.base import Pipeline
|
| 19 |
from transformers.pipelines.token_classification import AggregationStrategy
|
|
@@ -127,8 +124,8 @@ def predict(
|
|
| 127 |
supported_languages: tuple[str, ...] = ("fr", "de"),
|
| 128 |
) -> tuple[
|
| 129 |
Mapping[str, float],
|
| 130 |
-
str,
|
| 131 |
Mapping[str, float],
|
|
|
|
| 132 |
Sequence[tuple[str, str | None]],
|
| 133 |
Sequence[tuple[str, str | None]],
|
| 134 |
]:
|
|
@@ -189,27 +186,23 @@ def predict(
|
|
| 189 |
predict_fn: Callable,
|
| 190 |
query: str,
|
| 191 |
) -> Sequence[tuple[str, str | None]]:
|
| 192 |
-
def get_entity(pred: Mapping[str, str]):
|
| 193 |
-
return pred.get("entity", pred.get("entity_group", None))
|
| 194 |
|
| 195 |
-
|
| 196 |
-
mapping.update(**{pred["word"]: get_entity(pred) for pred in predict_fn(query)})
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
return res
|
| 206 |
|
| 207 |
languages = predict_lang(query)
|
| 208 |
translation = translate_query(query, languages)
|
| 209 |
classifications = classify_query(translation, categories)
|
| 210 |
general_entities = extract_entities(models.ner, query)
|
| 211 |
recipe_entities = extract_entities(models.recipe, translation)
|
| 212 |
-
return languages,
|
| 213 |
|
| 214 |
|
| 215 |
def main():
|
|
@@ -254,7 +247,7 @@ def main():
|
|
| 254 |
load_fn=lambda: pipeline(
|
| 255 |
"ner",
|
| 256 |
model=cfg.ner.general,
|
| 257 |
-
aggregation_strategy=AggregationStrategy.
|
| 258 |
),
|
| 259 |
),
|
| 260 |
recipe=Predictor(
|
|
@@ -282,15 +275,15 @@ def main():
|
|
| 282 |
type="auto",
|
| 283 |
label="Language identification",
|
| 284 |
),
|
| 285 |
-
gr.outputs.Textbox(
|
| 286 |
-
label="English query",
|
| 287 |
-
type="auto",
|
| 288 |
-
),
|
| 289 |
gr.outputs.Label(
|
| 290 |
num_top_classes=cfg.classification.max_results,
|
| 291 |
type="auto",
|
| 292 |
label="Predicted categories",
|
| 293 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
gr.outputs.HighlightedText(label="NER generic"),
|
| 295 |
gr.outputs.HighlightedText(label="NER Recipes"),
|
| 296 |
],
|
|
|
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import functools
|
|
|
|
|
|
|
| 5 |
from typing import Any
|
| 6 |
from typing import Callable
|
| 7 |
from typing import Mapping
|
|
|
|
| 11 |
import environ
|
| 12 |
import fasttext # not working with python3.9
|
| 13 |
import gradio as gr
|
|
|
|
| 14 |
from transformers.pipelines import pipeline
|
| 15 |
from transformers.pipelines.base import Pipeline
|
| 16 |
from transformers.pipelines.token_classification import AggregationStrategy
|
|
|
|
| 124 |
supported_languages: tuple[str, ...] = ("fr", "de"),
|
| 125 |
) -> tuple[
|
| 126 |
Mapping[str, float],
|
|
|
|
| 127 |
Mapping[str, float],
|
| 128 |
+
str,
|
| 129 |
Sequence[tuple[str, str | None]],
|
| 130 |
Sequence[tuple[str, str | None]],
|
| 131 |
]:
|
|
|
|
| 186 |
predict_fn: Callable,
|
| 187 |
query: str,
|
| 188 |
) -> Sequence[tuple[str, str | None]]:
|
|
|
|
|
|
|
| 189 |
|
| 190 |
+
predictions = predict_fn(query)
|
|
|
|
| 191 |
|
| 192 |
+
if len(predictions) == 0:
|
| 193 |
+
return [(query, None)]
|
| 194 |
+
else:
|
| 195 |
+
return [
|
| 196 |
+
(pred["word"], pred.get("entity_group", pred.get("entity", None)))
|
| 197 |
+
for pred in predictions
|
| 198 |
+
]
|
|
|
|
| 199 |
|
| 200 |
languages = predict_lang(query)
|
| 201 |
translation = translate_query(query, languages)
|
| 202 |
classifications = classify_query(translation, categories)
|
| 203 |
general_entities = extract_entities(models.ner, query)
|
| 204 |
recipe_entities = extract_entities(models.recipe, translation)
|
| 205 |
+
return languages, classifications, translation, general_entities, recipe_entities
|
| 206 |
|
| 207 |
|
| 208 |
def main():
|
|
|
|
| 247 |
load_fn=lambda: pipeline(
|
| 248 |
"ner",
|
| 249 |
model=cfg.ner.general,
|
| 250 |
+
aggregation_strategy=AggregationStrategy.SIMPLE,
|
| 251 |
),
|
| 252 |
),
|
| 253 |
recipe=Predictor(
|
|
|
|
| 275 |
type="auto",
|
| 276 |
label="Language identification",
|
| 277 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
gr.outputs.Label(
|
| 279 |
num_top_classes=cfg.classification.max_results,
|
| 280 |
type="auto",
|
| 281 |
label="Predicted categories",
|
| 282 |
),
|
| 283 |
+
gr.outputs.Textbox(
|
| 284 |
+
label="English query",
|
| 285 |
+
type="auto",
|
| 286 |
+
),
|
| 287 |
gr.outputs.HighlightedText(label="NER generic"),
|
| 288 |
gr.outputs.HighlightedText(label="NER Recipes"),
|
| 289 |
],
|