Spaces:
Runtime error
Runtime error
File size: 2,606 Bytes
72560be d439146 72560be 65d665a d439146 72560be 65d665a 26a68dc d439146 65d665a 6a32833 e07d593 d439146 65d665a d439146 65d665a 72560be d439146 72560be d439146 72560be d439146 72560be d439146 72560be d439146 65d665a 72560be 65d665a 72560be d439146 72560be d439146 72560be d439146 72560be d439146 65d665a 72560be 65d665a 72560be d439146 72560be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import os
import gradio as gr
import torch
from datasets import load_dataset
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
# === 1. Build the FAISS vectorstore from CUAD ===
print(" Loading CUAD and building index...")
#new
from datasets import load_dataset
cuad_data = load_dataset("lex_glue", "cuad")
texts = [item["text"] for item in cuad_data["train"] if "text" in item]
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = splitter.create_documents(texts)
embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectorstore = FAISS.from_documents(docs, embedding_model)
# === 2. Model setup ===
USE_LLAMA = os.environ.get("USE_LLAMA", "false").lower() == "true"
HF_TOKEN = os.environ.get("HF_TOKEN")
def load_llama():
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_auth_token=True)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf",
device_map="auto",
torch_dtype=torch.float16
)
return tokenizer, model
def generate_llama_response(prompt):
inputs = llama_tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = llama_model.generate(**inputs, max_new_tokens=300)
return llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
def generate_mistral_response(prompt):
return mistral_client.text_generation(prompt=prompt, max_new_tokens=300).strip()
if USE_LLAMA:
llama_tokenizer, llama_model = load_llama()
generate_response = generate_llama_response
else:
mistral_client = InferenceClient(
model="mistralai/Mistral-7B-Instruct-v0.1",
token=HF_TOKEN
)
generate_response = generate_mistral_response
# === 3. Main QA function ===
def answer_question(user_query):
docs = vectorstore.similarity_search(user_query, k=3)
context = "\n".join([doc.page_content for doc in docs])
prompt = f"""[Context]
{context}
[User Question]
{user_query}
[Answer]
"""
return generate_response(prompt)
# === 4. Gradio UI ===
iface = gr.Interface(
fn=answer_question,
inputs=gr.Textbox(placeholder="Ask a question about your contract..."),
outputs=gr.Textbox(label="Answer"),
title="LawCounsel AI",
description="Ask clause-specific questions from CUAD-trained contracts. Powered by RAG using Mistral or LLaMA.",
)
iface.launch()
|