yashphogat4all commited on
Commit
65d665a
Β·
verified Β·
1 Parent(s): 49ae1df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -18
app.py CHANGED
@@ -1,28 +1,29 @@
1
- # app.py
2
-
3
  import os
4
- import json
5
  import gradio as gr
6
  import torch
7
- from langchain.vectorstores import FAISS
8
- from langchain.docstore.document import Document
9
- from langchain_huggingface import HuggingFaceEmbeddings
10
  from huggingface_hub import InferenceClient
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
12
 
13
- # Load embedding model and FAISS index
14
- embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
15
- db = FAISS.load_local("faiss_index_lawcounsel", embedding_model, allow_dangerous_deserialization=True)
16
 
17
- # Load CUAD from hf itself
18
- from datasets import load_dataset
19
  cuad_data = load_dataset("cuad")
 
 
 
 
 
 
 
20
 
21
- # Model setup flags
22
  USE_LLAMA = os.environ.get("USE_LLAMA", "false").lower() == "true"
23
  HF_TOKEN = os.environ.get("HF_TOKEN")
24
 
25
- # Define generation logic
26
  def load_llama():
27
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_auth_token=True)
28
  model = AutoModelForCausalLM.from_pretrained(
@@ -40,7 +41,6 @@ def generate_llama_response(prompt):
40
  def generate_mistral_response(prompt):
41
  return mistral_client.text_generation(prompt=prompt, max_new_tokens=300).strip()
42
 
43
- # Load selected model
44
  if USE_LLAMA:
45
  llama_tokenizer, llama_model = load_llama()
46
  generate_response = generate_llama_response
@@ -51,9 +51,9 @@ else:
51
  )
52
  generate_response = generate_mistral_response
53
 
54
- # Main QA function
55
  def answer_question(user_query):
56
- docs = db.similarity_search(user_query, k=3)
57
  context = "\n".join([doc.page_content for doc in docs])
58
  prompt = f"""[Context]
59
  {context}
@@ -65,13 +65,13 @@ def answer_question(user_query):
65
  """
66
  return generate_response(prompt)
67
 
68
- # Gradio UI
69
  iface = gr.Interface(
70
  fn=answer_question,
71
  inputs=gr.Textbox(placeholder="Ask a question about your contract..."),
72
  outputs=gr.Textbox(label="Answer"),
73
  title="LawCounsel AI",
74
- description="Choose clause-related questions from your uploaded contract. Powered by RAG with Mistral or LLaMA.",
75
  )
76
 
77
  iface.launch()
 
 
 
1
  import os
 
2
  import gradio as gr
3
  import torch
4
+ from datasets import load_dataset
 
 
5
  from huggingface_hub import InferenceClient
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain.embeddings import HuggingFaceEmbeddings
9
+ from langchain.vectorstores import FAISS
10
 
11
+ # === 1. Build the FAISS vectorstore from CUAD ===
12
+ print("πŸ”„ Loading CUAD and building index...")
 
13
 
 
 
14
  cuad_data = load_dataset("cuad")
15
+ texts = [item["text"] for item in cuad_data["train"] if "text" in item]
16
+
17
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
18
+ docs = splitter.create_documents(texts)
19
+
20
+ embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
21
+ vectorstore = FAISS.from_documents(docs, embedding_model)
22
 
23
+ # === 2. Model setup ===
24
  USE_LLAMA = os.environ.get("USE_LLAMA", "false").lower() == "true"
25
  HF_TOKEN = os.environ.get("HF_TOKEN")
26
 
 
27
  def load_llama():
28
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_auth_token=True)
29
  model = AutoModelForCausalLM.from_pretrained(
 
41
  def generate_mistral_response(prompt):
42
  return mistral_client.text_generation(prompt=prompt, max_new_tokens=300).strip()
43
 
 
44
  if USE_LLAMA:
45
  llama_tokenizer, llama_model = load_llama()
46
  generate_response = generate_llama_response
 
51
  )
52
  generate_response = generate_mistral_response
53
 
54
+ # === 3. Main QA function ===
55
  def answer_question(user_query):
56
+ docs = vectorstore.similarity_search(user_query, k=3)
57
  context = "\n".join([doc.page_content for doc in docs])
58
  prompt = f"""[Context]
59
  {context}
 
65
  """
66
  return generate_response(prompt)
67
 
68
+ # === 4. Gradio UI ===
69
  iface = gr.Interface(
70
  fn=answer_question,
71
  inputs=gr.Textbox(placeholder="Ask a question about your contract..."),
72
  outputs=gr.Textbox(label="Answer"),
73
  title="LawCounsel AI",
74
+ description="Ask clause-specific questions from CUAD-trained contracts. Powered by RAG using Mistral or LLaMA.",
75
  )
76
 
77
  iface.launch()