Star-wars-app / app.py
victorbahlangene's picture
Update app.py
9adb505
from pathlib import Path
from fastai.vision.all import *
from bs4 import BeautifulSoup
import requests
import re
import nltk
from pprint import pprint
import random
import streamlit as st
import pathlib
plt =platform.system()
if plt == "Windows": pathlib.WindowsPath = pathlib.PosixPath
# Add random seed for answer location #
random.seed(42)
# app layout #
st.set_page_config(
page_title="Star Wars Character App"
)
#st.sidebar.success("Select a page above.")
# set session_state for character name prediction #
if "char_label" not in st.session_state:
st.session_state["char_label"] = ""
# session_state for quiz tests #
if "quiz_corpus" not in st.session_state:
st.session_state["quiz_corpus"] = ""
# make prediction function #
def make_pred(_model, _image):
"""
we return the predicted label to be used later
"""
label, idx, preds = _model.predict(_image)
print(
f"This is a picture of {label}, model is {preds[idx] * 100:.2f}% sure.")
st.write(
f"This is a picture of {label}, model is {preds[idx] * 100:.2f}% sure.")
return label
# Load Model function #
@st.experimental_singleton
def load_model():
path = Path()
learn_inf = load_learner(path / 'star-wars-characters-model_res34.pkl')
return learn_inf
# scrape web data for summary function #
@st.experimental_singleton
def srape_wiki(star_wars_character):
url = "https://starwars.fandom.com/wiki/"+star_wars_character
print(url)
r = requests.get(url)
soup = BeautifulSoup(r.text, "html.parser")
div_top = soup.find("div", class_="quote")
div_bot = soup.find("div", id="toc", class_="toc")
content = ""
item = div_top.nextSibling
while item != div_bot:
content += str(item)
item = item.nextSibling
# beautify content #
content_b = BeautifulSoup(content, "html.parser")
text_arr = []
for sentence in content_b.find_all("p"):
text_arr.append(sentence.text.strip())
# make text one continous string #
text_str = " ".join(text_arr)
# remove '\n' #
text_str = text_str.split("\n")
text_str = " ".join(text_str)
# remove brackets and all inside them #
corpus = re.sub(r'\[.*?\]', "", text_str)
return corpus
# make quiz prediction #
@st.experimental_singleton
def model_predict(payload):
nltk.download('universal_tagset')
nltk.download('stopwords')
from Questgen import main
# load t5 model #
qg_mcq = main.QGen()
model_out = qg_mcq.predict_mcq(payload)
for i in model_out["questions"]:
i["options"].insert(random.randint(0, 3), i["answer"])
return model_out
st.markdown("<h1 style='text-align: center; color: grey;'>Star Wars Character App</h1>",
unsafe_allow_html=True)
# st.markdown("<h2 style='text-align: center; color: black;'>Character Classification</h2>",
# unsafe_allow_html=True)
# containers #
col1, col2, col3 = st.columns(3)
# loading fastai model #
learn_inf = load_model()
# CLASSIFICATION SECTION #
with st.expander("Image Classification"):
st.markdown("<h2 style='text-align: center; color: black;'>Character Classification</h2>",
unsafe_allow_html=True)
# upload image #
uploaded_file1 = st.file_uploader(
"Upload Star wars character", type=['png', 'jpeg', 'jpg'])
if uploaded_file1 is not None:
image_file1 = PILImage.create((uploaded_file1))
# with st.expander("See Image"):
st.image(image_file1.to_thumb(200, 200),
caption='Uploaded Image')
with st.form(key="image_class"):
classify_img = st.form_submit_button("Submit")
if classify_img:
# with st.expander("See explanation"):
# st.image(image_file1.to_thumb(200, 200), caption='Uploaded Image')
st.session_state["char_label"] = make_pred(learn_inf, image_file1)
st.markdown("<br></br>", unsafe_allow_html=True)
# SUMMARY SECTION #
with st.expander("Summary"):
st.markdown("<h2 style='text-align: center; color: black;'>Character Summary</h2>",
unsafe_allow_html=True)
st.write("Summary of: ", st.session_state["char_label"])
try:
st.session_state["quiz_corpus"] = srape_wiki(
st.session_state["char_label"])
st.write(st.session_state["quiz_corpus"])
except AttributeError:
st.error(
"Please choose a different variation of the character name")
out_text_area = st.text_input(
"Charater name", st.session_state["char_label"])
with st.form(key="summary"):
#st.write("Summary of ", st.session_state["char_label"])
summary = st.form_submit_button("Summary")
if summary:
st.write(out_text_area)
st.session_state["char_label"] = out_text_area
st.session_state["quiz_corpus"] = srape_wiki(
st.session_state["char_label"])
st.write(st.session_state["quiz_corpus"])
st.markdown("<br></br>", unsafe_allow_html=True)
# QUIZ SECTION #
with st.expander("Quiz"):
st.markdown("<h2 style='text-align: center; color: black;'>Character Quiz</h2>",
unsafe_allow_html=True)
payload = {
"input_text": st.session_state["quiz_corpus"]
}
try:
model_output = model_predict(payload)
for i in model_output["questions"]:
with st.form(key=str(i["id"])):
st.write(f"Question {i['id']}")
entry = st.radio(label=i["question_statement"],
options=(i["options"]), key=str(i["id"]))
checkbox_val = st.checkbox("Do you want a clue?")
submitted = st.form_submit_button(label='Submit')
if submitted:
if i["answer"] == entry:
st.write("CORRECT!")
else:
st.write("Wrong, check clue")
if checkbox_val:
st.write(i["context"])
except KeyError:
print("error caught")