File size: 4,152 Bytes
3523908
137066c
3523908
137066c
3523908
 
137066c
 
 
3523908
05910f2
66ea0bf
137066c
 
 
ffc063e
3523908
 
 
 
 
 
137066c
 
 
 
3523908
137066c
 
 
 
 
 
 
3523908
 
137066c
 
 
 
3523908
 
137066c
 
 
 
 
 
 
 
 
3523908
137066c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3523908
 
 
 
137066c
 
 
3523908
 
137066c
06dbf7e
 
046385a
1fef6a8
 
 
06dbf7e
137066c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3523908
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import openai
import requests
import streamlit as st
from bs4 import BeautifulSoup
from sentence_transformers import CrossEncoder
from transformers import pipeline

all_documents = {}


def qa_gpt3(question, context):
    print(question, context)
    openai.api_key = st.secrets["openai_key"]

    response = openai.Completion.create(
        model="text-davinci-003",
        prompt=f"Answer given the following context: {context}\n\nQuestion: {question}",
        temperature=0.7,
        max_tokens=256,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0
    )
    print(response)
    return {'answer': response['choices'][0]['text'].strip()}


st.title('Document Question Answering System')

qa_model = None

crawl_urls = st.checkbox('Crawl?', value=False)

document_text = st.text_area(
    label="Links (Comma separated)", height=100,
    value='https://www.databricks.com/blog/2022/11/15/values-define-databricks-culture.html, https://databricks.com/product/databricks-runtime-for-machine-learning/faq'
)
query = st.text_input("Query")

qa_option = st.selectbox('Q/A Answerer', ('gpt3', 'a-ware/bart-squadv2'))
tokenizing = st.selectbox('How to Tokenize',
                          ("Don't (use entire body as document)", 'Newline (split by newline character)', 'Combo'))

if qa_option == 'gpt3':
    qa_model = qa_gpt3
else:
    qa_model = pipeline("question-answering", qa_option)
st.write(f'Using {qa_option} as the Q/A model')

encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')


def get_relevent_passage(question, documents):
    query_paragraph_list = [(question, para) for para in list(documents.keys()) if len(para.strip()) > 0]

    scores = encoder.predict(query_paragraph_list)
    top_5_indices = scores.argsort()[-5:]
    top_5_query_paragraph_list = [query_paragraph_list[i] for i in top_5_indices]
    top_5_query_paragraph_list.reverse()
    return top_5_query_paragraph_list[0][1]


def answer_question(query, context):
    answer = qa_model(question=query, context=context)['answer']
    return answer


def get_documents(document_text, crawl=crawl_urls):
    urls = document_text.split(',')
    for url in urls:
        st.write(f'Crawling {url}')
        if url in set(all_documents.values()):
            continue
        html = requests.get(url).text
        soup = BeautifulSoup(html, 'html.parser')

        if crawl:
            st.write('Give me a sec, crawling..')
            import re

            more_urls = re.findall('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+',
                                   html)
            more_urls = list(
                set([m for m in more_urls if m[-4] != '.' and m[-3] != '.' and m.split('/')[:3] == url.split('/')[:3]]))
            for more_url in more_urls:
                all_documents.update(get_documents(more_url, crawl=False))

        body = "\n".join([x for x in soup.body.get_text().split('\n') if len(x) > 10])
        print(body)

        if tokenizing == "Don't (use entire body as document)":
            document_paragraphs = [body]
        elif tokenizing == 'Newline (split by newline character)':
            document_paragraphs = [n for n in body.split('\n') if len(n) > 250]
        elif tokenizing == 'Combo':
            document_paragraphs = [body] + [n for n in body.split('\n') if len(n) > 250]

        for document_paragraph in document_paragraphs:
            all_documents[document_paragraph] = url

    return all_documents


if len(document_text.strip()) > 0 and len(query.strip()) > 0 and qa_model and encoder:
    st.write('Hmmm let me think about that..')
    document_text = document_text.strip()
    documents = get_documents(document_text)
    st.write(f'I am looking through {len(set(documents.values()))} sites')

    query = query.strip()
    context = get_relevent_passage(query, documents)
    answer = answer_question(query, context)

    relevant_url = documents[context]

    st.write('Check the answer below...with reference text')
    st.header("ANSWER: " + answer)
    st.subheader("REFERENCE: " + context)
    st.subheader("REFERENCE URL: " + relevant_url)