|
|
|
|
|
|
|
__all__ = ['model_checkpoint', 'tokenizer', 'model', 'dataset', 'train_dataset', 'iface', 'transform', 'cls_pooling', |
|
'get_embeddings', 'search_arxiv'] |
|
|
|
|
|
import gradio as gr |
|
from datasets import load_dataset |
|
import numpy as np |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
model_checkpoint = 'sentence-transformers/all-MiniLM-L6-v2' |
|
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) |
|
model = AutoModel.from_pretrained(model_checkpoint) |
|
|
|
|
|
dataset = load_dataset('Tarun-1999M/arxiv_cs_lg_embeddings') |
|
train_dataset = dataset['train'] |
|
|
|
|
|
def transform(example): |
|
example['embeddings'] = np.array(example['embeddings'], dtype=np.float32) |
|
return example |
|
|
|
train_dataset.set_transform(transform) |
|
|
|
|
|
train_dataset.add_faiss_index(column='embeddings') |
|
|
|
|
|
def cls_pooling(model_output): |
|
return model_output.last_hidden_state[:,0] |
|
|
|
|
|
def get_embeddings(query_list): |
|
encoded_input = tokenizer(query_list, padding=True, truncation=True, return_tensors='pt') |
|
model_output = model(**encoded_input) |
|
return cls_pooling(model_output) |
|
|
|
|
|
def search_arxiv(query): |
|
|
|
question_embedding = get_embeddings([query]).cpu().detach().numpy() |
|
|
|
|
|
scores, samples = train_dataset.get_nearest_examples("embeddings", question_embedding, k=5) |
|
|
|
|
|
sorted_results = sorted(zip(scores, samples['title'], samples['abstract']), reverse=True) |
|
|
|
|
|
results = [] |
|
for score, title, abstract in sorted_results: |
|
result = f"\n**Title:** {title}\n**Abstract:** {abstract}\n**Score:** {score:.4f}" |
|
results.append(result) |
|
|
|
return "\n\n".join(results) |
|
|
|
|
|
dataset_info = """ |
|
### About the Dataset |
|
|
|
This dataset contains a subset of ArXiv papers with the "cs.LG" tag, indicating that the paper is about Machine Learning. The core dataset is filtered from the full ArXiv dataset hosted on Kaggle: [ArXiv Dataset on Kaggle](https://www.kaggle.com/datasets/Cornell-University/arxiv). The original dataset contains roughly 2 million papers, and this dataset contains approximately 100,000 papers after category filtering. |
|
|
|
The dataset is maintained by making requests to the ArXiv API. The current iteration only includes the title and abstract of each paper. |
|
|
|
**Dataset Source:** The dataset is sourced from Hugging Face: [CShorten/ML-ArXiv-Papers](https://huggingface.co/datasets/CShorten/ML-ArXiv-Papers). |
|
|
|
""" |
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=search_arxiv, |
|
inputs=gr.components.Textbox(lines=1, placeholder="Enter your query..."), |
|
outputs="markdown", |
|
title="Semantic Search in ArXiv ML Papers", |
|
description="Enter a query to find relevant ML papers from the ArXiv dataset.", |
|
article=dataset_info, |
|
theme="huggingface", |
|
css=""" |
|
body { |
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
|
background-color: #FDF6E3; /* Warm background color */ |
|
color: #333333; /* Text color */ |
|
} |
|
.input-textbox { |
|
background-color: #FFFAF0; /* Lighter warm input box */ |
|
color: #333333; /* Input text color */ |
|
} |
|
h1, h2, h3, h4, h5, h6 { |
|
color: #D2691E; /* Warm, inviting heading color */ |
|
} |
|
""" |
|
) |
|
|
|
|
|
iface.launch(share=True) |
|
|
|
|