Spaces:
Build error
Build error
from pathlib import Path | |
from fastai.vision.all import * | |
from bs4 import BeautifulSoup | |
import requests | |
import re | |
import nltk | |
from pprint import pprint | |
import random | |
import streamlit as st | |
import pathlib | |
plt =platform.system() | |
if plt == "Windows": pathlib.WindowsPath = pathlib.PosixPath | |
# Add random seed for answer location # | |
random.seed(42) | |
# app layout # | |
st.set_page_config( | |
page_title="Star Wars Character App" | |
) | |
#st.sidebar.success("Select a page above.") | |
# set session_state for character name prediction # | |
if "char_label" not in st.session_state: | |
st.session_state["char_label"] = "" | |
# session_state for quiz tests # | |
if "quiz_corpus" not in st.session_state: | |
st.session_state["quiz_corpus"] = "" | |
# make prediction function # | |
def make_pred(_model, _image): | |
""" | |
we return the predicted label to be used later | |
""" | |
label, idx, preds = _model.predict(_image) | |
print( | |
f"This is a picture of {label}, model is {preds[idx] * 100:.2f}% sure.") | |
st.write( | |
f"This is a picture of {label}, model is {preds[idx] * 100:.2f}% sure.") | |
return label | |
# Load Model function # | |
def load_model(): | |
path = Path() | |
learn_inf = load_learner(path / 'star-wars-characters-model_res34.pkl') | |
return learn_inf | |
# scrape web data for summary function # | |
def srape_wiki(star_wars_character): | |
url = "https://starwars.fandom.com/wiki/"+star_wars_character | |
print(url) | |
r = requests.get(url) | |
soup = BeautifulSoup(r.text, "html.parser") | |
div_top = soup.find("div", class_="quote") | |
div_bot = soup.find("div", id="toc", class_="toc") | |
content = "" | |
item = div_top.nextSibling | |
while item != div_bot: | |
content += str(item) | |
item = item.nextSibling | |
# beautify content # | |
content_b = BeautifulSoup(content, "html.parser") | |
text_arr = [] | |
for sentence in content_b.find_all("p"): | |
text_arr.append(sentence.text.strip()) | |
# make text one continous string # | |
text_str = " ".join(text_arr) | |
# remove '\n' # | |
text_str = text_str.split("\n") | |
text_str = " ".join(text_str) | |
# remove brackets and all inside them # | |
corpus = re.sub(r'\[.*?\]', "", text_str) | |
return corpus | |
# make quiz prediction # | |
def model_predict(payload): | |
nltk.download('universal_tagset') | |
nltk.download('stopwords') | |
from Questgen import main | |
# load t5 model # | |
qg_mcq = main.QGen() | |
model_out = qg_mcq.predict_mcq(payload) | |
for i in model_out["questions"]: | |
i["options"].insert(random.randint(0, 3), i["answer"]) | |
return model_out | |
st.markdown("<h1 style='text-align: center; color: grey;'>Star Wars Character App</h1>", | |
unsafe_allow_html=True) | |
# st.markdown("<h2 style='text-align: center; color: black;'>Character Classification</h2>", | |
# unsafe_allow_html=True) | |
# containers # | |
col1, col2, col3 = st.columns(3) | |
# loading fastai model # | |
learn_inf = load_model() | |
# CLASSIFICATION SECTION # | |
with st.expander("Image Classification"): | |
st.markdown("<h2 style='text-align: center; color: black;'>Character Classification</h2>", | |
unsafe_allow_html=True) | |
# upload image # | |
uploaded_file1 = st.file_uploader( | |
"Upload Star wars character", type=['png', 'jpeg', 'jpg']) | |
if uploaded_file1 is not None: | |
image_file1 = PILImage.create((uploaded_file1)) | |
# with st.expander("See Image"): | |
st.image(image_file1.to_thumb(200, 200), | |
caption='Uploaded Image') | |
with st.form(key="image_class"): | |
classify_img = st.form_submit_button("Submit") | |
if classify_img: | |
# with st.expander("See explanation"): | |
# st.image(image_file1.to_thumb(200, 200), caption='Uploaded Image') | |
st.session_state["char_label"] = make_pred(learn_inf, image_file1) | |
st.markdown("<br></br>", unsafe_allow_html=True) | |
# SUMMARY SECTION # | |
with st.expander("Summary"): | |
st.markdown("<h2 style='text-align: center; color: black;'>Character Summary</h2>", | |
unsafe_allow_html=True) | |
st.write("Summary of: ", st.session_state["char_label"]) | |
try: | |
st.session_state["quiz_corpus"] = srape_wiki( | |
st.session_state["char_label"]) | |
st.write(st.session_state["quiz_corpus"]) | |
except AttributeError: | |
st.error( | |
"Please choose a different variation of the character name") | |
out_text_area = st.text_input( | |
"Charater name", st.session_state["char_label"]) | |
with st.form(key="summary"): | |
#st.write("Summary of ", st.session_state["char_label"]) | |
summary = st.form_submit_button("Summary") | |
if summary: | |
st.write(out_text_area) | |
st.session_state["char_label"] = out_text_area | |
st.session_state["quiz_corpus"] = srape_wiki( | |
st.session_state["char_label"]) | |
st.write(st.session_state["quiz_corpus"]) | |
st.markdown("<br></br>", unsafe_allow_html=True) | |
# QUIZ SECTION # | |
with st.expander("Quiz"): | |
st.markdown("<h2 style='text-align: center; color: black;'>Character Quiz</h2>", | |
unsafe_allow_html=True) | |
payload = { | |
"input_text": st.session_state["quiz_corpus"] | |
} | |
try: | |
model_output = model_predict(payload) | |
for i in model_output["questions"]: | |
with st.form(key=str(i["id"])): | |
st.write(f"Question {i['id']}") | |
entry = st.radio(label=i["question_statement"], | |
options=(i["options"]), key=str(i["id"])) | |
checkbox_val = st.checkbox("Do you want a clue?") | |
submitted = st.form_submit_button(label='Submit') | |
if submitted: | |
if i["answer"] == entry: | |
st.write("CORRECT!") | |
else: | |
st.write("Wrong, check clue") | |
if checkbox_val: | |
st.write(i["context"]) | |
except KeyError: | |
print("error caught") | |