alexandrlukashov commited on
Commit
398beaa
·
verified ·
1 Parent(s): 9127367
interfaces/compare_pipeline.py CHANGED
@@ -100,14 +100,18 @@ def compute_scores(*args):
100
  def compute_table(*args):
101
  gliclass_results, st_results = compute_scores(*args)
102
  max_docs = int(os.getenv("MAX_DOCS"))
 
 
103
  gliclass_labels = gliclass_results[:max_docs]
104
  st_labels = st_results[:max_docs]
 
 
 
105
  df = pd.DataFrame({
106
- "Rank": list(range(1, max_docs + 1)),
107
- "GLiClass Label": gliclass_labels,
108
- "CrossEncoder Label": st_labels,
109
  })
110
-
111
  return df
112
 
113
  examples = [
@@ -116,13 +120,14 @@ examples = [
116
  ]
117
 
118
  with gr.Blocks(title="GLiClass-Reranker") as compare_pipeline:
 
119
  inputs = []
120
  query = gr.Textbox(
121
  value=examples[0][0], label="Text query", placeholder="Enter your query here", lines=4
122
  )
123
- labels = [gr.Textbox(value=label, label=f"Label {i+1}") for i, label in enumerate(examples[0][1:])]
124
  submit_btn = gr.Button("Compare")
125
- result_table = gr.Dataframe(headers=["Rank", "GLiClass Label", "CrossEncoder Label"],
126
  label="Comparison Table",
127
  interactive=False)
128
 
 
100
  def compute_table(*args):
101
  gliclass_results, st_results = compute_scores(*args)
102
  max_docs = int(os.getenv("MAX_DOCS"))
103
+ labels = args[1:]
104
+
105
  gliclass_labels = gliclass_results[:max_docs]
106
  st_labels = st_results[:max_docs]
107
+
108
+ label_rank_gliclass = {label: rank + 1 for rank, label in enumerate(gliclass_labels) if label}
109
+ label_rank_st = {label: rank + 1 for rank, label in enumerate(st_labels) if label}
110
  df = pd.DataFrame({
111
+ "Document": labels,
112
+ "GLiClass Rank": [label_rank_gliclass.get(label, "") for label in labels],
113
+ "Cross-Encoder Rank": [label_rank_st.get(label, "") for label in labels],
114
  })
 
115
  return df
116
 
117
  examples = [
 
120
  ]
121
 
122
  with gr.Blocks(title="GLiClass-Reranker") as compare_pipeline:
123
+ example_state = gr.State(value=examples)
124
  inputs = []
125
  query = gr.Textbox(
126
  value=examples[0][0], label="Text query", placeholder="Enter your query here", lines=4
127
  )
128
+ labels = [gr.Textbox(value=label, label=f"Document {i+1}") for i, label in enumerate(examples[0][1:])]
129
  submit_btn = gr.Button("Compare")
130
+ result_table = gr.Dataframe(headers=["Document", "GLiClass Rank", "Cross-Encoder Rank"],
131
  label="Comparison Table",
132
  interactive=False)
133
 
interfaces/scores_pipeline.py CHANGED
@@ -86,6 +86,7 @@ def classification(*args) -> List[str]:
86
  return docs + scores
87
 
88
  with gr.Blocks(title="GLiClass-Reranker") as scores_pipeline:
 
89
  inputs = []
90
  outputs = []
91
  query = gr.Textbox(
 
86
  return docs + scores
87
 
88
  with gr.Blocks(title="GLiClass-Reranker") as scores_pipeline:
89
+ example_state = gr.State(value=examples)
90
  inputs = []
91
  outputs = []
92
  query = gr.Textbox(