Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from sentence_transformers import SentenceTransformer, util | |
import torch | |
from datasets import load_dataset | |
# Load the model and tokenizer | |
model_name = "google/flan-t5-xl" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
# Load the Gita dataset | |
ds = load_dataset("knowrohit07/gita_dataset") | |
chapters = ds['train']['Chapter'] | |
sentence_ranges = ds['train']['sentence_range'] | |
texts = ds['train']['Text'] | |
# Load a sentence transformer model for semantic search | |
sentence_model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Encode all texts for faster similarity search | |
text_embeddings = sentence_model.encode(texts, convert_to_tensor=True) | |
def find_relevant_texts(query, top_k=3): | |
query_embedding = sentence_model.encode(query, convert_to_tensor=True) | |
cos_scores = util.cos_sim(query_embedding, text_embeddings)[0] | |
top_results = torch.topk(cos_scores, k=top_k) | |
relevant_texts = [] | |
for score, idx in zip(top_results[0], top_results[1]): | |
relevant_texts.append(f"Chapter {chapters[idx]}, Verses {sentence_ranges[idx]}: {texts[idx]}") | |
return "\n\n".join(relevant_texts) | |
def generate_response(question): | |
relevant_texts = find_relevant_texts(question) | |
prompt = f"""Based on the following excerpts from the Bhagavad Gita, answer the question. | |
Relevant excerpts: | |
{relevant_texts} | |
Question: {question} | |
Answer:""" | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
outputs = model.generate(input_ids, max_new_tokens=200, do_sample=True, temperature=0.7, top_p=0.95) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
iface = gr.Interface( | |
fn=generate_response, | |
inputs=gr.Textbox(lines=2, placeholder="Enter your question about the Bhagavad Gita here..."), | |
outputs="text" | |
) | |
iface.launch() |