Tarun-1999M's picture
Update app.py
9845739 verified
# AUTOGENERATED! DO NOT EDIT! File to edit: Project_CUDA_Enabled (1).ipynb.
# %% auto 0
__all__ = ['model_checkpoint', 'tokenizer', 'model', 'dataset', 'train_dataset', 'iface', 'transform', 'cls_pooling',
'get_embeddings', 'search_arxiv']
# %% Project_CUDA_Enabled (1).ipynb 49
import gradio as gr
from datasets import load_dataset
import numpy as np
from transformers import AutoTokenizer, AutoModel
model_checkpoint = 'sentence-transformers/all-MiniLM-L6-v2'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModel.from_pretrained(model_checkpoint)
# Load the dataset from Hugging Face
dataset = load_dataset('Tarun-1999M/arxiv_cs_lg_embeddings')
train_dataset = dataset['train']
# Ensure embeddings are converted to NumPy arrays on-the-fly using set_transform
def transform(example):
example['embeddings'] = np.array(example['embeddings'], dtype=np.float32)
return example
train_dataset.set_transform(transform)
# Add FAISS index
train_dataset.add_faiss_index(column='embeddings')
#We are taking the embeddings of CLS token, which contains the information of all the tokens in the sentence
def cls_pooling(model_output):
return model_output.last_hidden_state[:,0]
# Function to get the embeddings for the query
def get_embeddings(query_list):
encoded_input = tokenizer(query_list, padding=True, truncation=True, return_tensors='pt')
model_output = model(**encoded_input)
return cls_pooling(model_output)
# Function to search the ArXiv papers
def search_arxiv(query):
# Get the embedding for the query
question_embedding = get_embeddings([query]).cpu().detach().numpy()
# Search for similar papers
scores, samples = train_dataset.get_nearest_examples("embeddings", question_embedding, k=5)
# Sort the results by scores in descending order
sorted_results = sorted(zip(scores, samples['title'], samples['abstract']), reverse=True)
# Prepare and format the results for display
results = []
for score, title, abstract in sorted_results:
result = f"\n**Title:** {title}\n**Abstract:** {abstract}\n**Score:** {score:.4f}"
results.append(result)
return "\n\n".join(results)
# Dataset information
dataset_info = """
### About the Dataset
This dataset contains a subset of ArXiv papers with the "cs.LG" tag, indicating that the paper is about Machine Learning. The core dataset is filtered from the full ArXiv dataset hosted on Kaggle: [ArXiv Dataset on Kaggle](https://www.kaggle.com/datasets/Cornell-University/arxiv). The original dataset contains roughly 2 million papers, and this dataset contains approximately 100,000 papers after category filtering.
The dataset is maintained by making requests to the ArXiv API. The current iteration only includes the title and abstract of each paper.
**Dataset Source:** The dataset is sourced from Hugging Face: [CShorten/ML-ArXiv-Papers](https://huggingface.co/datasets/CShorten/ML-ArXiv-Papers).
"""
# Create the Gradio interface
iface = gr.Interface(
fn=search_arxiv,
inputs=gr.components.Textbox(lines=1, placeholder="Enter your query..."),
outputs="markdown",
title="Semantic Search in ArXiv ML Papers",
description="Enter a query to find relevant ML papers from the ArXiv dataset.",
article=dataset_info,
theme="huggingface", # Use a built-in theme
css="""
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background-color: #FDF6E3; /* Warm background color */
color: #333333; /* Text color */
}
.input-textbox {
background-color: #FFFAF0; /* Lighter warm input box */
color: #333333; /* Input text color */
}
h1, h2, h3, h4, h5, h6 {
color: #D2691E; /* Warm, inviting heading color */
}
"""
)
# Launch the interface
iface.launch(share=True)