Spaces:
Build error
Build error
| from pathlib import Path | |
| from fastai.vision.all import * | |
| from bs4 import BeautifulSoup | |
| import requests | |
| import re | |
| import nltk | |
| from pprint import pprint | |
| import random | |
| import streamlit as st | |
| import pathlib | |
| plt =platform.system() | |
| if plt == "Windows": pathlib.WindowsPath = pathlib.PosixPath | |
| # Add random seed for answer location # | |
| random.seed(42) | |
| # app layout # | |
| st.set_page_config( | |
| page_title="Star Wars Character App" | |
| ) | |
| #st.sidebar.success("Select a page above.") | |
| # set session_state for character name prediction # | |
| if "char_label" not in st.session_state: | |
| st.session_state["char_label"] = "" | |
| # session_state for quiz tests # | |
| if "quiz_corpus" not in st.session_state: | |
| st.session_state["quiz_corpus"] = "" | |
| # make prediction function # | |
| def make_pred(_model, _image): | |
| """ | |
| we return the predicted label to be used later | |
| """ | |
| label, idx, preds = _model.predict(_image) | |
| print( | |
| f"This is a picture of {label}, model is {preds[idx] * 100:.2f}% sure.") | |
| st.write( | |
| f"This is a picture of {label}, model is {preds[idx] * 100:.2f}% sure.") | |
| return label | |
| # Load Model function # | |
| def load_model(): | |
| path = Path() | |
| learn_inf = load_learner(path / 'star-wars-characters-model_res34.pkl') | |
| return learn_inf | |
| # scrape web data for summary function # | |
| def srape_wiki(star_wars_character): | |
| url = "https://starwars.fandom.com/wiki/"+star_wars_character | |
| print(url) | |
| r = requests.get(url) | |
| soup = BeautifulSoup(r.text, "html.parser") | |
| div_top = soup.find("div", class_="quote") | |
| div_bot = soup.find("div", id="toc", class_="toc") | |
| content = "" | |
| item = div_top.nextSibling | |
| while item != div_bot: | |
| content += str(item) | |
| item = item.nextSibling | |
| # beautify content # | |
| content_b = BeautifulSoup(content, "html.parser") | |
| text_arr = [] | |
| for sentence in content_b.find_all("p"): | |
| text_arr.append(sentence.text.strip()) | |
| # make text one continous string # | |
| text_str = " ".join(text_arr) | |
| # remove '\n' # | |
| text_str = text_str.split("\n") | |
| text_str = " ".join(text_str) | |
| # remove brackets and all inside them # | |
| corpus = re.sub(r'\[.*?\]', "", text_str) | |
| return corpus | |
| # make quiz prediction # | |
| def model_predict(payload): | |
| nltk.download('universal_tagset') | |
| nltk.download('stopwords') | |
| from Questgen import main | |
| # load t5 model # | |
| qg_mcq = main.QGen() | |
| model_out = qg_mcq.predict_mcq(payload) | |
| for i in model_out["questions"]: | |
| i["options"].insert(random.randint(0, 3), i["answer"]) | |
| return model_out | |
| st.markdown("<h1 style='text-align: center; color: grey;'>Star Wars Character App</h1>", | |
| unsafe_allow_html=True) | |
| # st.markdown("<h2 style='text-align: center; color: black;'>Character Classification</h2>", | |
| # unsafe_allow_html=True) | |
| # containers # | |
| col1, col2, col3 = st.columns(3) | |
| # loading fastai model # | |
| learn_inf = load_model() | |
| # CLASSIFICATION SECTION # | |
| with st.expander("Image Classification"): | |
| st.markdown("<h2 style='text-align: center; color: black;'>Character Classification</h2>", | |
| unsafe_allow_html=True) | |
| # upload image # | |
| uploaded_file1 = st.file_uploader( | |
| "Upload Star wars character", type=['png', 'jpeg', 'jpg']) | |
| if uploaded_file1 is not None: | |
| image_file1 = PILImage.create((uploaded_file1)) | |
| # with st.expander("See Image"): | |
| st.image(image_file1.to_thumb(200, 200), | |
| caption='Uploaded Image') | |
| with st.form(key="image_class"): | |
| classify_img = st.form_submit_button("Submit") | |
| if classify_img: | |
| # with st.expander("See explanation"): | |
| # st.image(image_file1.to_thumb(200, 200), caption='Uploaded Image') | |
| st.session_state["char_label"] = make_pred(learn_inf, image_file1) | |
| st.markdown("<br></br>", unsafe_allow_html=True) | |
| # SUMMARY SECTION # | |
| with st.expander("Summary"): | |
| st.markdown("<h2 style='text-align: center; color: black;'>Character Summary</h2>", | |
| unsafe_allow_html=True) | |
| st.write("Summary of: ", st.session_state["char_label"]) | |
| try: | |
| st.session_state["quiz_corpus"] = srape_wiki( | |
| st.session_state["char_label"]) | |
| st.write(st.session_state["quiz_corpus"]) | |
| except AttributeError: | |
| st.error( | |
| "Please choose a different variation of the character name") | |
| out_text_area = st.text_input( | |
| "Charater name", st.session_state["char_label"]) | |
| with st.form(key="summary"): | |
| #st.write("Summary of ", st.session_state["char_label"]) | |
| summary = st.form_submit_button("Summary") | |
| if summary: | |
| st.write(out_text_area) | |
| st.session_state["char_label"] = out_text_area | |
| st.session_state["quiz_corpus"] = srape_wiki( | |
| st.session_state["char_label"]) | |
| st.write(st.session_state["quiz_corpus"]) | |
| st.markdown("<br></br>", unsafe_allow_html=True) | |
| # QUIZ SECTION # | |
| with st.expander("Quiz"): | |
| st.markdown("<h2 style='text-align: center; color: black;'>Character Quiz</h2>", | |
| unsafe_allow_html=True) | |
| payload = { | |
| "input_text": st.session_state["quiz_corpus"] | |
| } | |
| try: | |
| model_output = model_predict(payload) | |
| for i in model_output["questions"]: | |
| with st.form(key=str(i["id"])): | |
| st.write(f"Question {i['id']}") | |
| entry = st.radio(label=i["question_statement"], | |
| options=(i["options"]), key=str(i["id"])) | |
| checkbox_val = st.checkbox("Do you want a clue?") | |
| submitted = st.form_submit_button(label='Submit') | |
| if submitted: | |
| if i["answer"] == entry: | |
| st.write("CORRECT!") | |
| else: | |
| st.write("Wrong, check clue") | |
| if checkbox_val: | |
| st.write(i["context"]) | |
| except KeyError: | |
| print("error caught") | |