Spaces:
Runtime error
Runtime error
| import asyncio | |
| import streamlit as st | |
| from text_processing import segment_text | |
| from keyword_extraction import extract_keywords | |
| from utils import QuestionGenerationError | |
| from mapping_keywords import map_keywords_to_sentences | |
| from option_generation import gen_options, generate_options_async | |
| from fill_in_the_blanks_generation import generate_fill_in_the_blank_questions | |
| from load_models import load_nlp_models, load_qa_models, load_model | |
| nlp, s2v = load_nlp_models() | |
| similarity_model, spell = load_qa_models() | |
| def assess_question_quality(context, question, answer): | |
| # Assess relevance using cosine similarity | |
| context_doc = nlp(context) | |
| question_doc = nlp(question) | |
| relevance_score = context_doc.similarity(question_doc) | |
| # Assess complexity using token length (as a simple metric) | |
| complexity_score = min(len(question_doc) / 20, 1) # Normalize to 0-1 | |
| # Assess Spelling correctness | |
| misspelled = spell.unknown(question.split()) | |
| spelling_correctness = 1 - (len(misspelled) / len(question.split())) # Normalize to 0-1 | |
| # Calculate overall score (you can adjust weights as needed) | |
| overall_score = ( | |
| 0.4 * relevance_score + | |
| 0.4 * complexity_score + | |
| 0.2 * spelling_correctness | |
| ) | |
| return overall_score, relevance_score, complexity_score, spelling_correctness | |
| async def process_batch(batch, keywords, context_window_size, num_beams, num_questions, modelname): | |
| questions = [] | |
| print("inside process batch function") | |
| flag = False | |
| for text in batch: | |
| if flag: | |
| break | |
| keyword_sentence_mapping = map_keywords_to_sentences(text, keywords, context_window_size) | |
| print(keyword_sentence_mapping) | |
| for keyword, context in keyword_sentence_mapping.items(): | |
| print("Length of questions list from process batch function: ",len(questions)) | |
| if len(questions)>=num_questions: | |
| flag = True | |
| break | |
| question = await generate_question_async(context, keyword, num_beams,modelname) | |
| options = await generate_options_async(keyword, context) | |
| # options = gen_options(keyword, context, question) | |
| blank_question = await generate_fill_in_the_blank_questions(context,keyword) | |
| overall_score, relevance_score, complexity_score, spelling_correctness = assess_question_quality(context, question, keyword) | |
| if overall_score >= 0.5: | |
| questions.append({ | |
| "question": question, | |
| "context": context, | |
| "answer": keyword, | |
| "options": options, | |
| "overall_score": overall_score, | |
| "relevance_score": relevance_score, | |
| "complexity_score": complexity_score, | |
| "spelling_correctness": spelling_correctness, | |
| "blank_question": blank_question, | |
| }) | |
| return questions | |
| async def generate_question_async(context, answer, num_beams,modelname): | |
| model, tokenizer = load_model(modelname) | |
| try: | |
| input_text = f"<context> {context} <answer> {answer}" | |
| print(f"\n{input_text}\n") | |
| input_ids = tokenizer.encode(input_text, return_tensors='pt') | |
| outputs = await asyncio.to_thread(model.generate, input_ids, num_beams=num_beams, early_stopping=True, max_length=250) | |
| question = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| print(f"\n{question}\n") | |
| # print(type(question)) | |
| return question | |
| except Exception as e: | |
| raise QuestionGenerationError(f"Error in question generation: {str(e)}") | |
| # Function to generate questions using beam search | |
| async def generate_questions_async(text, num_questions, context_window_size, num_beams, extract_all_keywords,modelname): | |
| try: | |
| batches = segment_text(text.lower()) | |
| keywords = extract_keywords(text, extract_all_keywords) | |
| all_questions = [] | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| print("Final keywords:",keywords) | |
| print("Number of questions that needs to be generated: ",num_questions) | |
| print("totoal no of batches:", batches) | |
| for i, batch in enumerate(batches): | |
| print("batch no: ", len(batches)) | |
| status_text.text(f"Processing batch {i+1} of {len(batches)}...") | |
| batch_questions = await process_batch(batch, keywords, context_window_size, num_beams,num_questions,modelname) | |
| all_questions.extend(batch_questions) | |
| progress_bar.progress((i + 1) / len(batches)) | |
| print("Length of the all questions list: ",len(all_questions)) | |
| if len(all_questions) >= num_questions: | |
| break | |
| progress_bar.empty() | |
| status_text.empty() | |
| return all_questions[:num_questions] | |
| except QuestionGenerationError as e: | |
| st.error(f"An error occurred during question generation: {str(e)}") | |
| return [] | |
| except Exception as e: | |
| st.error(f"An unexpected error occurred: {str(e)}") | |
| return [] | |