Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import streamlit as st | |
| import uuid | |
| import os | |
| import re | |
| import sys | |
| import uuid | |
| from io import BytesIO | |
| sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/semantic_search") | |
| sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/RAG") | |
| sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/utilities") | |
| import boto3 | |
| import requests | |
| from boto3 import Session | |
| import botocore.session | |
| import json | |
| import random | |
| import string | |
| # import rag_DocumentLoader | |
| # import rag_DocumentSearcher | |
| import pandas as pd | |
| from PIL import Image | |
| import shutil | |
| import base64 | |
| import time | |
| import botocore | |
| #from langchain.callbacks.base import BaseCallbackHandler | |
| #import streamlit_nested_layout | |
| #from IPython.display import clear_output, display, display_markdown, Markdown | |
| from requests_aws4auth import AWS4Auth | |
| #import copali | |
| from requests.auth import HTTPBasicAuth | |
| import bedrock_agent | |
| import warnings | |
| warnings.filterwarnings("ignore", category=DeprecationWarning) | |
| st.set_page_config( | |
| layout="wide", | |
| page_icon="images/opensearch_mark_default.png" | |
| ) | |
| parent_dirname = '/home/ubuntu/AI-search-with-amazon-opensearch-service/OpenSearchApp' | |
| USER_ICON = "images/user.png" | |
| AI_ICON = "images/opensearch-twitter-card.png" | |
| REGENERATE_ICON = "images/regenerate.png" | |
| s3_bucket_ = "pdf-repo-uploads" | |
| polly_client = boto3.Session( | |
| region_name='us-east-1').client('polly') | |
| # Check if the user ID is already stored in the session state | |
| if 'user_id' in st.session_state: | |
| user_id = st.session_state['user_id'] | |
| # If the user ID is not yet stored in the session state, generate a random UUID | |
| else: | |
| user_id = str(uuid.uuid4()) | |
| st.session_state['user_id'] = user_id | |
| if 'session_id_' not in st.session_state: | |
| st.session_state['session_id_'] = str(uuid.uuid1()) | |
| if "chats" not in st.session_state: | |
| st.session_state.chats = [ | |
| { | |
| 'id': 0, | |
| 'question': '', | |
| 'answer': '' | |
| } | |
| ] | |
| if "questions__" not in st.session_state: | |
| st.session_state.questions__ = [] | |
| if "answers__" not in st.session_state: | |
| st.session_state.answers__ = [] | |
| if "input_is_rerank" not in st.session_state: | |
| st.session_state.input_is_rerank = True | |
| if "input_copali_rerank" not in st.session_state: | |
| st.session_state.input_copali_rerank = False | |
| if "input_table_with_sql" not in st.session_state: | |
| st.session_state.input_table_with_sql = False | |
| if "inputs_" not in st.session_state: | |
| st.session_state.inputs_ = {} | |
| if "input_shopping_query" not in st.session_state: | |
| st.session_state.input_shopping_query="get me shoes suitable for trekking" | |
| if "input_rag_searchType" not in st.session_state: | |
| st.session_state.input_rag_searchType = ["Sparse Search"] | |
| region = 'us-east-1' | |
| output = [] | |
| service = 'es' | |
| st.markdown(""" | |
| <style> | |
| [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{ | |
| gap: 0rem; | |
| } | |
| [data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock]{ | |
| gap: 0rem; | |
| } | |
| </style> | |
| """,unsafe_allow_html=True) | |
| def write_logo(): | |
| col1, col2, col3 = st.columns([5, 1, 5]) | |
| with col2: | |
| st.image(AI_ICON, use_column_width='always') | |
| def write_top_bar(): | |
| col1, col2 = st.columns([77,23]) | |
| with col1: | |
| st.page_link("app.py", label=":orange[Home]", icon="🏠") | |
| st.header("AI Shopping assistant",divider='rainbow') | |
| with col2: | |
| st.write("") | |
| st.write("") | |
| clear = st.button("Clear") | |
| st.write("") | |
| st.write("") | |
| return clear | |
| clear = write_top_bar() | |
| if clear: | |
| st.session_state.questions__ = [] | |
| st.session_state.answers__ = [] | |
| st.session_state.input_shopping_query="" | |
| st.session_state.session_id_ = str(uuid.uuid1()) | |
| bedrock_agent.delete_memory() | |
| def handle_input(): | |
| if(st.session_state.input_shopping_query==''): | |
| return "" | |
| inputs = {} | |
| for key in st.session_state: | |
| if key.startswith('input_'): | |
| inputs[key.removeprefix('input_')] = st.session_state[key] | |
| st.session_state.inputs_ = inputs | |
| question_with_id = { | |
| 'question': inputs["shopping_query"], | |
| 'id': len(st.session_state.questions__) | |
| } | |
| st.session_state.questions__.append(question_with_id) | |
| print(inputs) | |
| out_ = bedrock_agent.query_(inputs) | |
| st.session_state.answers__.append({ | |
| 'answer': out_['text'], | |
| 'source':out_['source'], | |
| 'last_tool':out_['last_tool'], | |
| 'id': len(st.session_state.questions__) | |
| }) | |
| st.session_state.input_shopping_query="" | |
| def write_user_message(md): | |
| col1, col2 = st.columns([3,97]) | |
| with col1: | |
| st.image(USER_ICON, use_column_width='always') | |
| with col2: | |
| st.markdown("<div style='color:#e28743';font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;'>"+md['question']+"</div>", unsafe_allow_html = True) | |
| def render_answer(question,answer,index): | |
| col1, col2, col_3 = st.columns([4,74,22]) | |
| with col1: | |
| st.image(AI_ICON, use_column_width='always') | |
| with col2: | |
| use_interim_results = False | |
| src_dict = {} | |
| ans_ = answer['answer'] | |
| span_ans = ans_.replace('<question>',"<span style='fontSize:18px;color:#f37709;fontStyle:italic;'>").replace("</question>","</span>") | |
| st.markdown("<p>"+span_ans+"</p>",unsafe_allow_html = True) | |
| if(answer['last_tool']['name'] in ["generate_images","get_relevant_items_for_image","get_relevant_items_for_text","retrieve_with_hybrid_search","retrieve_with_keyword_search","get_any_general_recommendation"]): | |
| use_interim_results = True | |
| src_dict =json.loads(answer['last_tool']['response'].replace("'",'"')) | |
| if(use_interim_results and answer['last_tool']['name']!= 'generate_images' and answer['last_tool']['name']!= 'get_any_general_recommendation'): | |
| key_ = answer['last_tool']['name'] | |
| st.write("<br><br>",unsafe_allow_html = True) | |
| img_col1, img_col2, img_col3 = st.columns([30,30,40]) | |
| for index,item in enumerate(src_dict[key_]): | |
| response_ = requests.get(item['image']) | |
| img = Image.open(BytesIO(response_.content)) | |
| resizedImg = img.resize((230, 180), Image.Resampling.LANCZOS) | |
| if(index ==0): | |
| with img_col1: | |
| st.image(resizedImg,use_column_width = True,caption = item['title']) | |
| if(index ==1): | |
| with img_col2: | |
| st.image(resizedImg,use_column_width = True,caption = item['title']) | |
| if(answer['last_tool']['name'] == "generate_images" or answer['last_tool']['name'] == "get_any_general_recommendation"): | |
| st.write("<br>",unsafe_allow_html = True) | |
| gen_img_col1, gen_img_col2,gen_img_col2 = st.columns([30,30,30]) | |
| res = src_dict['generate_images'].replace('s3://','') | |
| s3_ = boto3.resource('s3', | |
| aws_access_key_id=st.secrets['user_access_key'], | |
| aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1') | |
| key = res.split('/')[1] | |
| s3_stream = s3_.Object("bedrock-video-generation-us-east-1-lbxkrh", key).get()['Body'].read() | |
| img_ = Image.open(BytesIO(s3_stream)) | |
| resizedImg = img_.resize((230, 180), Image.Resampling.LANCZOS) | |
| with gen_img_col1: | |
| st.image(resizedImg,caption = "Generated image for "+key.split(".")[0],use_column_width = True) | |
| st.write("<br>",unsafe_allow_html = True) | |
| colu1,colu2,colu3 = st.columns([4,82,20]) | |
| if(answer['source']!={}): | |
| with colu2: | |
| with st.expander("Agent Traces:"): | |
| st.write(answer['source']) | |
| #Each answer will have context of the question asked in order to associate the provided feedback with the respective question | |
| def write_chat_message(md, q,index): | |
| chat = st.container() | |
| with chat: | |
| render_answer(q,md,index) | |
| def render_all(): | |
| index = 0 | |
| for (q, a) in zip(st.session_state.questions__, st.session_state.answers__): | |
| index = index +1 | |
| write_user_message(q) | |
| write_chat_message(a, q,index) | |
| placeholder = st.empty() | |
| with placeholder.container(): | |
| render_all() | |
| st.markdown("") | |
| col_2, col_3 = st.columns([75,20]) | |
| with col_2: | |
| input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_shopping_query") | |
| with col_3: | |
| play = st.button("Go",on_click=handle_input,key = "play") | |