Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -27,7 +27,7 @@ model_name = "ibm-granite/granite-3.1-2b-instruct"
|
|
| 27 |
model = AutoModelForCausalLM.from_pretrained(
|
| 28 |
model_name,
|
| 29 |
device_map="balanced", # Using balanced CPU mapping.
|
| 30 |
-
torch_dtype=torch.float16 # Use float16 if supported
|
| 31 |
)
|
| 32 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 33 |
|
|
@@ -64,7 +64,6 @@ def read_file(file_obj):
|
|
| 64 |
"""
|
| 65 |
Reads content from a file. Supports both file paths (str) and Streamlit uploaded files.
|
| 66 |
"""
|
| 67 |
-
# If file_obj is a string path:
|
| 68 |
if isinstance(file_obj, str):
|
| 69 |
if file_obj in FILE_CACHE:
|
| 70 |
return FILE_CACHE[file_obj]
|
|
@@ -139,11 +138,17 @@ def read_files(file_objs, max_length=3000):
|
|
| 139 |
SUMMARY_CACHE[cache_key] = summarized
|
| 140 |
return summarized
|
| 141 |
|
| 142 |
-
def
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
def speculative_decode(input_text, max_tokens=DEFAULT_MAX_TOKENS, top_p=0.9, temperature=0.7):
|
| 149 |
model_inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
|
|
@@ -168,18 +173,23 @@ def post_process(text):
|
|
| 168 |
unique_lines.append(clean_line)
|
| 169 |
return "\n".join(unique_lines)
|
| 170 |
|
| 171 |
-
def granite_analysis(
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
"
|
| 178 |
-
"
|
|
|
|
|
|
|
| 179 |
)
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
| 183 |
final_response = post_process(response)
|
| 184 |
return final_response
|
| 185 |
|
|
@@ -207,10 +217,4 @@ if st.button("Analyze Contract"):
|
|
| 207 |
result = granite_analysis(user_prompt, uploaded_files, max_tokens=max_tokens_slider, top_p=top_p_slider, temperature=temperature_slider)
|
| 208 |
st.success("Analysis complete!")
|
| 209 |
st.markdown("### Analysis Output")
|
| 210 |
-
|
| 211 |
-
keyword = "assistant"
|
| 212 |
-
text_after_keyword = result.rsplit(keyword, 1)[-1].strip()
|
| 213 |
-
|
| 214 |
-
st.text_area("Output", text_after_keyword, height=400)
|
| 215 |
-
|
| 216 |
-
|
|
|
|
| 27 |
model = AutoModelForCausalLM.from_pretrained(
|
| 28 |
model_name,
|
| 29 |
device_map="balanced", # Using balanced CPU mapping.
|
| 30 |
+
torch_dtype=torch.float16 # Use float16 if supported.
|
| 31 |
)
|
| 32 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 33 |
|
|
|
|
| 64 |
"""
|
| 65 |
Reads content from a file. Supports both file paths (str) and Streamlit uploaded files.
|
| 66 |
"""
|
|
|
|
| 67 |
if isinstance(file_obj, str):
|
| 68 |
if file_obj in FILE_CACHE:
|
| 69 |
return FILE_CACHE[file_obj]
|
|
|
|
| 138 |
SUMMARY_CACHE[cache_key] = summarized
|
| 139 |
return summarized
|
| 140 |
|
| 141 |
+
def build_prompt(system_msg, document_content, user_prompt):
|
| 142 |
+
"""
|
| 143 |
+
Build a unified prompt that explicitly delineates the system instructions,
|
| 144 |
+
document content, and user prompt.
|
| 145 |
+
"""
|
| 146 |
+
prompt_parts = []
|
| 147 |
+
prompt_parts.append("SYSTEM PROMPT:\n" + system_msg.strip())
|
| 148 |
+
if document_content:
|
| 149 |
+
prompt_parts.append("\nDOCUMENT CONTENT:\n" + document_content.strip())
|
| 150 |
+
prompt_parts.append("\nUSER PROMPT:\n" + user_prompt.strip())
|
| 151 |
+
return "\n\n".join(prompt_parts)
|
| 152 |
|
| 153 |
def speculative_decode(input_text, max_tokens=DEFAULT_MAX_TOKENS, top_p=0.9, temperature=0.7):
|
| 154 |
model_inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
|
|
|
|
| 173 |
unique_lines.append(clean_line)
|
| 174 |
return "\n".join(unique_lines)
|
| 175 |
|
| 176 |
+
def granite_analysis(user_prompt, file_objs=None, max_tokens=DEFAULT_MAX_TOKENS, top_p=0.9, temperature=0.7):
|
| 177 |
+
# Read and summarize document content.
|
| 178 |
+
document_content = read_files(file_objs) if file_objs else ""
|
| 179 |
+
|
| 180 |
+
# Define a clear system prompt.
|
| 181 |
+
system_prompt = (
|
| 182 |
+
"You are IBM Granite, an enterprise legal and technical analysis assistant. "
|
| 183 |
+
"Your task is to critically analyze the contract document provided below. "
|
| 184 |
+
"Pay special attention to identifying dangerous provisions, legal pitfalls, and potential liabilities. "
|
| 185 |
+
"Make sure to address both the overall contract structure and specific clauses where applicable."
|
| 186 |
)
|
| 187 |
+
|
| 188 |
+
# Build a unified prompt with explicit sections.
|
| 189 |
+
unified_prompt = build_prompt(system_prompt, document_content, user_prompt)
|
| 190 |
+
|
| 191 |
+
# Generate the analysis.
|
| 192 |
+
response = speculative_decode(unified_prompt, max_tokens=max_tokens, top_p=top_p, temperature=temperature)
|
| 193 |
final_response = post_process(response)
|
| 194 |
return final_response
|
| 195 |
|
|
|
|
| 217 |
result = granite_analysis(user_prompt, uploaded_files, max_tokens=max_tokens_slider, top_p=top_p_slider, temperature=temperature_slider)
|
| 218 |
st.success("Analysis complete!")
|
| 219 |
st.markdown("### Analysis Output")
|
| 220 |
+
st.text_area("Output", result, height=400)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|