Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,28 +1,29 @@
|
|
1 |
-
# app.py
|
2 |
-
|
3 |
import os
|
4 |
-
import json
|
5 |
import gradio as gr
|
6 |
import torch
|
7 |
-
from
|
8 |
-
from langchain.docstore.document import Document
|
9 |
-
from langchain_huggingface import HuggingFaceEmbeddings
|
10 |
from huggingface_hub import InferenceClient
|
11 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
|
|
|
12 |
|
13 |
-
#
|
14 |
-
|
15 |
-
db = FAISS.load_local("faiss_index_lawcounsel", embedding_model, allow_dangerous_deserialization=True)
|
16 |
|
17 |
-
# Load CUAD from hf itself
|
18 |
-
from datasets import load_dataset
|
19 |
cuad_data = load_dataset("cuad")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
# Model setup
|
22 |
USE_LLAMA = os.environ.get("USE_LLAMA", "false").lower() == "true"
|
23 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
24 |
|
25 |
-
# Define generation logic
|
26 |
def load_llama():
|
27 |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_auth_token=True)
|
28 |
model = AutoModelForCausalLM.from_pretrained(
|
@@ -40,7 +41,6 @@ def generate_llama_response(prompt):
|
|
40 |
def generate_mistral_response(prompt):
|
41 |
return mistral_client.text_generation(prompt=prompt, max_new_tokens=300).strip()
|
42 |
|
43 |
-
# Load selected model
|
44 |
if USE_LLAMA:
|
45 |
llama_tokenizer, llama_model = load_llama()
|
46 |
generate_response = generate_llama_response
|
@@ -51,9 +51,9 @@ else:
|
|
51 |
)
|
52 |
generate_response = generate_mistral_response
|
53 |
|
54 |
-
# Main QA function
|
55 |
def answer_question(user_query):
|
56 |
-
docs =
|
57 |
context = "\n".join([doc.page_content for doc in docs])
|
58 |
prompt = f"""[Context]
|
59 |
{context}
|
@@ -65,13 +65,13 @@ def answer_question(user_query):
|
|
65 |
"""
|
66 |
return generate_response(prompt)
|
67 |
|
68 |
-
# Gradio UI
|
69 |
iface = gr.Interface(
|
70 |
fn=answer_question,
|
71 |
inputs=gr.Textbox(placeholder="Ask a question about your contract..."),
|
72 |
outputs=gr.Textbox(label="Answer"),
|
73 |
title="LawCounsel AI",
|
74 |
-
description="
|
75 |
)
|
76 |
|
77 |
iface.launch()
|
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
+
from datasets import load_dataset
|
|
|
|
|
5 |
from huggingface_hub import InferenceClient
|
6 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
7 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
9 |
+
from langchain.vectorstores import FAISS
|
10 |
|
11 |
+
# === 1. Build the FAISS vectorstore from CUAD ===
|
12 |
+
print("π Loading CUAD and building index...")
|
|
|
13 |
|
|
|
|
|
14 |
cuad_data = load_dataset("cuad")
|
15 |
+
texts = [item["text"] for item in cuad_data["train"] if "text" in item]
|
16 |
+
|
17 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
18 |
+
docs = splitter.create_documents(texts)
|
19 |
+
|
20 |
+
embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
21 |
+
vectorstore = FAISS.from_documents(docs, embedding_model)
|
22 |
|
23 |
+
# === 2. Model setup ===
|
24 |
USE_LLAMA = os.environ.get("USE_LLAMA", "false").lower() == "true"
|
25 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
26 |
|
|
|
27 |
def load_llama():
|
28 |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_auth_token=True)
|
29 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
41 |
def generate_mistral_response(prompt):
|
42 |
return mistral_client.text_generation(prompt=prompt, max_new_tokens=300).strip()
|
43 |
|
|
|
44 |
if USE_LLAMA:
|
45 |
llama_tokenizer, llama_model = load_llama()
|
46 |
generate_response = generate_llama_response
|
|
|
51 |
)
|
52 |
generate_response = generate_mistral_response
|
53 |
|
54 |
+
# === 3. Main QA function ===
|
55 |
def answer_question(user_query):
|
56 |
+
docs = vectorstore.similarity_search(user_query, k=3)
|
57 |
context = "\n".join([doc.page_content for doc in docs])
|
58 |
prompt = f"""[Context]
|
59 |
{context}
|
|
|
65 |
"""
|
66 |
return generate_response(prompt)
|
67 |
|
68 |
+
# === 4. Gradio UI ===
|
69 |
iface = gr.Interface(
|
70 |
fn=answer_question,
|
71 |
inputs=gr.Textbox(placeholder="Ask a question about your contract..."),
|
72 |
outputs=gr.Textbox(label="Answer"),
|
73 |
title="LawCounsel AI",
|
74 |
+
description="Ask clause-specific questions from CUAD-trained contracts. Powered by RAG using Mistral or LLaMA.",
|
75 |
)
|
76 |
|
77 |
iface.launch()
|