v2-dualmodel / app.py
yashphogat4all's picture
Update app.py
26a68dc verified
raw
history blame
2.61 kB
import os
import gradio as gr
import torch
from datasets import load_dataset
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
# === 1. Build the FAISS vectorstore from CUAD ===
print(" Loading CUAD and building index...")
#new
from datasets import load_dataset
cuad_data = load_dataset("lex_glue", "cuad")
texts = [item["text"] for item in cuad_data["train"] if "text" in item]
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = splitter.create_documents(texts)
embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectorstore = FAISS.from_documents(docs, embedding_model)
# === 2. Model setup ===
USE_LLAMA = os.environ.get("USE_LLAMA", "false").lower() == "true"
HF_TOKEN = os.environ.get("HF_TOKEN")
def load_llama():
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_auth_token=True)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf",
device_map="auto",
torch_dtype=torch.float16
)
return tokenizer, model
def generate_llama_response(prompt):
inputs = llama_tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = llama_model.generate(**inputs, max_new_tokens=300)
return llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
def generate_mistral_response(prompt):
return mistral_client.text_generation(prompt=prompt, max_new_tokens=300).strip()
if USE_LLAMA:
llama_tokenizer, llama_model = load_llama()
generate_response = generate_llama_response
else:
mistral_client = InferenceClient(
model="mistralai/Mistral-7B-Instruct-v0.1",
token=HF_TOKEN
)
generate_response = generate_mistral_response
# === 3. Main QA function ===
def answer_question(user_query):
docs = vectorstore.similarity_search(user_query, k=3)
context = "\n".join([doc.page_content for doc in docs])
prompt = f"""[Context]
{context}
[User Question]
{user_query}
[Answer]
"""
return generate_response(prompt)
# === 4. Gradio UI ===
iface = gr.Interface(
fn=answer_question,
inputs=gr.Textbox(placeholder="Ask a question about your contract..."),
outputs=gr.Textbox(label="Answer"),
title="LawCounsel AI",
description="Ask clause-specific questions from CUAD-trained contracts. Powered by RAG using Mistral or LLaMA.",
)
iface.launch()