import os import gradio as gr import torch from datasets import load_dataset from huggingface_hub import InferenceClient from transformers import AutoTokenizer, AutoModelForCausalLM from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS # === 1. Build the FAISS vectorstore from CUAD === print(" Loading CUAD and building index...") #new from datasets import load_dataset cuad_data = load_dataset("lex_glue", "cuad") texts = [item["text"] for item in cuad_data["train"] if "text" in item] splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) docs = splitter.create_documents(texts) embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") vectorstore = FAISS.from_documents(docs, embedding_model) # === 2. Model setup === USE_LLAMA = os.environ.get("USE_LLAMA", "false").lower() == "true" HF_TOKEN = os.environ.get("HF_TOKEN") def load_llama(): tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_auth_token=True) model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-chat-hf", device_map="auto", torch_dtype=torch.float16 ) return tokenizer, model def generate_llama_response(prompt): inputs = llama_tokenizer(prompt, return_tensors="pt").to("cuda") outputs = llama_model.generate(**inputs, max_new_tokens=300) return llama_tokenizer.decode(outputs[0], skip_special_tokens=True) def generate_mistral_response(prompt): return mistral_client.text_generation(prompt=prompt, max_new_tokens=300).strip() if USE_LLAMA: llama_tokenizer, llama_model = load_llama() generate_response = generate_llama_response else: mistral_client = InferenceClient( model="mistralai/Mistral-7B-Instruct-v0.1", token=HF_TOKEN ) generate_response = generate_mistral_response # === 3. Main QA function === def answer_question(user_query): docs = vectorstore.similarity_search(user_query, k=3) context = "\n".join([doc.page_content for doc in docs]) prompt = f"""[Context] {context} [User Question] {user_query} [Answer] """ return generate_response(prompt) # === 4. Gradio UI === iface = gr.Interface( fn=answer_question, inputs=gr.Textbox(placeholder="Ask a question about your contract..."), outputs=gr.Textbox(label="Answer"), title="LawCounsel AI", description="Ask clause-specific questions from CUAD-trained contracts. Powered by RAG using Mistral or LLaMA.", ) iface.launch()