File size: 3,689 Bytes
25a9d18
 
 
859cc25
 
5c84ace
6f53a5f
859cc25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25a9d18
 
 
 
 
 
 
ee3ffed
25a9d18
 
 
 
 
 
 
 
aabed80
 
 
 
25a9d18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a48575a
25a9d18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da4ed7a
25a9d18
 
 
 
 
 
 
 
 
 
859cc25
ee6f528
859cc25
 
 
 
25a9d18
859cc25
ee6f528
25a9d18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import streamlit as st
from utils import convert_to_base64, convert_to_html
import requests
import boto3
import sagemaker
import os
import json

region = os.getenv("region")
sm_endpoint_name = os.getenv("sm_endpoint_name")
access_key = os.getenv("access_key")
secret_key = os.getenv("secret_key")
hf_token = os.getenv("hf_read_access")

session = boto3.Session(
    aws_access_key_id=access_key,
    aws_secret_access_key=secret_key,
    region_name=region
)
sess = sagemaker.Session(boto_session=session)

smr = session.client("sagemaker-runtime")

headers = {'Content-Type': 'application/json'}
st.set_page_config(page_title="AWS Inferentia2 Demo", layout="wide")
#st.set_page_config(layout="wide")

st.title("Multimodal Model on AWS Inf2")
st.subheader("LLaVA-1.6-Mistral-7B")
st.text(" LLaVA (or Large Language and Vision Assistant), an open-source large multi-modal model. This demo is running on AWS Inferentia2 built with Llava1.6.")


def upload_image():
    image_list=["./images/view.jpg",
            "./images/cat.jpg",
            "./images/olympic.jpg",
            "./images/usa.jpg",
            "./images/box.jpg"]
    name_list=["view(from internet)",
            "cat(from internet)",
            "paris 2024(from internet)",
            "statue of liberty(from internet)",
            "box(from my camera)"]
    images_all = dict(zip(name_list, image_list))

    user_option = st.selectbox("Select a preset image", ["–Select–"] + name_list)
    print(user_option)
    if user_option!="–Select–":
        image_names=[images_all[user_option]]
    else:
        image_names=[]

    st.text("OR")

    images = st.file_uploader("Upload an image to chat about", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
    #print(images)
    # assert max number of images, e.g. 1
    assert len(images) <= 1, (st.error("Please upload at most 1 image"), st.stop())

    if images or image_names:
        if images:
            image_names=[]
        # convert images to base64
        images_b64 = []
        for image in images+image_names:
            image_b64 = convert_to_base64(image)
            images_b64.append(image_b64)

        # display images in multiple columns
        cols = st.columns(len(images_b64)) ##only process first image
        for i, col in enumerate(cols):
            col.markdown(f"**Image {i+1}**")
            col.markdown(convert_to_html(images_b64[i]), unsafe_allow_html=True)
            break #only process first image
        st.markdown("---")
        return images_b64[0]    #only process first image
    st.stop()


@st.cache_data(show_spinner=False)
def ask_llm(prompt, byte_image):
    payload = {
        "prompt":prompt,
        "image": byte_image,
        "parameters": {
            "top_k": 100,
            "top_p": 0.1,
            "temperature": 0.2,
        }
    }
    #response = requests.post(url, json=payload, headers=headers)
    response_model = smr.invoke_endpoint(
        EndpointName=sm_endpoint_name,
        Body=json.dumps(payload),
        ContentType="application/json",
    )

    #return response.text
    return response_model['Body'].read().decode('utf8')

def app():
    st.markdown("---")
    c1, c2 = st.columns(2)
    with c2:
        image_b64 = upload_image()
    with c1:
        question = st.chat_input("Ask a question about this image")

    if not question: st.stop()
    with c1:
        with st.chat_message("question"):
            st.markdown(question, unsafe_allow_html=True)
        with st.spinner("Thinking..."):
            res = ask_llm(question, image_b64)
            with st.chat_message("response"):
                st.write(res)

if __name__ == "__main__":
    app()