Spaces:
Running
Running
anakin87
commited on
Commit
·
a147158
1
Parent(s):
55e565f
little code improvements
Browse files- Rock_fact_checker.py +1 -4
- app_utils/backend_utils.py +4 -6
- app_utils/frontend_utils.py +17 -8
- pages/Info.py +3 -2
Rock_fact_checker.py
CHANGED
|
@@ -12,15 +12,13 @@ from app_utils.frontend_utils import (
|
|
| 12 |
entailment_html_messages,
|
| 13 |
create_df_for_relevant_snippets,
|
| 14 |
create_ternary_plot,
|
| 15 |
-
build_sidebar
|
| 16 |
)
|
| 17 |
from app_utils.config import RETRIEVER_TOP_K
|
| 18 |
|
| 19 |
|
| 20 |
def main():
|
| 21 |
-
|
| 22 |
statements = load_statements()
|
| 23 |
-
|
| 24 |
build_sidebar()
|
| 25 |
|
| 26 |
# Persistent state
|
|
@@ -120,7 +118,6 @@ def main():
|
|
| 120 |
st.markdown(f"###### Most Relevant snippets:")
|
| 121 |
df, urls = create_df_for_relevant_snippets(docs)
|
| 122 |
st.dataframe(df)
|
| 123 |
-
|
| 124 |
str_wiki_pages = "Wikipedia source pages: "
|
| 125 |
for doc, url in urls.items():
|
| 126 |
str_wiki_pages += f"[{doc}]({url}) "
|
|
|
|
| 12 |
entailment_html_messages,
|
| 13 |
create_df_for_relevant_snippets,
|
| 14 |
create_ternary_plot,
|
| 15 |
+
build_sidebar,
|
| 16 |
)
|
| 17 |
from app_utils.config import RETRIEVER_TOP_K
|
| 18 |
|
| 19 |
|
| 20 |
def main():
|
|
|
|
| 21 |
statements = load_statements()
|
|
|
|
| 22 |
build_sidebar()
|
| 23 |
|
| 24 |
# Persistent state
|
|
|
|
| 118 |
st.markdown(f"###### Most Relevant snippets:")
|
| 119 |
df, urls = create_df_for_relevant_snippets(docs)
|
| 120 |
st.dataframe(df)
|
|
|
|
| 121 |
str_wiki_pages = "Wikipedia source pages: "
|
| 122 |
for doc, url in urls.items():
|
| 123 |
str_wiki_pages += f"[{doc}]({url}) "
|
app_utils/backend_utils.py
CHANGED
|
@@ -31,7 +31,7 @@ def load_statements():
|
|
| 31 |
)
|
| 32 |
def start_haystack():
|
| 33 |
"""
|
| 34 |
-
load document store, retriever,
|
| 35 |
"""
|
| 36 |
shutil.copy(f"{INDEX_DIR}/faiss_document_store.db", ".")
|
| 37 |
document_store = FAISSDocumentStore(
|
|
@@ -39,13 +39,11 @@ def start_haystack():
|
|
| 39 |
faiss_config_path=f"{INDEX_DIR}/my_faiss_index.json",
|
| 40 |
)
|
| 41 |
print(f"Index size: {document_store.get_document_count()}")
|
| 42 |
-
|
| 43 |
retriever = EmbeddingRetriever(
|
| 44 |
document_store=document_store,
|
| 45 |
embedding_model=RETRIEVER_MODEL,
|
| 46 |
model_format=RETRIEVER_MODEL_FORMAT,
|
| 47 |
)
|
| 48 |
-
|
| 49 |
entailment_checker = EntailmentChecker(model_name_or_path=NLI_MODEL, use_gpu=False)
|
| 50 |
|
| 51 |
pipe = Pipeline()
|
|
@@ -84,8 +82,8 @@ def query(statement: str, retriever_top_k: int = 5):
|
|
| 84 |
break
|
| 85 |
|
| 86 |
results["agg_entailment_info"] = {
|
| 87 |
-
"contradiction":
|
| 88 |
-
"neutral":
|
| 89 |
-
"entailment":
|
| 90 |
}
|
| 91 |
return results
|
|
|
|
| 31 |
)
|
| 32 |
def start_haystack():
|
| 33 |
"""
|
| 34 |
+
load document store, retriever, entailment checker and create pipeline
|
| 35 |
"""
|
| 36 |
shutil.copy(f"{INDEX_DIR}/faiss_document_store.db", ".")
|
| 37 |
document_store = FAISSDocumentStore(
|
|
|
|
| 39 |
faiss_config_path=f"{INDEX_DIR}/my_faiss_index.json",
|
| 40 |
)
|
| 41 |
print(f"Index size: {document_store.get_document_count()}")
|
|
|
|
| 42 |
retriever = EmbeddingRetriever(
|
| 43 |
document_store=document_store,
|
| 44 |
embedding_model=RETRIEVER_MODEL,
|
| 45 |
model_format=RETRIEVER_MODEL_FORMAT,
|
| 46 |
)
|
|
|
|
| 47 |
entailment_checker = EntailmentChecker(model_name_or_path=NLI_MODEL, use_gpu=False)
|
| 48 |
|
| 49 |
pipe = Pipeline()
|
|
|
|
| 82 |
break
|
| 83 |
|
| 84 |
results["agg_entailment_info"] = {
|
| 85 |
+
"contradiction": round(agg_con / scores, 2),
|
| 86 |
+
"neutral": round(agg_neu / scores, 2),
|
| 87 |
+
"entailment": round(agg_ent / scores, 2),
|
| 88 |
}
|
| 89 |
return results
|
app_utils/frontend_utils.py
CHANGED
|
@@ -9,8 +9,9 @@ entailment_html_messages = {
|
|
| 9 |
"neutral": 'The knowledge base is <span style="color:darkgray">neutral</span> about your statement',
|
| 10 |
}
|
| 11 |
|
|
|
|
| 12 |
def build_sidebar():
|
| 13 |
-
sidebar="""
|
| 14 |
<h1 style='text-align: center'>Fact Checking 🎸 Rocks!</h1>
|
| 15 |
<div style='text-align: center'>
|
| 16 |
<i>Fact checking baseline combining dense retrieval and textual entailment</i>
|
|
@@ -20,6 +21,7 @@ def build_sidebar():
|
|
| 20 |
"""
|
| 21 |
st.sidebar.markdown(sidebar, unsafe_allow_html=True)
|
| 22 |
|
|
|
|
| 23 |
def set_state_if_absent(key, value):
|
| 24 |
if key not in st.session_state:
|
| 25 |
st.session_state[key] = value
|
|
@@ -33,6 +35,9 @@ def reset_results(*args):
|
|
| 33 |
|
| 34 |
|
| 35 |
def create_ternary_plot(entailment_data):
|
|
|
|
|
|
|
|
|
|
| 36 |
hover_text = ""
|
| 37 |
for label, value in entailment_data.items():
|
| 38 |
hover_text += f"{label}: {value}<br>"
|
|
@@ -83,14 +88,11 @@ def makeAxis(title, tickangle):
|
|
| 83 |
}
|
| 84 |
|
| 85 |
|
| 86 |
-
def highlight_cols(s):
|
| 87 |
-
coldict = {"con": "#FFA07A", "neu": "#E5E4E2", "ent": "#a9d39e"}
|
| 88 |
-
if s.name in coldict.keys():
|
| 89 |
-
return ["background-color: {}".format(coldict[s.name])] * len(s)
|
| 90 |
-
return [""] * len(s)
|
| 91 |
-
|
| 92 |
-
|
| 93 |
def create_df_for_relevant_snippets(docs):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
rows = []
|
| 95 |
urls = {}
|
| 96 |
for doc in docs:
|
|
@@ -106,3 +108,10 @@ def create_df_for_relevant_snippets(docs):
|
|
| 106 |
rows.append(row)
|
| 107 |
df = pd.DataFrame(rows).style.apply(highlight_cols)
|
| 108 |
return df, urls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
"neutral": 'The knowledge base is <span style="color:darkgray">neutral</span> about your statement',
|
| 10 |
}
|
| 11 |
|
| 12 |
+
|
| 13 |
def build_sidebar():
|
| 14 |
+
sidebar = """
|
| 15 |
<h1 style='text-align: center'>Fact Checking 🎸 Rocks!</h1>
|
| 16 |
<div style='text-align: center'>
|
| 17 |
<i>Fact checking baseline combining dense retrieval and textual entailment</i>
|
|
|
|
| 21 |
"""
|
| 22 |
st.sidebar.markdown(sidebar, unsafe_allow_html=True)
|
| 23 |
|
| 24 |
+
|
| 25 |
def set_state_if_absent(key, value):
|
| 26 |
if key not in st.session_state:
|
| 27 |
st.session_state[key] = value
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
def create_ternary_plot(entailment_data):
|
| 38 |
+
"""
|
| 39 |
+
Create a Plotly ternary plot for the given entailment dict.
|
| 40 |
+
"""
|
| 41 |
hover_text = ""
|
| 42 |
for label, value in entailment_data.items():
|
| 43 |
hover_text += f"{label}: {value}<br>"
|
|
|
|
| 88 |
}
|
| 89 |
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
def create_df_for_relevant_snippets(docs):
|
| 92 |
+
"""
|
| 93 |
+
Create a dataframe that contains all relevant snippets.
|
| 94 |
+
Also returns the URLs
|
| 95 |
+
"""
|
| 96 |
rows = []
|
| 97 |
urls = {}
|
| 98 |
for doc in docs:
|
|
|
|
| 108 |
rows.append(row)
|
| 109 |
df = pd.DataFrame(rows).style.apply(highlight_cols)
|
| 110 |
return df, urls
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def highlight_cols(s):
|
| 114 |
+
coldict = {"con": "#FFA07A", "neu": "#E5E4E2", "ent": "#a9d39e"}
|
| 115 |
+
if s.name in coldict.keys():
|
| 116 |
+
return ["background-color: {}".format(coldict[s.name])] * len(s)
|
| 117 |
+
return [""] * len(s)
|
pages/Info.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
import streamlit as st
|
|
|
|
| 2 |
from app_utils.frontend_utils import build_sidebar
|
| 3 |
|
| 4 |
build_sidebar()
|
| 5 |
|
| 6 |
-
with open(
|
| 7 |
-
readme = fin.read().rpartition(
|
| 8 |
|
| 9 |
st.markdown(readme, unsafe_allow_html=True)
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
|
| 3 |
from app_utils.frontend_utils import build_sidebar
|
| 4 |
|
| 5 |
build_sidebar()
|
| 6 |
|
| 7 |
+
with open("README.md", "r") as fin:
|
| 8 |
+
readme = fin.read().rpartition("---")[-1]
|
| 9 |
|
| 10 |
st.markdown(readme, unsafe_allow_html=True)
|