File size: 2,945 Bytes
a05279b
 
 
 
 
8eef9b2
a05279b
4d794c6
a05279b
4d794c6
 
 
 
 
 
 
a05279b
 
 
 
 
 
 
 
 
 
 
 
 
4d794c6
 
 
 
 
 
a05279b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d794c6
a05279b
 
 
4d794c6
a05279b
 
4d794c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2de00c6
4d794c6
2de00c6
4d794c6
2de00c6
4d794c6
 
 
 
 
 
 
a05279b
4d794c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import faiss
import gradio as gr
import numpy as np
import pandas as pd
from datasets import load_dataset
from sentence_transformers import SentenceTransformer

idx = 0

dataset = load_dataset("tollefj/rettsavgjoerelser_100samples_embeddings")
model = SentenceTransformer("NbAiLab/nb-sbert-base")
df = dataset["train"].to_pandas()


def build_doc_frame(df, idx):
    doc = df.iloc[idx]
    # as df:
    doc_df = pd.DataFrame(doc).T
    # keep only sentences + embedding:
    doc_df = doc_df[["url", "sentences", "embedding"]]
    # unpack the sentences and embedding in separate rows
    doc_df = doc_df.explode(["sentences", "embedding"])
    return doc_df


def get_doc_embeddings(doc):
    return np.array(doc.embedding.tolist(), dtype="float32")


def faiss_search(doc_url, query_str, K=5):
    global idx
    # find idx from url:
    doc_idx = df[df.url == doc_url].index[0]
    idx = int(doc_idx)
    newdoc = build_doc_frame(df, idx)
    embeddings = get_doc_embeddings(newdoc)

    faiss.normalize_L2(embeddings)
    index = faiss.IndexFlatIP(768)
    index.add(embeddings)

    target_emb = model.encode([query_str])
    target_emb = np.array([target_emb.reshape(-1)])
    faiss.normalize_L2(target_emb)

    D, I = index.search(np.array(target_emb), K)
    print(list(zip(D[0], I[0])))

    # prettyprint the results:
    pretty_results = []
    for idx, score in zip(I[0], D[0]):
        pretty_results.append((round(float(score), 3), newdoc.iloc[idx].sentences))
    pretty_results_str = "\n".join([f"Score: {score}\t\t{sent}" for score, sent in pretty_results])
    top_k_str = f"Top {K} results for: {query_str}"

    # return str:
    return f"{top_k_str}\n{pretty_results_str}"


# def DropdownSummary():
#     next_opts = df.iloc[idx].summary.tolist()
#     return gr.Dropdown.update(choices=next_opts, label="Velg fra oppsummeringene")


dropdown_opts = [doc.url for idx, doc in df.iterrows()]

with gr.Blocks() as demo:
    gr.Label("Lovdata rettsavgjørelser - semantisk søk")

    case_dropdown = gr.Dropdown(label="Velg en rettsavgjørelse", choices=dropdown_opts, default=dropdown_opts[0])
    # when case_dropdown changes, update the summary dropdown:
    # idx_label = gr.Label(f"Current index: {idx}")

    query = gr.Textbox(
        label="Søk etter setninger",
        lines=1,
        placeholder="Kollisjon mellom to kjøretøy.",
    )
    k_slider = gr.Slider(minimum=1, maximum=10, label="Antall treff", value=5, step=1)

    search_btn = gr.Button("Søk")

    output = gr.Textbox(label="Resultater", lines=10)

    # from the selected URL, find the index in the df:
    search_btn.click(faiss_search, inputs=[case_dropdown, query, k_slider], outputs=[output])

    # clear_btn.click(None, inputs=[None, None], outputs=None)
    # search_btn.click(faiss_search, inputs=[None, None, None], outputs=["text"])
    # search_btn.click(faiss_search, inputs=[idx, query, k_slider], outputs=["text"])

demo.launch()