Spaces:
Runtime error
Runtime error
import faiss | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer, util | |
def build_doc_frame(df, idx=0): | |
doc = df.iloc[0] | |
# as df: | |
doc_df = pd.DataFrame(doc).T | |
# keep only sentences + embedding: | |
doc_df = doc_df[["url", "sentences", "embedding"]] | |
# unpack the sentences and embedding in separate rows | |
doc_df = doc_df.explode(["sentences", "embedding"]) | |
return doc_df | |
def get_doc_embeddings(doc): | |
return np.array(doc.embedding.tolist(), dtype="float32") | |
def faiss_search(doc_idx, query_str, K=5): | |
# doc_idx is a choice option of (idx, text) | |
idx = doc_idx[0] - 1 | |
newdoc = build_doc_frame(df, idx=idx) | |
embeddings = get_doc_embeddings(newdoc) | |
faiss.normalize_L2(embeddings) | |
index = faiss.IndexFlatIP(768) | |
index.add(embeddings) | |
query_str = "Skade mellom kjøretøy" | |
target_emb = model.encode([query_str]) | |
target_emb = np.array([target_emb.reshape(-1)]) | |
faiss.normalize_L2(target_emb) | |
D, I = index.search(np.array(target_emb), K) | |
print(list(zip(D[0], I[0]))) | |
# prettyprint the results: | |
pretty_results = [] | |
for idx, score in zip(I[0], D[0]): | |
pretty_results.append((round(float(score), 3), newdoc.iloc[idx].sentences)) | |
pretty_results_str = "\n".join([f"{score}\t{sent}" for score, sent in pretty_results]) | |
top_k_str = f"Top {K} results for: {query_str}" | |
underlines = "__" * 40 | |
# return str: | |
return f"{top_k_str}\n{pretty_results_str}\n{underlines}" | |
dataset = load_dataset("tollefj/rettsavgjoerelser_100samples_embeddings") | |
model = SentenceTransformer("NbAiLab/nb-sbert-base") | |
df = dataset["train"].to_pandas() | |
dropdown_opts = [(idx + 1, f"\t{doc.summary[0][:60]}...") for idx, doc in df.iterrows()] | |
iface = gr.Interface( | |
fn=faiss_search, | |
inputs=[ | |
gr.Dropdown(label="Select a court case", choices=dropdown_opts), | |
gr.Textbox(lines=2, placeholder="Your query here..."), | |
gr.Slider(minimum=1, maximum=10, label="Number of matches", value=5), | |
], | |
outputs="text", | |
title="Lovdata rettsavgjørelser - semantisk søk", | |
description="Velg en rettsak og søk for å hente ut lignende setninger i saken", | |
) | |
iface.launch() | |