File size: 2,606 Bytes
72560be
d439146
72560be
65d665a
d439146
72560be
65d665a
26a68dc
 
d439146
65d665a
6a32833
e07d593
 
 
 
d439146
65d665a
 
 
 
 
 
 
d439146
65d665a
72560be
 
d439146
72560be
 
 
 
 
 
 
 
d439146
72560be
 
 
 
d439146
72560be
 
d439146
72560be
 
 
 
 
 
 
 
 
d439146
65d665a
72560be
65d665a
72560be
 
 
d439146
72560be
 
d439146
72560be
d439146
72560be
d439146
65d665a
72560be
 
 
 
 
65d665a
72560be
d439146
72560be
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import gradio as gr
import torch
from datasets import load_dataset
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS

# === 1. Build the FAISS vectorstore from CUAD ===
print(" Loading CUAD and building index...")
#new
from datasets import load_dataset

cuad_data = load_dataset("lex_glue", "cuad")

texts = [item["text"] for item in cuad_data["train"] if "text" in item]

splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = splitter.create_documents(texts)

embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectorstore = FAISS.from_documents(docs, embedding_model)

# === 2. Model setup ===
USE_LLAMA = os.environ.get("USE_LLAMA", "false").lower() == "true"
HF_TOKEN = os.environ.get("HF_TOKEN")

def load_llama():
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_auth_token=True)
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-7b-chat-hf",
        device_map="auto",
        torch_dtype=torch.float16
    )
    return tokenizer, model

def generate_llama_response(prompt):
    inputs = llama_tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = llama_model.generate(**inputs, max_new_tokens=300)
    return llama_tokenizer.decode(outputs[0], skip_special_tokens=True)

def generate_mistral_response(prompt):
    return mistral_client.text_generation(prompt=prompt, max_new_tokens=300).strip()

if USE_LLAMA:
    llama_tokenizer, llama_model = load_llama()
    generate_response = generate_llama_response
else:
    mistral_client = InferenceClient(
        model="mistralai/Mistral-7B-Instruct-v0.1",
        token=HF_TOKEN
    )
    generate_response = generate_mistral_response

# === 3. Main QA function ===
def answer_question(user_query):
    docs = vectorstore.similarity_search(user_query, k=3)
    context = "\n".join([doc.page_content for doc in docs])
    prompt = f"""[Context]
{context}

[User Question]
{user_query}

[Answer]
"""
    return generate_response(prompt)

# === 4. Gradio UI ===
iface = gr.Interface(
    fn=answer_question,
    inputs=gr.Textbox(placeholder="Ask a question about your contract..."),
    outputs=gr.Textbox(label="Answer"),
    title="LawCounsel AI",
    description="Ask clause-specific questions from CUAD-trained contracts. Powered by RAG using Mistral or LLaMA.",
)

iface.launch()