Spaces:
Running
Running
import os | |
import re | |
import io | |
import numpy as np | |
import networkx as nx | |
from sympy import symbols | |
from galgebra.ga import Ga | |
from tensorflow.keras.models import Sequential | |
from tensorflow.keras.layers import Dense, Input | |
import tensorflow as tf | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
pipeline | |
) | |
import torch | |
from PyPDF2 import PdfReader | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import streamlit as st | |
# Optionally, set environment variables to optimize CPU parallelism. | |
os.environ["OMP_NUM_THREADS"] = "4" # Adjust to your available cores. | |
os.environ["MKL_NUM_THREADS"] = "4" | |
# Setup IBM Granite model without 8-bit quantization (for CPU). | |
model_name = "ibm-granite/granite-3.1-2b-instruct" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="balanced", # Using balanced CPU mapping. | |
torch_dtype=torch.float16 # Use float16 if supported. | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
DIM = 5000 | |
# We use a lower max token count for faster generation. | |
DEFAULT_MAX_TOKENS = 1000 | |
coords = symbols('e1 e2 e3') | |
ga = Ga('e1 e2 e3', g=[1, 1, 1]) | |
# Cache the knowledge graph. | |
KNOWLEDGE_GRAPH = nx.Graph() | |
KNOWLEDGE_GRAPH.add_edges_from([ | |
("Ambiguous Terms", "Risk of Dispute"), | |
("Lack of Termination Clause", "Prolonged Obligations"), | |
("Non-compliance", "Legal Penalties"), | |
("Confidentiality Breaches", "Reputational Damage"), | |
("Inadequate Indemnification", "High Liability"), | |
("Unclear Jurisdiction", "Compliance Issues"), | |
("Force Majeure", "Risk Mitigation"), | |
("Data Privacy", "Regulatory Compliance"), | |
("Penalty Clauses", "Financial Risk"), | |
("Intellectual Property", "Contract Disputes") | |
]) | |
# Caches for file content and summaries. | |
FILE_CACHE = {} | |
SUMMARY_CACHE = {} | |
# Initialize a summarization pipeline on CPU (using a lightweight model). | |
summarizer = pipeline("summarization", model="t5-small", tokenizer="t5-small", device=-1) | |
def read_file(file_obj): | |
""" | |
Reads content from a file. Supports both file paths (str) and Streamlit uploaded files. | |
""" | |
if isinstance(file_obj, str): | |
if file_obj in FILE_CACHE: | |
return FILE_CACHE[file_obj] | |
if not os.path.exists(file_obj): | |
st.error(f"File not found: {file_obj}") | |
return "" | |
content = "" | |
try: | |
if file_obj.lower().endswith(".pdf"): | |
reader = PdfReader(file_obj) | |
for page in reader.pages: | |
content += page.extract_text() + "\n" | |
else: | |
with open(file_obj, "r", encoding="utf-8") as f: | |
content = f.read() + "\n" | |
FILE_CACHE[file_obj] = content | |
except Exception as e: | |
st.error(f"Error reading {file_obj}: {e}") | |
content = "" | |
return content | |
else: | |
# Assume it's an uploaded file (BytesIO). | |
file_name = file_obj.name | |
if file_name in FILE_CACHE: | |
return FILE_CACHE[file_name] | |
try: | |
if file_name.lower().endswith(".pdf"): | |
reader = PdfReader(io.BytesIO(file_obj.read())) | |
content = "" | |
for page in reader.pages: | |
content += page.extract_text() + "\n" | |
else: | |
content = file_obj.getvalue().decode("utf-8") | |
FILE_CACHE[file_name] = content | |
return content | |
except Exception as e: | |
st.error(f"Error reading uploaded file {file_name}: {e}") | |
return "" | |
def summarize_text(text, chunk_size=2000): | |
""" | |
Summarize text if it is longer than chunk_size. | |
Uses parallel processing for multiple chunks. | |
(Reducing chunk_size may speed up summarization on CPU.) | |
""" | |
if len(text) <= chunk_size: | |
return text | |
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)] | |
summaries = [] | |
with ThreadPoolExecutor() as executor: | |
futures = {executor.submit(summarizer, chunk, max_length=100, min_length=30, do_sample=False): chunk for chunk in chunks} | |
for future in as_completed(futures): | |
summary = future.result()[0]["summary_text"] | |
summaries.append(summary) | |
return " ".join(summaries) | |
def read_files(file_objs, max_length=3000): | |
""" | |
Read and, if necessary, summarize file content from a list of file objects or file paths. | |
""" | |
full_text = "" | |
for file_obj in file_objs: | |
text = read_file(file_obj) | |
full_text += text + "\n" | |
cache_key = hash(full_text) | |
if cache_key in SUMMARY_CACHE: | |
return SUMMARY_CACHE[cache_key] | |
if len(full_text) > max_length: | |
summarized = summarize_text(full_text, chunk_size=max_length) | |
else: | |
summarized = full_text | |
SUMMARY_CACHE[cache_key] = summarized | |
return summarized | |
def build_prompt(system_msg, document_content, user_prompt): | |
""" | |
Build a unified prompt that explicitly delineates the system instructions, | |
document content, and user prompt. | |
""" | |
prompt_parts = [] | |
prompt_parts.append("SYSTEM PROMPT:\n" + system_msg.strip()) | |
if document_content: | |
prompt_parts.append("\nDOCUMENT CONTENT:\n" + document_content.strip()) | |
prompt_parts.append("\nUSER PROMPT:\n" + user_prompt.strip()) | |
return "\n\n".join(prompt_parts) | |
def speculative_decode(input_text, max_tokens=DEFAULT_MAX_TOKENS, top_p=0.9, temperature=0.7): | |
model_inputs = tokenizer([input_text], return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
output = model.generate( | |
**model_inputs, | |
max_new_tokens=max_tokens, | |
do_sample=True, | |
top_p=top_p, | |
temperature=temperature, | |
num_return_sequences=1, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
return tokenizer.decode(output[0], skip_special_tokens=True) | |
def post_process(text): | |
lines = text.splitlines() | |
unique_lines = [] | |
for line in lines: | |
clean_line = line.strip() | |
if clean_line and clean_line not in unique_lines: | |
unique_lines.append(clean_line) | |
return "\n".join(unique_lines) | |
def granite_analysis(user_prompt, file_objs=None, max_tokens=DEFAULT_MAX_TOKENS, top_p=0.9, temperature=0.7): | |
# Read and summarize document content. | |
document_content = read_files(file_objs) if file_objs else "" | |
# Define a clear system prompt. | |
system_prompt = ( | |
"You are IBM Granite, an enterprise legal and technical analysis assistant. " | |
"Your task is to critically analyze the contract document provided below. " | |
"Pay special attention to identifying dangerous provisions, legal pitfalls, and potential liabilities. " | |
"Make sure to address both the overall contract structure and specific clauses where applicable." | |
) | |
# Build a unified prompt with explicit sections. | |
unified_prompt = build_prompt(system_prompt, document_content, user_prompt) | |
# Generate the analysis. | |
response = speculative_decode(unified_prompt, max_tokens=max_tokens, top_p=top_p, temperature=temperature) | |
final_response = post_process(response) | |
return final_response | |
# --------- Streamlit App Interface --------- | |
st.title("IBM Granite - Contract Analysis Assistant") | |
st.markdown("Upload a contract document (PDF or text) and adjust the analysis prompt and generation parameters.") | |
# File uploader (allows drag & drop) | |
uploaded_files = st.file_uploader("Upload contract file(s)", type=["pdf", "txt"], accept_multiple_files=True) | |
# Editable prompt text area | |
default_prompt = ( | |
"Please analyze the attached contract document and highlight any clauses " | |
"that represent potential dangers, liabilities, or legal pitfalls that may lead to future disputes or significant financial exposure." | |
) | |
user_prompt = st.text_area("Analysis Prompt", default_prompt, height=150) | |
# Sliders for generation parameters. | |
max_tokens_slider = st.slider("Maximum Tokens", min_value=100, max_value=2000, value=DEFAULT_MAX_TOKENS, step=100) | |
temperature_slider = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.7, step=0.1) | |
top_p_slider = st.slider("Top-p", min_value=0.0, max_value=1.0, value=0.9, step=0.05) | |
if st.button("Analyze Contract"): | |
with st.spinner("Analyzing contract document..."): | |
result = granite_analysis(user_prompt, uploaded_files, max_tokens=max_tokens_slider, top_p=top_p_slider, temperature=temperature_slider) | |
st.success("Analysis complete!") | |
st.markdown("### Analysis Output") | |
keyword = "ASSISTANT PROMPT:" | |
text_after_keyword = result.rsplit(keyword, 1)[-1].strip() | |
st.text_area("Output", text_after_keyword, height=400) | |