Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import torch | |
from datasets import load_dataset | |
from huggingface_hub import InferenceClient | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
# === 1. Build the FAISS vectorstore from CUAD === | |
print(" Loading CUAD and building index...") | |
#new | |
from datasets import load_dataset | |
cuad_data = load_dataset("lex_glue", "cuad") | |
texts = [item["text"] for item in cuad_data["train"] if "text" in item] | |
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
docs = splitter.create_documents(texts) | |
embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
vectorstore = FAISS.from_documents(docs, embedding_model) | |
# === 2. Model setup === | |
USE_LLAMA = os.environ.get("USE_LLAMA", "false").lower() == "true" | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
def load_llama(): | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_auth_token=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
"meta-llama/Llama-2-7b-chat-hf", | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
return tokenizer, model | |
def generate_llama_response(prompt): | |
inputs = llama_tokenizer(prompt, return_tensors="pt").to("cuda") | |
outputs = llama_model.generate(**inputs, max_new_tokens=300) | |
return llama_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def generate_mistral_response(prompt): | |
return mistral_client.text_generation(prompt=prompt, max_new_tokens=300).strip() | |
if USE_LLAMA: | |
llama_tokenizer, llama_model = load_llama() | |
generate_response = generate_llama_response | |
else: | |
mistral_client = InferenceClient( | |
model="mistralai/Mistral-7B-Instruct-v0.1", | |
token=HF_TOKEN | |
) | |
generate_response = generate_mistral_response | |
# === 3. Main QA function === | |
def answer_question(user_query): | |
docs = vectorstore.similarity_search(user_query, k=3) | |
context = "\n".join([doc.page_content for doc in docs]) | |
prompt = f"""[Context] | |
{context} | |
[User Question] | |
{user_query} | |
[Answer] | |
""" | |
return generate_response(prompt) | |
# === 4. Gradio UI === | |
iface = gr.Interface( | |
fn=answer_question, | |
inputs=gr.Textbox(placeholder="Ask a question about your contract..."), | |
outputs=gr.Textbox(label="Answer"), | |
title="LawCounsel AI", | |
description="Ask clause-specific questions from CUAD-trained contracts. Powered by RAG using Mistral or LLaMA.", | |
) | |
iface.launch() | |