Spaces:
Sleeping
Sleeping
File size: 3,689 Bytes
25a9d18 859cc25 5c84ace 6f53a5f 859cc25 25a9d18 ee3ffed 25a9d18 aabed80 25a9d18 a48575a 25a9d18 da4ed7a 25a9d18 859cc25 ee6f528 859cc25 25a9d18 859cc25 ee6f528 25a9d18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import streamlit as st
from utils import convert_to_base64, convert_to_html
import requests
import boto3
import sagemaker
import os
import json
region = os.getenv("region")
sm_endpoint_name = os.getenv("sm_endpoint_name")
access_key = os.getenv("access_key")
secret_key = os.getenv("secret_key")
hf_token = os.getenv("hf_read_access")
session = boto3.Session(
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=region
)
sess = sagemaker.Session(boto_session=session)
smr = session.client("sagemaker-runtime")
headers = {'Content-Type': 'application/json'}
st.set_page_config(page_title="AWS Inferentia2 Demo", layout="wide")
#st.set_page_config(layout="wide")
st.title("Multimodal Model on AWS Inf2")
st.subheader("LLaVA-1.6-Mistral-7B")
st.text(" LLaVA (or Large Language and Vision Assistant), an open-source large multi-modal model. This demo is running on AWS Inferentia2 built with Llava1.6.")
def upload_image():
image_list=["./images/view.jpg",
"./images/cat.jpg",
"./images/olympic.jpg",
"./images/usa.jpg",
"./images/box.jpg"]
name_list=["view(from internet)",
"cat(from internet)",
"paris 2024(from internet)",
"statue of liberty(from internet)",
"box(from my camera)"]
images_all = dict(zip(name_list, image_list))
user_option = st.selectbox("Select a preset image", ["–Select–"] + name_list)
print(user_option)
if user_option!="–Select–":
image_names=[images_all[user_option]]
else:
image_names=[]
st.text("OR")
images = st.file_uploader("Upload an image to chat about", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
#print(images)
# assert max number of images, e.g. 1
assert len(images) <= 1, (st.error("Please upload at most 1 image"), st.stop())
if images or image_names:
if images:
image_names=[]
# convert images to base64
images_b64 = []
for image in images+image_names:
image_b64 = convert_to_base64(image)
images_b64.append(image_b64)
# display images in multiple columns
cols = st.columns(len(images_b64)) ##only process first image
for i, col in enumerate(cols):
col.markdown(f"**Image {i+1}**")
col.markdown(convert_to_html(images_b64[i]), unsafe_allow_html=True)
break #only process first image
st.markdown("---")
return images_b64[0] #only process first image
st.stop()
@st.cache_data(show_spinner=False)
def ask_llm(prompt, byte_image):
payload = {
"prompt":prompt,
"image": byte_image,
"parameters": {
"top_k": 100,
"top_p": 0.1,
"temperature": 0.2,
}
}
#response = requests.post(url, json=payload, headers=headers)
response_model = smr.invoke_endpoint(
EndpointName=sm_endpoint_name,
Body=json.dumps(payload),
ContentType="application/json",
)
#return response.text
return response_model['Body'].read().decode('utf8')
def app():
st.markdown("---")
c1, c2 = st.columns(2)
with c2:
image_b64 = upload_image()
with c1:
question = st.chat_input("Ask a question about this image")
if not question: st.stop()
with c1:
with st.chat_message("question"):
st.markdown(question, unsafe_allow_html=True)
with st.spinner("Thinking..."):
res = ask_llm(question, image_b64)
with st.chat_message("response"):
st.write(res)
if __name__ == "__main__":
app() |