Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
2e2dda5
0
Parent(s):
RAG fix
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- .gitignore +3 -0
- .streamlit/config.toml +21 -0
- RAG/bedrock_agent.py +146 -0
- RAG/generate_csv_for_tables.py +167 -0
- RAG/rag_DocumentLoader.py +395 -0
- RAG/rag_DocumentSearcher.py +338 -0
- README.md +13 -0
- app.py +125 -0
- figures/ukhousingstats/figure-1-1-resized.jpg +0 -0
- figures/ukhousingstats/figure-1-1.jpg +0 -0
- figures/ukhousingstats/figure-1-2-resized.jpg +0 -0
- figures/ukhousingstats/figure-1-2.jpg +0 -0
- figures/ukhousingstats/figure-2-3-resized.jpg +0 -0
- figures/ukhousingstats/figure-2-3.jpg +0 -0
- figures/ukhousingstats/figure-3-4-resized.jpg +0 -0
- figures/ukhousingstats/figure-3-4.jpg +0 -0
- figures/ukhousingstats/figure-3-5-resized.jpg +0 -0
- figures/ukhousingstats/figure-3-5.jpg +0 -0
- figures/ukhousingstats/figure-4-6-resized.jpg +0 -0
- figures/ukhousingstats/figure-4-6.jpg +0 -0
- figures/ukhousingstats/figure-4-7-resized.jpg +0 -0
- figures/ukhousingstats/figure-4-7.jpg +0 -0
- figures/ukhousingstats/figure-5-8-resized.jpg +0 -0
- figures/ukhousingstats/figure-5-8.jpg +0 -0
- figures/ukhousingstats/figure-6-10-resized.jpg +0 -0
- figures/ukhousingstats/figure-6-10.jpg +0 -0
- figures/ukhousingstats/figure-6-11-resized.jpg +0 -0
- figures/ukhousingstats/figure-6-11.jpg +0 -0
- figures/ukhousingstats/figure-6-12-resized.jpg +0 -0
- figures/ukhousingstats/figure-6-12.jpg +0 -0
- figures/ukhousingstats/figure-6-13-resized.jpg +0 -0
- figures/ukhousingstats/figure-6-13.jpg +0 -0
- figures/ukhousingstats/figure-6-14-resized.jpg +0 -0
- figures/ukhousingstats/figure-6-14.jpg +0 -0
- figures/ukhousingstats/figure-6-15-resized.jpg +0 -0
- figures/ukhousingstats/figure-6-15.jpg +0 -0
- figures/ukhousingstats/figure-6-16-resized.jpg +0 -0
- figures/ukhousingstats/figure-6-16.jpg +0 -0
- figures/ukhousingstats/figure-6-17-resized.jpg +0 -0
- figures/ukhousingstats/figure-6-17.jpg +0 -0
- figures/ukhousingstats/figure-6-18-resized.jpg +0 -0
- figures/ukhousingstats/figure-6-18.jpg +0 -0
- figures/ukhousingstats/figure-6-19-resized.jpg +0 -0
- figures/ukhousingstats/figure-6-19.jpg +0 -0
- figures/ukhousingstats/figure-6-20-resized.jpg +0 -0
- figures/ukhousingstats/figure-6-20.jpg +0 -0
- figures/ukhousingstats/figure-6-21-resized.jpg +0 -0
- figures/ukhousingstats/figure-6-21.jpg +0 -0
- figures/ukhousingstats/figure-6-22-resized.jpg +0 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**/__pycache__/
|
| 2 |
+
*.DS_Store
|
| 3 |
+
|
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
[client]
|
| 3 |
+
toolbarMode = "viewer"
|
| 4 |
+
showSidebarNavigation = false
|
| 5 |
+
showErrorDetails = true
|
| 6 |
+
|
| 7 |
+
[browser]
|
| 8 |
+
gatherUsageStats = false
|
| 9 |
+
|
| 10 |
+
[theme]
|
| 11 |
+
base="dark"
|
| 12 |
+
font="sans serif"
|
| 13 |
+
primaryColor="#e28743"
|
| 14 |
+
backgroundColor ="#000000"
|
| 15 |
+
|
| 16 |
+
[global]
|
| 17 |
+
disableWidgetStateDuplicationWarning = true
|
| 18 |
+
showWarningOnDirectExecution = false
|
| 19 |
+
|
| 20 |
+
[server]
|
| 21 |
+
enableXsrfProtection=false
|
RAG/bedrock_agent.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import boto3
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
import zipfile
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
import uuid
|
| 7 |
+
import pprint
|
| 8 |
+
import logging
|
| 9 |
+
print(boto3.__version__)
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import os
|
| 12 |
+
import base64
|
| 13 |
+
import re
|
| 14 |
+
import requests
|
| 15 |
+
import utilities.re_ranker as re_ranker
|
| 16 |
+
import utilities.invoke_models as invoke_models
|
| 17 |
+
import streamlit as st
|
| 18 |
+
import time as t
|
| 19 |
+
import botocore.exceptions
|
| 20 |
+
|
| 21 |
+
if "inputs_" not in st.session_state:
|
| 22 |
+
st.session_state.inputs_ = {}
|
| 23 |
+
|
| 24 |
+
parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
|
| 25 |
+
region = 'us-east-1'
|
| 26 |
+
print(region)
|
| 27 |
+
account_id = '445083327804'
|
| 28 |
+
# setting logger
|
| 29 |
+
logging.basicConfig(format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', level=logging.INFO)
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
# getting boto3 clients for required AWS services
|
| 32 |
+
|
| 33 |
+
#bedrock_agent_client = boto3.client('bedrock-agent',region_name=region)
|
| 34 |
+
bedrock_agent_runtime_client = boto3.client(
|
| 35 |
+
'bedrock-agent-runtime',
|
| 36 |
+
aws_access_key_id=st.secrets['user_access_key'],
|
| 37 |
+
aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1'
|
| 38 |
+
)
|
| 39 |
+
enable_trace:bool = True
|
| 40 |
+
end_session:bool = False
|
| 41 |
+
|
| 42 |
+
def delete_memory():
|
| 43 |
+
response = bedrock_agent_runtime_client.delete_agent_memory(
|
| 44 |
+
agentAliasId='TSTALIASID',
|
| 45 |
+
agentId='B4Z7BTURC4'
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def query_(inputs):
|
| 49 |
+
## create a random id for session initiator id
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# invoke the agent API
|
| 53 |
+
agentResponse = bedrock_agent_runtime_client.invoke_agent(
|
| 54 |
+
inputText=inputs['shopping_query'],
|
| 55 |
+
agentId='B4Z7BTURC4',
|
| 56 |
+
agentAliasId='TSTALIASID',
|
| 57 |
+
sessionId=st.session_state.session_id_,
|
| 58 |
+
enableTrace=enable_trace,
|
| 59 |
+
endSession= end_session
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
logger.info(pprint.pprint(agentResponse))
|
| 63 |
+
print("***agent*****response*********")
|
| 64 |
+
print(agentResponse)
|
| 65 |
+
event_stream = agentResponse['completion']
|
| 66 |
+
total_context = []
|
| 67 |
+
last_tool = ""
|
| 68 |
+
last_tool_name = ""
|
| 69 |
+
agent_answer = ""
|
| 70 |
+
try:
|
| 71 |
+
for event in event_stream:
|
| 72 |
+
print("***event*********")
|
| 73 |
+
print(event)
|
| 74 |
+
# if 'chunk' in event:
|
| 75 |
+
# data = event['chunk']['bytes']
|
| 76 |
+
# print("***chunk*********")
|
| 77 |
+
# print(data)
|
| 78 |
+
# logger.info(f"Final answer ->\n{data.decode('utf8')}")
|
| 79 |
+
# agent_answer_ = data.decode('utf8')
|
| 80 |
+
# print(agent_answer_)
|
| 81 |
+
if 'trace' in event:
|
| 82 |
+
print("trace*****total*********")
|
| 83 |
+
print(event['trace'])
|
| 84 |
+
if('orchestrationTrace' not in event['trace']['trace']):
|
| 85 |
+
continue
|
| 86 |
+
orchestration_trace = event['trace']['trace']['orchestrationTrace']
|
| 87 |
+
total_context_item = {}
|
| 88 |
+
if('modelInvocationOutput' in orchestration_trace and '<tool_name>' in orchestration_trace['modelInvocationOutput']['rawResponse']['content']):
|
| 89 |
+
total_context_item['tool'] = orchestration_trace['modelInvocationOutput']['rawResponse']
|
| 90 |
+
if('rationale' in orchestration_trace):
|
| 91 |
+
total_context_item['rationale'] = orchestration_trace['rationale']['text']
|
| 92 |
+
if('invocationInput' in orchestration_trace):
|
| 93 |
+
total_context_item['invocationInput'] = orchestration_trace['invocationInput']['actionGroupInvocationInput']
|
| 94 |
+
last_tool_name = total_context_item['invocationInput']['function']
|
| 95 |
+
if('observation' in orchestration_trace):
|
| 96 |
+
print("trace****observation******")
|
| 97 |
+
total_context_item['observation'] = event['trace']['trace']['orchestrationTrace']['observation']
|
| 98 |
+
tool_output_last_obs = event['trace']['trace']['orchestrationTrace']['observation']
|
| 99 |
+
print(tool_output_last_obs)
|
| 100 |
+
if(tool_output_last_obs['type'] == 'ACTION_GROUP'):
|
| 101 |
+
last_tool = tool_output_last_obs['actionGroupInvocationOutput']['text']
|
| 102 |
+
if(tool_output_last_obs['type'] == 'FINISH'):
|
| 103 |
+
agent_answer = tool_output_last_obs['finalResponse']['text']
|
| 104 |
+
if('modelInvocationOutput' in orchestration_trace and '<thinking>' in orchestration_trace['modelInvocationOutput']['rawResponse']['content']):
|
| 105 |
+
total_context_item['thinking'] = orchestration_trace['modelInvocationOutput']['rawResponse']
|
| 106 |
+
if(total_context_item!={}):
|
| 107 |
+
total_context.append(total_context_item)
|
| 108 |
+
print("total_context------")
|
| 109 |
+
print(total_context)
|
| 110 |
+
except botocore.exceptions.EventStreamError as error:
|
| 111 |
+
raise error
|
| 112 |
+
# t.sleep(2)
|
| 113 |
+
# query_(st.session_state.inputs_)
|
| 114 |
+
|
| 115 |
+
# if 'chunk' in event:
|
| 116 |
+
# data = event['chunk']['bytes']
|
| 117 |
+
# final_ans = data.decode('utf8')
|
| 118 |
+
# print(f"Final answer ->\n{final_ans}")
|
| 119 |
+
# logger.info(f"Final answer ->\n{final_ans}")
|
| 120 |
+
# agent_answer = final_ans
|
| 121 |
+
# end_event_received = True
|
| 122 |
+
# # End event indicates that the request finished successfully
|
| 123 |
+
# elif 'trace' in event:
|
| 124 |
+
# logger.info(json.dumps(event['trace'], indent=2))
|
| 125 |
+
# else:
|
| 126 |
+
# raise Exception("unexpected event.", event)
|
| 127 |
+
# except Exception as e:
|
| 128 |
+
# raise Exception("unexpected event.", e)
|
| 129 |
+
return {'text':agent_answer,'source':total_context,'last_tool':{'name':last_tool_name,'response':last_tool}}
|
| 130 |
+
|
| 131 |
+
####### Re-Rank ########
|
| 132 |
+
|
| 133 |
+
#print("re-rank")
|
| 134 |
+
|
| 135 |
+
# if(st.session_state.input_is_rerank == True and len(total_context)):
|
| 136 |
+
# ques = [{"question":question}]
|
| 137 |
+
# ans = [{"answer":total_context}]
|
| 138 |
+
|
| 139 |
+
# total_context = re_ranker.re_rank('rag','Cross Encoder',"",ques, ans)
|
| 140 |
+
|
| 141 |
+
# llm_prompt = prompt_template.format(context=total_context[0],question=question)
|
| 142 |
+
# output = invoke_models.invoke_llm_model( "\n\nHuman: {input}\n\nAssistant:".format(input=llm_prompt) ,False)
|
| 143 |
+
# #print(output)
|
| 144 |
+
# if(len(images_2)==0):
|
| 145 |
+
# images_2 = images
|
| 146 |
+
# return {'text':output,'source':total_context,'image':images_2,'table':df}
|
RAG/generate_csv_for_tables.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import boto3
|
| 4 |
+
import io
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
import sys
|
| 7 |
+
from pprint import pprint
|
| 8 |
+
from PyPDF2 import PdfWriter, PdfReader
|
| 9 |
+
import re
|
| 10 |
+
import shutil
|
| 11 |
+
import streamlit as st
|
| 12 |
+
|
| 13 |
+
file_content = {}
|
| 14 |
+
parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
|
| 15 |
+
# if os.path.isdir(parent_dirname+"/split_pdf"):
|
| 16 |
+
# shutil.rmtree(parent_dirname+"/split_pdf")
|
| 17 |
+
# os.mkdir(parent_dirname+"/split_pdf")
|
| 18 |
+
|
| 19 |
+
# if os.path.isdir(parent_dirname+"/split_pdf_csv"):
|
| 20 |
+
# shutil.rmtree(parent_dirname+"/split_pdf_csv")
|
| 21 |
+
# os.mkdir(parent_dirname+"/split_pdf_csv")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_rows_columns_map(table_result, blocks_map):
|
| 25 |
+
rows = {}
|
| 26 |
+
#scores = []
|
| 27 |
+
for relationship in table_result['Relationships']:
|
| 28 |
+
if relationship['Type'] == 'CHILD':
|
| 29 |
+
for child_id in relationship['Ids']:
|
| 30 |
+
cell = blocks_map[child_id]
|
| 31 |
+
if cell['BlockType'] == 'CELL':
|
| 32 |
+
row_index = cell['RowIndex']
|
| 33 |
+
col_index = cell['ColumnIndex']
|
| 34 |
+
if row_index not in rows:
|
| 35 |
+
# create new row
|
| 36 |
+
rows[row_index] = {}
|
| 37 |
+
|
| 38 |
+
# get confidence score
|
| 39 |
+
#scores.append(str(cell['Confidence']))
|
| 40 |
+
|
| 41 |
+
# get the text value
|
| 42 |
+
rows[row_index][col_index] = get_text(cell, blocks_map)
|
| 43 |
+
return rows#, scores
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_text(result, blocks_map):
|
| 47 |
+
text = ''
|
| 48 |
+
if 'Relationships' in result:
|
| 49 |
+
for relationship in result['Relationships']:
|
| 50 |
+
if relationship['Type'] == 'CHILD':
|
| 51 |
+
for child_id in relationship['Ids']:
|
| 52 |
+
word = blocks_map[child_id]
|
| 53 |
+
if word['BlockType'] == 'WORD':
|
| 54 |
+
if "," in word['Text'] and word['Text'].replace(",", "").isnumeric():
|
| 55 |
+
text += '"' + word['Text'] + '"' +' '
|
| 56 |
+
else:
|
| 57 |
+
text += word['Text'] +' '
|
| 58 |
+
if word['BlockType'] == 'SELECTION_ELEMENT':
|
| 59 |
+
if word['SelectionStatus'] =='SELECTED':
|
| 60 |
+
text += 'X '
|
| 61 |
+
return text
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def split_pages(file_name):
|
| 65 |
+
|
| 66 |
+
inputpdf = PdfReader(open(file_name, "rb"))
|
| 67 |
+
file_name_short = re.sub('[^A-Za-z0-9]+', '', (file_name.split("/")[-1].split(".")[0]).lower())
|
| 68 |
+
|
| 69 |
+
for i in range(len(inputpdf.pages)):
|
| 70 |
+
|
| 71 |
+
output = PdfWriter()
|
| 72 |
+
output.add_page(inputpdf.pages[i])
|
| 73 |
+
split_file = parent_dirname+"/split_pdf/"+file_name_short+"%s.pdf" % i
|
| 74 |
+
|
| 75 |
+
with open(split_file, "wb") as outputStream:
|
| 76 |
+
output.write(outputStream)
|
| 77 |
+
table_csv = get_table_csv_results(split_file)
|
| 78 |
+
if(table_csv != "<b> NO Table FOUND </b>"):
|
| 79 |
+
|
| 80 |
+
output_file = parent_dirname+"/split_pdf_csv/"+file_name_short+"%s.csv" % i
|
| 81 |
+
file_content[output_file] = table_csv
|
| 82 |
+
|
| 83 |
+
# replace content
|
| 84 |
+
with open(output_file, "wt") as fout:
|
| 85 |
+
fout.write(table_csv)
|
| 86 |
+
|
| 87 |
+
# show the results
|
| 88 |
+
print('CSV OUTPUT FILE: ', output_file)
|
| 89 |
+
return file_content
|
| 90 |
+
|
| 91 |
+
def get_table_csv_results(file_name):
|
| 92 |
+
|
| 93 |
+
with open(file_name, 'rb') as file:
|
| 94 |
+
img_test = file.read()
|
| 95 |
+
bytes_test = bytearray(img_test)
|
| 96 |
+
#print('Image loaded', file_name)
|
| 97 |
+
|
| 98 |
+
# process using image bytes
|
| 99 |
+
# get the results
|
| 100 |
+
#session = boto3.Session(profile_name='profile-name')
|
| 101 |
+
client = boto3.client('textract',aws_access_key_id=st.secrets['user_access_key'],
|
| 102 |
+
aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1')
|
| 103 |
+
# {'S3Object': {
|
| 104 |
+
# 'Bucket': 'ml-search-app-access',
|
| 105 |
+
# 'Name': 'covid19_ie_removed.pdf'
|
| 106 |
+
# }}
|
| 107 |
+
|
| 108 |
+
response = client.analyze_document(Document={'Bytes': bytes_test}, FeatureTypes=['TABLES'])
|
| 109 |
+
|
| 110 |
+
# Get the text blocks
|
| 111 |
+
blocks=response['Blocks']
|
| 112 |
+
#pprint(blocks)
|
| 113 |
+
|
| 114 |
+
blocks_map = {}
|
| 115 |
+
table_blocks = []
|
| 116 |
+
for block in blocks:
|
| 117 |
+
blocks_map[block['Id']] = block
|
| 118 |
+
if block['BlockType'] == "TABLE":
|
| 119 |
+
table_blocks.append(block)
|
| 120 |
+
|
| 121 |
+
if len(table_blocks) <= 0:
|
| 122 |
+
return "<b> NO Table FOUND </b>"
|
| 123 |
+
|
| 124 |
+
csv = ''
|
| 125 |
+
for index, table in enumerate(table_blocks):
|
| 126 |
+
csv += generate_table_csv(table, blocks_map, index +1)
|
| 127 |
+
csv += '\n\n'
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
return csv
|
| 131 |
+
|
| 132 |
+
def generate_table_csv(table_result, blocks_map, table_index):
|
| 133 |
+
rows = get_rows_columns_map(table_result, blocks_map)
|
| 134 |
+
|
| 135 |
+
table_id = 'Table_' + str(table_index)
|
| 136 |
+
|
| 137 |
+
# get cells.
|
| 138 |
+
csv = ''#Table: {0}\n\n'.format(table_id)
|
| 139 |
+
for row_index, cols in rows.items():
|
| 140 |
+
for col_index, text in cols.items():
|
| 141 |
+
col_indices = len(cols.items())
|
| 142 |
+
csv += text.strip()+"`" #'{}'.format(text) + ","
|
| 143 |
+
csv += '\n'
|
| 144 |
+
|
| 145 |
+
# csv += '\n\n Confidence Scores % (Table Cell) \n'
|
| 146 |
+
# cols_count = 0
|
| 147 |
+
# for score in scores:
|
| 148 |
+
# cols_count += 1
|
| 149 |
+
# csv += score + ","
|
| 150 |
+
# if cols_count == col_indices:
|
| 151 |
+
# csv += '\n'
|
| 152 |
+
# cols_count = 0
|
| 153 |
+
|
| 154 |
+
csv += '\n\n\n'
|
| 155 |
+
return csv
|
| 156 |
+
|
| 157 |
+
def main_(file_name):
|
| 158 |
+
table_csv = split_pages(file_name)
|
| 159 |
+
#print(table_csv)
|
| 160 |
+
return table_csv
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# if __name__ == "__main__":
|
| 166 |
+
# file_name = "/home/ubuntu/covid19_ie_removed.pdf"
|
| 167 |
+
# main(file_name)
|
RAG/rag_DocumentLoader.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import boto3
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import time
|
| 6 |
+
from unstructured.partition.pdf import partition_pdf
|
| 7 |
+
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
|
| 8 |
+
import streamlit as st
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import base64
|
| 11 |
+
import re
|
| 12 |
+
#import torch
|
| 13 |
+
import base64
|
| 14 |
+
import requests
|
| 15 |
+
from requests_aws4auth import AWS4Auth
|
| 16 |
+
import re_ranker
|
| 17 |
+
import utilities.invoke_models as invoke_models
|
| 18 |
+
from requests.auth import HTTPBasicAuth
|
| 19 |
+
|
| 20 |
+
import generate_csv_for_tables
|
| 21 |
+
from pdf2image import convert_from_bytes,convert_from_path
|
| 22 |
+
#import langchain
|
| 23 |
+
|
| 24 |
+
bedrock_runtime_client = boto3.client('bedrock-runtime',region_name='us-east-1')
|
| 25 |
+
textract_client = boto3.client('textract',region_name='us-east-1')
|
| 26 |
+
|
| 27 |
+
region = 'us-east-1'
|
| 28 |
+
service = 'es'
|
| 29 |
+
|
| 30 |
+
credentials = boto3.Session().get_credentials()
|
| 31 |
+
auth = HTTPBasicAuth('prasadnu',st.secrets['rag_shopping_assistant_os_api_access'])
|
| 32 |
+
|
| 33 |
+
ospy_client = OpenSearch(
|
| 34 |
+
hosts = [{'host': 'search-opensearchservi-75ucark0bqob-bzk6r6h2t33dlnpgx2pdeg22gi.us-east-1.es.amazonaws.com', 'port': 443}],
|
| 35 |
+
http_auth = auth,
|
| 36 |
+
use_ssl = True,
|
| 37 |
+
verify_certs = True,
|
| 38 |
+
connection_class = RequestsHttpConnection,
|
| 39 |
+
pool_maxsize = 20
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
summary_prompt = """You are an assistant tasked with summarizing tables and text. \
|
| 45 |
+
Give a detailed summary of the table or text. Table or text chunk: {element} """
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def generate_image_captions_(image_paths):
|
| 54 |
+
images = []
|
| 55 |
+
for image_path in image_paths:
|
| 56 |
+
i_image = Image.open(image_path)
|
| 57 |
+
if i_image.mode != "RGB":
|
| 58 |
+
i_image = i_image.convert(mode="RGB")
|
| 59 |
+
|
| 60 |
+
images.append(i_image)
|
| 61 |
+
|
| 62 |
+
pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
|
| 63 |
+
pixel_values = pixel_values.to(device)
|
| 64 |
+
|
| 65 |
+
output_ids = model.generate(pixel_values, **gen_kwargs)
|
| 66 |
+
|
| 67 |
+
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
| 68 |
+
preds = [pred.strip() for pred in preds]
|
| 69 |
+
return preds
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def load_docs(inp):
|
| 75 |
+
|
| 76 |
+
print("input_doc")
|
| 77 |
+
print(inp)
|
| 78 |
+
extracted_elements_list = []
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
data_dir = parent_dirname+"/pdfs"
|
| 82 |
+
target_files = [os.path.join(data_dir,inp["key"])]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
Image.MAX_IMAGE_PIXELS = 100000000
|
| 87 |
+
width = 2048
|
| 88 |
+
height = 2048
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
for target_file in target_files:
|
| 92 |
+
tables_textract = generate_csv_for_tables.main_(target_file)
|
| 93 |
+
#tables_textract = {}
|
| 94 |
+
index_ = re.sub('[^A-Za-z0-9]+', '', (target_file.split("/")[-1].split(".")[0]).lower())
|
| 95 |
+
st.session_state.input_index = index_
|
| 96 |
+
|
| 97 |
+
if os.path.isdir(parent_dirname+'/figures/') == False:
|
| 98 |
+
os.mkdir(parent_dirname+'/figures/')
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
image_output_dir = parent_dirname+'/figures/'+st.session_state.input_index+"/"
|
| 105 |
+
|
| 106 |
+
if os.path.isdir(image_output_dir):
|
| 107 |
+
shutil.rmtree(image_output_dir)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
os.mkdir(image_output_dir)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
print("***")
|
| 114 |
+
print(target_file)
|
| 115 |
+
#image_output_dir_path = os.path.join(image_output_dir,target_file.split('/')[-1].split('.')[0])
|
| 116 |
+
#os.mkdir(image_output_dir_path)
|
| 117 |
+
|
| 118 |
+
# with open(target_file, "rb") as pdf_file:
|
| 119 |
+
# encoded_string_pdf = bytearray(pdf_file.read())
|
| 120 |
+
|
| 121 |
+
#images_pdf = convert_from_path(target_file)
|
| 122 |
+
|
| 123 |
+
# for index,image in enumerate(images_pdf):
|
| 124 |
+
# image.save(image_output_dir_pdf+"/"+st.session_state.input_index+"/"+str(index)+"_pdf.jpeg", 'JPEG')
|
| 125 |
+
# with open(image_output_dir_pdf+"/"+st.session_state.input_index+"/"+str(index)+"_pdf.jpeg", "rb") as read_img:
|
| 126 |
+
# input_encoded = base64.b64encode(read_img.read())
|
| 127 |
+
# print(encoded_string_pdf)
|
| 128 |
+
# tables_= textract_client.analyze_document(
|
| 129 |
+
# Document={'Bytes': encoded_string_pdf},
|
| 130 |
+
# FeatureTypes=['TABLES']
|
| 131 |
+
# )
|
| 132 |
+
|
| 133 |
+
# print(tables_)
|
| 134 |
+
|
| 135 |
+
table_and_text_elements = partition_pdf(
|
| 136 |
+
filename=target_file,
|
| 137 |
+
extract_images_in_pdf=True,
|
| 138 |
+
infer_table_structure=False,
|
| 139 |
+
chunking_strategy="by_title", #Uses title elements to identify sections within the document for chunking
|
| 140 |
+
max_characters=4000,
|
| 141 |
+
new_after_n_chars=3800,
|
| 142 |
+
combine_text_under_n_chars=2000,
|
| 143 |
+
extract_image_block_output_dir=parent_dirname+'/figures/'+st.session_state.input_index+'/',
|
| 144 |
+
)
|
| 145 |
+
tables = []
|
| 146 |
+
texts = []
|
| 147 |
+
print(table_and_text_elements)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
for table in tables_textract.keys():
|
| 151 |
+
print(table)
|
| 152 |
+
#print(tables_textract[table])
|
| 153 |
+
tables.append({'table_name':table,'raw':tables_textract[table],'summary':invoke_models.invoke_llm_model(summary_prompt.format(element=tables_textract[table]),False)})
|
| 154 |
+
time.sleep(4)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
for element in table_and_text_elements:
|
| 158 |
+
# if "unstructured.documents.elements.Table" in str(type(element)):
|
| 159 |
+
# tables.append({'raw':str(element),'summary':invoke_models.invoke_llm_model(summary_prompt.format(element=str(element)),False)})
|
| 160 |
+
# tables_source.append({'raw':element,'summary':invoke_models.invoke_llm_model(summary_prompt.format(element=str(element)),False)})
|
| 161 |
+
|
| 162 |
+
if "unstructured.documents.elements.CompositeElement" in str(type(element)):
|
| 163 |
+
texts.append(str(element))
|
| 164 |
+
image_captions = {}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
for image_file in os.listdir(image_output_dir):
|
| 168 |
+
print("image_processing")
|
| 169 |
+
|
| 170 |
+
photo_full_path = image_output_dir+image_file
|
| 171 |
+
photo_full_path_no_format = photo_full_path.replace('.jpg',"")
|
| 172 |
+
|
| 173 |
+
with Image.open(photo_full_path) as image:
|
| 174 |
+
image.verify()
|
| 175 |
+
|
| 176 |
+
with Image.open(photo_full_path) as image:
|
| 177 |
+
|
| 178 |
+
file_type = 'jpg'
|
| 179 |
+
path = image.filename.rsplit(".", 1)[0]
|
| 180 |
+
image.thumbnail((width, height))
|
| 181 |
+
image.save(photo_full_path_no_format+"-resized.jpg")
|
| 182 |
+
|
| 183 |
+
with open(photo_full_path_no_format+"-resized.jpg", "rb") as read_img:
|
| 184 |
+
input_encoded = base64.b64encode(read_img.read()).decode("utf8")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
image_captions[image_file] = {"caption":invoke_models.generate_image_captions_llm(input_encoded, "What's in this image?"),
|
| 188 |
+
"encoding":input_encoded
|
| 189 |
+
}
|
| 190 |
+
print("image_processing done")
|
| 191 |
+
#print(image_captions)
|
| 192 |
+
|
| 193 |
+
#print(os.path.join('figures',image_file))
|
| 194 |
+
extracted_elements_list = []
|
| 195 |
+
extracted_elements_list.append({
|
| 196 |
+
'source': target_file,
|
| 197 |
+
'tables': tables,
|
| 198 |
+
'texts': texts,
|
| 199 |
+
'images': image_captions
|
| 200 |
+
})
|
| 201 |
+
documents = []
|
| 202 |
+
documents_mm = []
|
| 203 |
+
for extracted_element in extracted_elements_list:
|
| 204 |
+
print("prepping data")
|
| 205 |
+
texts = extracted_element['texts']
|
| 206 |
+
tables = extracted_element['tables']
|
| 207 |
+
images_data = extracted_element['images']
|
| 208 |
+
src_doc = extracted_element['source']
|
| 209 |
+
for text in texts:
|
| 210 |
+
embedding = invoke_models.invoke_model(text)
|
| 211 |
+
document = prep_document(text,text,'text',src_doc,'none',embedding)
|
| 212 |
+
documents.append(document)
|
| 213 |
+
for table in tables:
|
| 214 |
+
table_raw = table['raw']
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
table_summary = table['summary']
|
| 218 |
+
embedding = invoke_models.invoke_model(table_summary)
|
| 219 |
+
|
| 220 |
+
document = prep_document(table_raw,table_summary,'table*'+table['table_name'],src_doc,'none',embedding)
|
| 221 |
+
documents.append(document)
|
| 222 |
+
for file_name in images_data.keys():
|
| 223 |
+
embedding = invoke_models.invoke_model_mm(image_captions[file_name]['caption'],image_captions[file_name]['encoding'])
|
| 224 |
+
document = prep_document(image_captions[file_name]['caption'],image_captions[file_name]['caption'],'image_'+file_name,src_doc,image_captions[file_name]['encoding'],embedding)
|
| 225 |
+
documents_mm.append(document)
|
| 226 |
+
|
| 227 |
+
embedding = invoke_models.invoke_model(image_captions[file_name]['caption'])
|
| 228 |
+
document = prep_document(image_captions[file_name]['caption'],image_captions[file_name]['caption'],'image_'+file_name,src_doc,'none',embedding)
|
| 229 |
+
documents.append(document)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
os_ingest(index_, documents)
|
| 234 |
+
os_ingest_mm(index_, documents_mm)
|
| 235 |
+
|
| 236 |
+
def prep_document(raw_element,processed_element,doc_type,src_doc,encoding,embedding):
|
| 237 |
+
if('image' in doc_type):
|
| 238 |
+
img_ = doc_type.split("_")[1]
|
| 239 |
+
else:
|
| 240 |
+
img_ = "None"
|
| 241 |
+
document = {
|
| 242 |
+
"processed_element": re.sub(r"[^a-zA-Z0-9]+", ' ', processed_element) ,
|
| 243 |
+
"raw_element_type": doc_type.split("*")[0],
|
| 244 |
+
"raw_element": re.sub(r"[^a-zA-Z0-9]+", ' ', raw_element) ,
|
| 245 |
+
"src_doc": src_doc.replace(","," "),
|
| 246 |
+
"image": img_,
|
| 247 |
+
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
if(encoding!="none"):
|
| 251 |
+
document["image_encoding"] = encoding
|
| 252 |
+
document["processed_element_embedding_bedrock-multimodal"] = embedding
|
| 253 |
+
else:
|
| 254 |
+
document["processed_element_embedding"] = embedding
|
| 255 |
+
|
| 256 |
+
if('table' in doc_type):
|
| 257 |
+
document["table"] = doc_type.split("*")[1]
|
| 258 |
+
|
| 259 |
+
return document
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def os_ingest(index_,documents):
|
| 264 |
+
print("ingesting data")
|
| 265 |
+
#host = 'your collection id.region.aoss.amazonaws.com'
|
| 266 |
+
if(ospy_client.indices.exists(index=index_)):
|
| 267 |
+
ospy_client.indices.delete(index = index_)
|
| 268 |
+
index_body = {
|
| 269 |
+
"settings": {
|
| 270 |
+
"index": {
|
| 271 |
+
"knn": True,
|
| 272 |
+
"default_pipeline": "rag-ingest-pipeline",
|
| 273 |
+
"number_of_shards": 4
|
| 274 |
+
}
|
| 275 |
+
},
|
| 276 |
+
"mappings": {
|
| 277 |
+
"properties": {
|
| 278 |
+
"processed_element": {
|
| 279 |
+
"type": "text"
|
| 280 |
+
},
|
| 281 |
+
"raw_element": {
|
| 282 |
+
"type": "text"
|
| 283 |
+
},
|
| 284 |
+
"processed_element_embedding": {
|
| 285 |
+
"type": "knn_vector",
|
| 286 |
+
"dimension":1536,
|
| 287 |
+
"method": {
|
| 288 |
+
"engine": "faiss",
|
| 289 |
+
"space_type": "l2",
|
| 290 |
+
"name": "hnsw",
|
| 291 |
+
"parameters": {}
|
| 292 |
+
}
|
| 293 |
+
},
|
| 294 |
+
# "processed_element_embedding_bedrock-multimodal": {
|
| 295 |
+
# "type": "knn_vector",
|
| 296 |
+
# "dimension": 1024,
|
| 297 |
+
# "method": {
|
| 298 |
+
# "engine": "faiss",
|
| 299 |
+
# "space_type": "l2",
|
| 300 |
+
# "name": "hnsw",
|
| 301 |
+
# "parameters": {}
|
| 302 |
+
# }
|
| 303 |
+
# },
|
| 304 |
+
# "image_encoding": {
|
| 305 |
+
# "type": "binary"
|
| 306 |
+
# },
|
| 307 |
+
"raw_element_type": {
|
| 308 |
+
"type": "text"
|
| 309 |
+
},
|
| 310 |
+
"processed_element_embedding_sparse": {
|
| 311 |
+
"type": "rank_features"
|
| 312 |
+
},
|
| 313 |
+
"src_doc": {
|
| 314 |
+
"type": "text"
|
| 315 |
+
},
|
| 316 |
+
"image":{ "type": "text"}
|
| 317 |
+
|
| 318 |
+
}
|
| 319 |
+
}
|
| 320 |
+
}
|
| 321 |
+
response = ospy_client.indices.create(index_, body=index_body)
|
| 322 |
+
|
| 323 |
+
for doc in documents:
|
| 324 |
+
print("----------doc------------")
|
| 325 |
+
if(doc['image']!='None'):
|
| 326 |
+
print("image insert")
|
| 327 |
+
print(doc['image'])
|
| 328 |
+
|
| 329 |
+
response = ospy_client.index(
|
| 330 |
+
index = index_,
|
| 331 |
+
body = doc,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def os_ingest_mm(index_,documents_mm):
|
| 336 |
+
#host = 'your collection id.region.aoss.amazonaws.com'
|
| 337 |
+
index_ = index_+"_mm"
|
| 338 |
+
if(ospy_client.indices.exists(index=index_)):
|
| 339 |
+
ospy_client.indices.delete(index = index_)
|
| 340 |
+
index_body = {
|
| 341 |
+
"settings": {
|
| 342 |
+
"index": {
|
| 343 |
+
"knn": True,
|
| 344 |
+
# "default_pipeline": "rag-ingest-pipeline",
|
| 345 |
+
"number_of_shards": 4
|
| 346 |
+
}
|
| 347 |
+
},
|
| 348 |
+
"mappings": {
|
| 349 |
+
"properties": {
|
| 350 |
+
"processed_element": {
|
| 351 |
+
"type": "text"
|
| 352 |
+
},
|
| 353 |
+
"raw_element": {
|
| 354 |
+
"type": "text"
|
| 355 |
+
},
|
| 356 |
+
|
| 357 |
+
"processed_element_embedding_bedrock-multimodal": {
|
| 358 |
+
"type": "knn_vector",
|
| 359 |
+
"dimension": 1024,
|
| 360 |
+
"method": {
|
| 361 |
+
"engine": "faiss",
|
| 362 |
+
"space_type": "l2",
|
| 363 |
+
"name": "hnsw",
|
| 364 |
+
"parameters": {}
|
| 365 |
+
}
|
| 366 |
+
},
|
| 367 |
+
"image_encoding": {
|
| 368 |
+
"type": "binary"
|
| 369 |
+
},
|
| 370 |
+
"raw_element_type": {
|
| 371 |
+
"type": "text"
|
| 372 |
+
},
|
| 373 |
+
|
| 374 |
+
"src_doc": {
|
| 375 |
+
"type": "text"
|
| 376 |
+
},
|
| 377 |
+
"image":{ "type": "text"}
|
| 378 |
+
|
| 379 |
+
}
|
| 380 |
+
}
|
| 381 |
+
}
|
| 382 |
+
response = ospy_client.indices.create(index_, body=index_body)
|
| 383 |
+
|
| 384 |
+
for doc in documents_mm:
|
| 385 |
+
#print("----------doc------------")
|
| 386 |
+
#print(doc)
|
| 387 |
+
|
| 388 |
+
response = ospy_client.index(
|
| 389 |
+
index = index_,
|
| 390 |
+
body = doc,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
|
RAG/rag_DocumentSearcher.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import boto3
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
from unstructured.partition.pdf import partition_pdf
|
| 6 |
+
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
|
| 7 |
+
import streamlit as st
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import base64
|
| 10 |
+
import re
|
| 11 |
+
#from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
|
| 12 |
+
import torch
|
| 13 |
+
import base64
|
| 14 |
+
import requests
|
| 15 |
+
import utilities.re_ranker as re_ranker
|
| 16 |
+
import utilities.invoke_models as invoke_models
|
| 17 |
+
#import langchain
|
| 18 |
+
headers = {"Content-Type": "application/json"}
|
| 19 |
+
host = "https://search-opensearchservi-75ucark0bqob-bzk6r6h2t33dlnpgx2pdeg22gi.us-east-1.es.amazonaws.com/"
|
| 20 |
+
|
| 21 |
+
parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
|
| 22 |
+
|
| 23 |
+
def query_(awsauth,inputs, session_id,search_types):
|
| 24 |
+
|
| 25 |
+
print("using index: "+st.session_state.input_index)
|
| 26 |
+
|
| 27 |
+
question = inputs['query']
|
| 28 |
+
|
| 29 |
+
k=1
|
| 30 |
+
embedding = invoke_models.invoke_model_mm(question,"none")
|
| 31 |
+
|
| 32 |
+
query_mm = {
|
| 33 |
+
"size": k,
|
| 34 |
+
"_source": {
|
| 35 |
+
"exclude": [
|
| 36 |
+
"processed_element_embedding_bedrock-multimodal","processed_element_embedding_sparse","image_encoding","processed_element_embedding"
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
"query": {
|
| 40 |
+
"knn": {
|
| 41 |
+
"processed_element_embedding_bedrock-multimodal": {
|
| 42 |
+
"vector": embedding,
|
| 43 |
+
"k": k}
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
path = st.session_state.input_index+"_mm/_search"
|
| 49 |
+
url = host+path
|
| 50 |
+
r = requests.get(url, auth=awsauth, json=query_mm, headers=headers)
|
| 51 |
+
response_mm = json.loads(r.text)
|
| 52 |
+
# response_mm = ospy_client.search(
|
| 53 |
+
# body = query_mm,
|
| 54 |
+
# index = st.session_state.input_index+"_mm"
|
| 55 |
+
# )
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
hits = response_mm['hits']['hits']
|
| 60 |
+
context = []
|
| 61 |
+
context_tables = []
|
| 62 |
+
images = []
|
| 63 |
+
|
| 64 |
+
for hit in hits:
|
| 65 |
+
#context.append(hit['_source']['caption'])
|
| 66 |
+
images.append({'file':hit['_source']['image'],'caption':hit['_source']['processed_element']})
|
| 67 |
+
|
| 68 |
+
####### SEARCH ########
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
path = "_search/pipeline/rag-search-pipeline"
|
| 72 |
+
url = host + path
|
| 73 |
+
|
| 74 |
+
num_queries = len(search_types)
|
| 75 |
+
|
| 76 |
+
weights = []
|
| 77 |
+
|
| 78 |
+
searches = ['Keyword','Vector','NeuralSparse']
|
| 79 |
+
equal_weight = (int(100/num_queries) )/100
|
| 80 |
+
if(num_queries>1):
|
| 81 |
+
for index,search in enumerate(search_types):
|
| 82 |
+
|
| 83 |
+
if(index != (num_queries-1)):
|
| 84 |
+
weight = equal_weight
|
| 85 |
+
else:
|
| 86 |
+
weight = 1-sum(weights)
|
| 87 |
+
|
| 88 |
+
weights.append(weight)
|
| 89 |
+
|
| 90 |
+
#print(weights)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
s_pipeline_payload = {
|
| 94 |
+
"description": "Post processor for hybrid search",
|
| 95 |
+
"phase_results_processors": [
|
| 96 |
+
{
|
| 97 |
+
"normalization-processor": {
|
| 98 |
+
"normalization": {
|
| 99 |
+
"technique": "min_max"
|
| 100 |
+
},
|
| 101 |
+
"combination": {
|
| 102 |
+
"technique": "arithmetic_mean",
|
| 103 |
+
"parameters": {
|
| 104 |
+
"weights": weights
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
]
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
r = requests.put(url, auth=awsauth, json=s_pipeline_payload, headers=headers)
|
| 113 |
+
#print(r.status_code)
|
| 114 |
+
#print(r.text)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
SIZE = 5
|
| 119 |
+
|
| 120 |
+
hybrid_payload = {
|
| 121 |
+
"_source": {
|
| 122 |
+
"exclude": [
|
| 123 |
+
"processed_element_embedding","processed_element_embedding_sparse"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
"query": {
|
| 127 |
+
"hybrid": {
|
| 128 |
+
"queries": [
|
| 129 |
+
|
| 130 |
+
#1. keyword query
|
| 131 |
+
#2. vector search query
|
| 132 |
+
#3. Sparse query
|
| 133 |
+
|
| 134 |
+
]
|
| 135 |
+
}
|
| 136 |
+
},"size":SIZE,
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if('Keyword Search' in search_types):
|
| 142 |
+
|
| 143 |
+
keyword_payload = {
|
| 144 |
+
"match": {
|
| 145 |
+
"processed_element": {
|
| 146 |
+
"query": question
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
hybrid_payload["query"]["hybrid"]["queries"].append(keyword_payload)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if('Vector Search' in search_types):
|
| 156 |
+
|
| 157 |
+
embedding = embedding = invoke_models.invoke_model(question)
|
| 158 |
+
|
| 159 |
+
vector_payload = {
|
| 160 |
+
"knn": {
|
| 161 |
+
"processed_element_embedding": {
|
| 162 |
+
"vector": embedding,
|
| 163 |
+
"k": 2}
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
hybrid_payload["query"]["hybrid"]["queries"].append(vector_payload)
|
| 168 |
+
|
| 169 |
+
if('Sparse Search' in search_types):
|
| 170 |
+
|
| 171 |
+
#print("text expansion is enabled")
|
| 172 |
+
sparse_payload = { "neural_sparse": {
|
| 173 |
+
"processed_element_embedding_sparse": {
|
| 174 |
+
"query_text": question,
|
| 175 |
+
"model_id": "srrJ-owBQhe1aB-khx2n"
|
| 176 |
+
}
|
| 177 |
+
}}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
hybrid_payload["query"]["hybrid"]["queries"].append(sparse_payload)
|
| 181 |
+
|
| 182 |
+
# path2 = "_plugins/_ml/models/srrJ-owBQhe1aB-khx2n/_predict"
|
| 183 |
+
# url2 = host+path2
|
| 184 |
+
# payload2 = {
|
| 185 |
+
# "parameters": {
|
| 186 |
+
# "inputs": question
|
| 187 |
+
# }
|
| 188 |
+
# }
|
| 189 |
+
# r2 = requests.post(url2, auth=awsauth, json=payload2, headers=headers)
|
| 190 |
+
# sparse_ = json.loads(r2.text)
|
| 191 |
+
# query_sparse = sparse_["inference_results"][0]["output"][0]["dataAsMap"]["response"][0]
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# print("hybrid_payload")
|
| 198 |
+
# print("---------------")
|
| 199 |
+
#print(hybrid_payload)
|
| 200 |
+
hits = []
|
| 201 |
+
if(num_queries>1):
|
| 202 |
+
path = st.session_state.input_index+"/_search?search_pipeline=rag-search-pipeline"
|
| 203 |
+
else:
|
| 204 |
+
path = st.session_state.input_index+"/_search"
|
| 205 |
+
url = host+path
|
| 206 |
+
if(len(hybrid_payload["query"]["hybrid"]["queries"])==1):
|
| 207 |
+
single_query = hybrid_payload["query"]["hybrid"]["queries"][0]
|
| 208 |
+
del hybrid_payload["query"]["hybrid"]
|
| 209 |
+
hybrid_payload["query"] = single_query
|
| 210 |
+
r = requests.get(url, auth=awsauth, json=hybrid_payload, headers=headers)
|
| 211 |
+
#print(r.status_code)
|
| 212 |
+
response_ = json.loads(r.text)
|
| 213 |
+
#print("-------------------------------------------------------------------")
|
| 214 |
+
#print(r.text)
|
| 215 |
+
hits = response_['hits']['hits']
|
| 216 |
+
|
| 217 |
+
else:
|
| 218 |
+
r = requests.get(url, auth=awsauth, json=hybrid_payload, headers=headers)
|
| 219 |
+
#print(r.status_code)
|
| 220 |
+
response_ = json.loads(r.text)
|
| 221 |
+
#print("-------------------------------------------------------------------")
|
| 222 |
+
#print(response_)
|
| 223 |
+
hits = response_['hits']['hits']
|
| 224 |
+
|
| 225 |
+
##### GET reference tables separately like *_mm index search for images ######
|
| 226 |
+
def lazy_get_table():
|
| 227 |
+
#print("Forcing table analysis")
|
| 228 |
+
table_ref = []
|
| 229 |
+
any_table_exists = False
|
| 230 |
+
for fname in os.listdir(parent_dirname+"/split_pdf_csv"):
|
| 231 |
+
if fname.startswith(st.session_state.input_index):
|
| 232 |
+
any_table_exists = True
|
| 233 |
+
break
|
| 234 |
+
if(any_table_exists):
|
| 235 |
+
#################### Basic Match query #################
|
| 236 |
+
# payload_tables = {
|
| 237 |
+
# "query": {
|
| 238 |
+
# "bool":{
|
| 239 |
+
|
| 240 |
+
# "must":{"match": {
|
| 241 |
+
# "processed_element": question
|
| 242 |
+
|
| 243 |
+
# }},
|
| 244 |
+
|
| 245 |
+
# "filter":{"term":{"raw_element_type": "table"}}
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# }}}
|
| 249 |
+
|
| 250 |
+
#################### Neural Sparse query #################
|
| 251 |
+
payload_tables = {"query":{"neural_sparse": {
|
| 252 |
+
"processed_element_embedding_sparse": {
|
| 253 |
+
"query_text": question,
|
| 254 |
+
"model_id": "srrJ-owBQhe1aB-khx2n"
|
| 255 |
+
}
|
| 256 |
+
} } }
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
r_ = requests.get(url, auth=awsauth, json=payload_tables, headers=headers)
|
| 260 |
+
r_tables = json.loads(r_.text)
|
| 261 |
+
|
| 262 |
+
for res_ in r_tables['hits']['hits']:
|
| 263 |
+
if(res_["_source"]['raw_element_type'] == 'table'):
|
| 264 |
+
table_ref.append({'name':res_["_source"]['table'],'text':res_["_source"]['processed_element']})
|
| 265 |
+
if(len(table_ref) == 2):
|
| 266 |
+
break
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
return table_ref
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
########################### LLM Generation ########################
|
| 273 |
+
prompt_template = """
|
| 274 |
+
The following is a friendly conversation between a human and an AI.
|
| 275 |
+
The AI is talkative and provides lots of specific details from its context.
|
| 276 |
+
{context}
|
| 277 |
+
Instruction: Based on the above documents, provide a detailed answer for, {question}. Answer "don't know",
|
| 278 |
+
if not present in the context.
|
| 279 |
+
Solution:"""
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
idx = 0
|
| 284 |
+
images_2 = []
|
| 285 |
+
is_table_in_result = False
|
| 286 |
+
df = []
|
| 287 |
+
for hit in hits[0:3]:
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
if(hit["_source"]["raw_element_type"] == 'table'):
|
| 291 |
+
#print("Need to analyse table")
|
| 292 |
+
is_table_in_result = True
|
| 293 |
+
table_res = invoke_models.read_from_table(hit["_source"]["table"],question)
|
| 294 |
+
df.append({'name':hit["_source"]["table"],'text':hit["_source"]["processed_element"]})
|
| 295 |
+
context_tables.append(table_res+"\n\n"+hit["_source"]["processed_element"])
|
| 296 |
+
|
| 297 |
+
else:
|
| 298 |
+
if(hit["_source"]["image"]!="None"):
|
| 299 |
+
with open(parent_dirname+'/figures/'+st.session_state.input_index+"/"+hit["_source"]["raw_element_type"].split("_")[1].replace(".jpg","")+"-resized.jpg", "rb") as read_img:
|
| 300 |
+
input_encoded = base64.b64encode(read_img.read()).decode("utf8")
|
| 301 |
+
context.append(invoke_models.generate_image_captions_llm(input_encoded,question))
|
| 302 |
+
else:
|
| 303 |
+
context.append(hit["_source"]["processed_element"])
|
| 304 |
+
|
| 305 |
+
if(hit["_source"]["image"]!="None"):
|
| 306 |
+
images_2.append({'file':hit["_source"]["image"],'caption':hit["_source"]["processed_element"]})
|
| 307 |
+
|
| 308 |
+
idx = idx +1
|
| 309 |
+
#images.append(hit['_source']['image'])
|
| 310 |
+
|
| 311 |
+
# if(is_table_in_result == False):
|
| 312 |
+
# df = lazy_get_table()
|
| 313 |
+
# print("forcefully selected top 2 tables")
|
| 314 |
+
# print(df)
|
| 315 |
+
|
| 316 |
+
# for pos,table in enumerate(df):
|
| 317 |
+
# table_res = invoke_models.read_from_table(table['name'],question)
|
| 318 |
+
# context_tables.append(table_res)#+"\n\n"+table['text']
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
total_context = context_tables + context
|
| 322 |
+
|
| 323 |
+
####### Re-Rank ########
|
| 324 |
+
|
| 325 |
+
#print("re-rank")
|
| 326 |
+
|
| 327 |
+
if(st.session_state.input_is_rerank == True and len(total_context)):
|
| 328 |
+
ques = [{"question":question}]
|
| 329 |
+
ans = [{"answer":total_context}]
|
| 330 |
+
|
| 331 |
+
total_context = re_ranker.re_rank('rag','Cross Encoder',"",ques, ans)
|
| 332 |
+
|
| 333 |
+
llm_prompt = prompt_template.format(context=total_context[0],question=question)
|
| 334 |
+
output = invoke_models.invoke_llm_model( "\n\nHuman: {input}\n\nAssistant:".format(input=llm_prompt) ,False)
|
| 335 |
+
#print(output)
|
| 336 |
+
if(len(images_2)==0):
|
| 337 |
+
images_2 = images
|
| 338 |
+
return {'text':output,'source':total_context,'image':images_2,'table':df}
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: OpenSearch AI
|
| 3 |
+
emoji: 🔍
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: 1.41.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import base64
|
| 4 |
+
import yaml
|
| 5 |
+
import os
|
| 6 |
+
import urllib.request
|
| 7 |
+
import tarfile
|
| 8 |
+
import subprocess
|
| 9 |
+
from yaml.loader import SafeLoader
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
st.set_page_config(
|
| 13 |
+
|
| 14 |
+
#page_title="Semantic Search using OpenSearch",
|
| 15 |
+
layout="wide",
|
| 16 |
+
page_icon="/home/ubuntu/images/opensearch_mark_default.png"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
st.markdown("""<style>
|
| 20 |
+
@import url('https://fonts.cdnfonts.com/css/amazon-ember');
|
| 21 |
+
</style>
|
| 22 |
+
""",unsafe_allow_html=True)
|
| 23 |
+
|
| 24 |
+
# with open('/home/ubuntu/AI-search-with-amazon-opensearch-service/OpenSearchApp/auth.yaml') as file:
|
| 25 |
+
# config = yaml.load(file, Loader=SafeLoader)
|
| 26 |
+
# authenticator = Authenticate(
|
| 27 |
+
# config['credentials'],
|
| 28 |
+
# config['cookie']['name'],
|
| 29 |
+
# config['cookie']['key'],
|
| 30 |
+
# config['cookie']['expiry_days'],
|
| 31 |
+
# config['preauthorized']
|
| 32 |
+
# )
|
| 33 |
+
# name, authentication_status, username = authenticator.login('Login', 'main')
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
AI_ICON = "images/opensearch-twitter-card.png"
|
| 37 |
+
col_0_1,col_0_2,col_0_3= st.columns([10,50,85])
|
| 38 |
+
with col_0_1:
|
| 39 |
+
st.image(AI_ICON, use_container_width='always')
|
| 40 |
+
with col_0_2:
|
| 41 |
+
st.markdown('<p style="fontSize:40px;color:#FF9900;fontFamily:\'Amazon Ember Display 500\', sans-serif;">OpenSearch AI demos</p>',unsafe_allow_html=True)
|
| 42 |
+
#st.header("OpenSearch AI demos")#,divider = 'rainbow'
|
| 43 |
+
# with col_0_3:
|
| 44 |
+
# st.markdown("<a style = 'font-size:150%;background-color: #e28743;color: white;padding: 5px 10px;text-align: center;text-decoration: none;margin: 10px 20px;border-radius: 12px;display: inline-block;' href = 'https://catalog.workshops.aws/opensearch-ml-search'>Workshop</a>",unsafe_allow_html=True)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
#st.header(":rewind: Demos available")
|
| 48 |
+
st.write("")
|
| 49 |
+
#st.write("----")
|
| 50 |
+
#st.write("Choose a demo")
|
| 51 |
+
st.write("")
|
| 52 |
+
col_1_1,col_1_2,col_1_3 = st.columns([3,40,65])
|
| 53 |
+
with col_1_1:
|
| 54 |
+
st.subheader(" ")
|
| 55 |
+
with col_1_2:
|
| 56 |
+
st.markdown('<p style="fontSize:28px;color:#c5c3c0;fontFamily:\'Amazon Ember Cd RC 250\', sans-serif;">Neural Search</p>',unsafe_allow_html=True)
|
| 57 |
+
with col_1_3:
|
| 58 |
+
demo_1 = st.button(":arrow_forward:",key = "demo_1")
|
| 59 |
+
if(demo_1):
|
| 60 |
+
st.switch_page('pages/Semantic_Search.py')
|
| 61 |
+
st.write("")
|
| 62 |
+
#st.page_link("pages/1_Semantic_Search.py", label=":orange[1. Semantic Search] :arrow_forward:")
|
| 63 |
+
#st.button("1. Semantic Search")
|
| 64 |
+
# image_ = Image.open('/home/ubuntu/images/Semantic_SEarch.png')
|
| 65 |
+
# new_image = image_.resize((1500, 1000))
|
| 66 |
+
# new_image.save('images/semantic_search_resize.png')
|
| 67 |
+
# st.image("images/semantic_search_resize.png")
|
| 68 |
+
col_2_1,col_2_2,col_2_3 = st.columns([3,40,65])
|
| 69 |
+
with col_2_1:
|
| 70 |
+
st.subheader(" ")
|
| 71 |
+
with col_2_2:
|
| 72 |
+
st.markdown('<p style="fontSize:28px;color:#c5c3c0;fontFamily:\'Amazon Ember Cd RC 250\', sans-serif;">Multimodal Conversational Search</p>',unsafe_allow_html=True)
|
| 73 |
+
|
| 74 |
+
with col_2_3:
|
| 75 |
+
demo_2 = st.button(":arrow_forward:",key = "demo_2")
|
| 76 |
+
if(demo_2):
|
| 77 |
+
st.switch_page('pages/Multimodal_Conversational_Search.py')
|
| 78 |
+
st.write("")
|
| 79 |
+
#st.header("2. Multimodal Conversational Search")
|
| 80 |
+
# image_ = Image.open('images/RAG_.png')
|
| 81 |
+
# new_image = image_.resize((1500, 1000))
|
| 82 |
+
# new_image.save('images/RAG_resize.png')
|
| 83 |
+
# st.image("images/RAG_resize.png")
|
| 84 |
+
|
| 85 |
+
col_3_1,col_3_2,col_3_3 = st.columns([3,40,65])
|
| 86 |
+
with col_3_1:
|
| 87 |
+
st.subheader(" ")
|
| 88 |
+
with col_3_2:
|
| 89 |
+
st.markdown('<div style="fontSize:28px;color:#c5c3c0;fontFamily:\'Amazon Ember Cd RC 250\', sans-serif;">Agentic Shopping Assistant</div>',unsafe_allow_html=True)#<span style="fontSize:14px;color:#099ef3;fontWeight:bold;textDecorationLine:underline;textDecorationStyle: dashed;">New</span>
|
| 90 |
+
with col_3_3:
|
| 91 |
+
demo_3 = st.button(":arrow_forward:",key = "demo_3")
|
| 92 |
+
if(demo_3):
|
| 93 |
+
st.switch_page('pages/AI_Shopping_Assistant.py')
|
| 94 |
+
# with st.sidebar:
|
| 95 |
+
# st.subheader("Choose a demo !")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# """
|
| 101 |
+
# <style>
|
| 102 |
+
|
| 103 |
+
# [data-testid="stHeader"]::after {
|
| 104 |
+
# content: "My Company Name";
|
| 105 |
+
# margin-left: 0px;
|
| 106 |
+
# margin-top: 0px;
|
| 107 |
+
# font-size: 30px;
|
| 108 |
+
# position: relative;
|
| 109 |
+
# left: 90%;
|
| 110 |
+
# top: 30%;
|
| 111 |
+
# }
|
| 112 |
+
# </style>
|
| 113 |
+
# """,
|
| 114 |
+
|
| 115 |
+
isExist = os.path.exists("/home/user/images_retail")
|
| 116 |
+
if not isExist:
|
| 117 |
+
os.makedirs("/home/user/images_retail")
|
| 118 |
+
metadata_file = urllib.request.urlretrieve('https://aws-blogs-artifacts-public.s3.amazonaws.com/BDB-3144/products-data.yml', '/home/user/products.yaml')
|
| 119 |
+
img_filename,headers= urllib.request.urlretrieve('https://aws-blogs-artifacts-public.s3.amazonaws.com/BDB-3144/images.tar.gz', '/home/user/images_retail/images.tar.gz')
|
| 120 |
+
print(img_filename)
|
| 121 |
+
file = tarfile.open('/home/user/images_retail/images.tar.gz')
|
| 122 |
+
file.extractall('/home/user/images_retail/')
|
| 123 |
+
file.close()
|
| 124 |
+
#remove images.tar.gz
|
| 125 |
+
os.remove('/home/user/images_retail/images.tar.gz')
|
figures/ukhousingstats/figure-1-1-resized.jpg
ADDED
|
figures/ukhousingstats/figure-1-1.jpg
ADDED
|
figures/ukhousingstats/figure-1-2-resized.jpg
ADDED
|
figures/ukhousingstats/figure-1-2.jpg
ADDED
|
figures/ukhousingstats/figure-2-3-resized.jpg
ADDED
|
figures/ukhousingstats/figure-2-3.jpg
ADDED
|
figures/ukhousingstats/figure-3-4-resized.jpg
ADDED
|
figures/ukhousingstats/figure-3-4.jpg
ADDED
|
figures/ukhousingstats/figure-3-5-resized.jpg
ADDED
|
figures/ukhousingstats/figure-3-5.jpg
ADDED
|
figures/ukhousingstats/figure-4-6-resized.jpg
ADDED
|
figures/ukhousingstats/figure-4-6.jpg
ADDED
|
figures/ukhousingstats/figure-4-7-resized.jpg
ADDED
|
figures/ukhousingstats/figure-4-7.jpg
ADDED
|
figures/ukhousingstats/figure-5-8-resized.jpg
ADDED
|
figures/ukhousingstats/figure-5-8.jpg
ADDED
|
figures/ukhousingstats/figure-6-10-resized.jpg
ADDED
|
figures/ukhousingstats/figure-6-10.jpg
ADDED
|
figures/ukhousingstats/figure-6-11-resized.jpg
ADDED
|
figures/ukhousingstats/figure-6-11.jpg
ADDED
|
figures/ukhousingstats/figure-6-12-resized.jpg
ADDED
|
figures/ukhousingstats/figure-6-12.jpg
ADDED
|
figures/ukhousingstats/figure-6-13-resized.jpg
ADDED
|
figures/ukhousingstats/figure-6-13.jpg
ADDED
|
figures/ukhousingstats/figure-6-14-resized.jpg
ADDED
|
figures/ukhousingstats/figure-6-14.jpg
ADDED
|
figures/ukhousingstats/figure-6-15-resized.jpg
ADDED
|
figures/ukhousingstats/figure-6-15.jpg
ADDED
|
figures/ukhousingstats/figure-6-16-resized.jpg
ADDED
|
figures/ukhousingstats/figure-6-16.jpg
ADDED
|
figures/ukhousingstats/figure-6-17-resized.jpg
ADDED
|
figures/ukhousingstats/figure-6-17.jpg
ADDED
|
figures/ukhousingstats/figure-6-18-resized.jpg
ADDED
|
figures/ukhousingstats/figure-6-18.jpg
ADDED
|
figures/ukhousingstats/figure-6-19-resized.jpg
ADDED
|
figures/ukhousingstats/figure-6-19.jpg
ADDED
|
figures/ukhousingstats/figure-6-20-resized.jpg
ADDED
|
figures/ukhousingstats/figure-6-20.jpg
ADDED
|
figures/ukhousingstats/figure-6-21-resized.jpg
ADDED
|
figures/ukhousingstats/figure-6-21.jpg
ADDED
|
figures/ukhousingstats/figure-6-22-resized.jpg
ADDED
|