Spaces:
Runtime error
Runtime error
Commit
·
dc7ce01
1
Parent(s):
6ef6e5f
nlg models removal
Browse files
app.py
CHANGED
@@ -1,14 +1,13 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
from models import DST_MODELS,
|
4 |
|
5 |
|
6 |
def predict(text: str, model_name: str) -> str:
|
7 |
return PIPELINES[model_name](text)
|
8 |
|
9 |
|
10 |
-
with gr.Blocks(title="CLARIN-PL
|
11 |
-
gr.Markdown("Dialogue State Tracking Modules")
|
12 |
for model_name in DST_MODELS:
|
13 |
with gr.Row():
|
14 |
gr.Markdown(f"## {model_name}")
|
@@ -21,18 +20,5 @@ with gr.Blocks(title="CLARIN-PL Dialogue System Modules") as demo:
|
|
21 |
predict_button.click(fn=predict, inputs=[text_input, model_name_component], outputs=output)
|
22 |
|
23 |
|
24 |
-
gr.Markdown("Natural Language Generation / Paraphrasing Modules")
|
25 |
-
for model_name in NLG_MODELS:
|
26 |
-
with gr.Row():
|
27 |
-
gr.Markdown(f"## {model_name}")
|
28 |
-
model_name_component = gr.Textbox(value=model_name, visible=False)
|
29 |
-
with gr.Row():
|
30 |
-
text_input = gr.Textbox(label="Input Text", value=NLG_MODELS[model_name]["default_input"])
|
31 |
-
output = gr.Textbox(label="Slot Value", value="")
|
32 |
-
with gr.Row():
|
33 |
-
predict_button = gr.Button("Predict")
|
34 |
-
predict_button.click(fn=predict, inputs=[text_input, model_name_component], outputs=output)
|
35 |
-
|
36 |
-
|
37 |
demo.queue(concurrency_count=3)
|
38 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
from models import DST_MODELS, PIPELINES
|
4 |
|
5 |
|
6 |
def predict(text: str, model_name: str) -> str:
|
7 |
return PIPELINES[model_name](text)
|
8 |
|
9 |
|
10 |
+
with gr.Blocks(title="CLARIN-PL DST Modules") as demo:
|
|
|
11 |
for model_name in DST_MODELS:
|
12 |
with gr.Row():
|
13 |
gr.Markdown(f"## {model_name}")
|
|
|
20 |
predict_button.click(fn=predict, inputs=[text_input, model_name_component], outputs=output)
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
demo.queue(concurrency_count=3)
|
24 |
demo.launch()
|
models.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
from typing import Any, Dict
|
3 |
|
4 |
from transformers import (Pipeline, T5ForConditionalGeneration, T5Tokenizer,
|
5 |
-
pipeline
|
6 |
|
7 |
auth_token = os.environ.get("CLARIN_KNEXT")
|
8 |
|
@@ -60,48 +60,6 @@ DST_MODELS: Dict[str, Dict[str, Any]] = {
|
|
60 |
}
|
61 |
|
62 |
|
63 |
-
DEFAULT_ENCODER_DECODER_INPUT_EN = "The alarm is set for 6 am. The alarm's name is name \"Get up\"."
|
64 |
-
DEFAULT_DECODER_ONLY_INPUT_EN = f"[BOS]{DEFAULT_ENCODER_DECODER_INPUT_EN}[SEP]"
|
65 |
-
DEFAULT_ENCODER_DECODER_INPUT_PL = "Alarm jest o godzinie 6 rano. Alarm ma nazwę \"Obudź się\"."
|
66 |
-
DEFAULT_DECODER_ONLY_INPUT_PL = f"[BOS]{DEFAULT_ENCODER_DECODER_INPUT_PL}[SEP]"
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
NLG_MODELS: Dict[str, Dict[str, Any]] = {
|
71 |
-
# English
|
72 |
-
"t5-large": {
|
73 |
-
"model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-t5-large", use_auth_token=auth_token),
|
74 |
-
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-t5-large", use_auth_token=auth_token),
|
75 |
-
"default_input": DEFAULT_ENCODER_DECODER_INPUT_EN,
|
76 |
-
},
|
77 |
-
"en-mt5-large": {
|
78 |
-
"model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-en-mt5-large", use_auth_token=auth_token),
|
79 |
-
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-en-mt5-large", use_auth_token=auth_token),
|
80 |
-
"default_input": DEFAULT_ENCODER_DECODER_INPUT_EN,
|
81 |
-
},
|
82 |
-
"gpt2": {
|
83 |
-
"model": AutoModelForCausalLM.from_pretrained("clarin-knext/utterance-rewriting-gpt2", use_auth_token=auth_token),
|
84 |
-
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-gpt2", use_auth_token=auth_token),
|
85 |
-
"default_input": DEFAULT_DECODER_ONLY_INPUT_EN,
|
86 |
-
},
|
87 |
-
|
88 |
-
"pt5-large": {
|
89 |
-
"model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-pt5-large", use_auth_token=auth_token),
|
90 |
-
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-pt5-large", use_auth_token=auth_token),
|
91 |
-
"default_input": DEFAULT_ENCODER_DECODER_INPUT_PL,
|
92 |
-
},
|
93 |
-
"pl-mt5-large": {
|
94 |
-
"model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-pl-mt5-large", use_auth_token=auth_token),
|
95 |
-
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-pl-mt5-large", use_auth_token=auth_token),
|
96 |
-
"default_input": DEFAULT_ENCODER_DECODER_INPUT_PL,
|
97 |
-
},
|
98 |
-
"polish-gpt2": {
|
99 |
-
"model": AutoModelForCausalLM.from_pretrained("clarin-knext/utterance-rewriting-polish-gpt2", use_auth_token=auth_token),
|
100 |
-
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-polish-gpt2", use_auth_token=auth_token),
|
101 |
-
"default_input": DEFAULT_DECODER_ONLY_INPUT_PL,
|
102 |
-
},
|
103 |
-
}
|
104 |
-
|
105 |
PIPELINES: Dict[str, Pipeline] = {
|
106 |
model_name: pipeline(
|
107 |
"text2text-generation", model=DST_MODELS[model_name]["model"], tokenizer=DST_MODELS[model_name]["tokenizer"]
|
|
|
2 |
from typing import Any, Dict
|
3 |
|
4 |
from transformers import (Pipeline, T5ForConditionalGeneration, T5Tokenizer,
|
5 |
+
pipeline)
|
6 |
|
7 |
auth_token = os.environ.get("CLARIN_KNEXT")
|
8 |
|
|
|
60 |
}
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
PIPELINES: Dict[str, Pipeline] = {
|
64 |
model_name: pipeline(
|
65 |
"text2text-generation", model=DST_MODELS[model_name]["model"], tokenizer=DST_MODELS[model_name]["tokenizer"]
|