|
import os |
|
import torch |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
from peft import PeftModel |
|
import re |
|
|
|
|
|
BASE_MODEL = "deepseek-ai/deepseek-math-7b-instruct" |
|
REPO_ID = "danxh/math-mcq-generator-v1" |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
|
|
def load_model(): |
|
"""Load the fine-tuned model with error handling""" |
|
global model, tokenizer |
|
|
|
try: |
|
print("🔄 Loading model and tokenizer...") |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.float16 |
|
) |
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
BASE_MODEL, |
|
quantization_config=bnb_config, |
|
device_map="auto", |
|
torch_dtype=torch.float16, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
model = PeftModel.from_pretrained(base_model, REPO_ID) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(REPO_ID) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
print("✅ Model loaded successfully!") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"❌ Error loading model: {str(e)}") |
|
return False |
|
|
|
def generate_mcq(chapter, topics, difficulty="medium", cognitive_skill="direct_application"): |
|
"""Generate MCQ using the fine-tuned model""" |
|
|
|
if model is None or tokenizer is None: |
|
return "❌ Model not loaded. Please wait for initialization." |
|
|
|
try: |
|
input_text = f"chapter: {chapter}\ntopics: {topics}\nDifficulty: {difficulty}\nCognitive Skill: {cognitive_skill}" |
|
|
|
prompt = f"""### Instruction: |
|
Generate a math MCQ similar in style to the provided examples. |
|
|
|
### Input: |
|
{input_text} |
|
|
|
### Response: |
|
""" |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=300, |
|
temperature=0.7, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
repetition_penalty=1.1 |
|
) |
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
response_start = generated_text.find("### Response:") + len("### Response:") |
|
response = generated_text[response_start:].strip() |
|
|
|
return response |
|
|
|
except Exception as e: |
|
return f"❌ Error generating MCQ: {str(e)}" |
|
|
|
def parse_mcq_response(response): |
|
"""Parse the model response""" |
|
try: |
|
question_match = re.search(r'Question:\s*(.*?)(?=\nOptions:|Options:)', response, re.DOTALL) |
|
question = question_match.group(1).strip() if question_match else "Question not found" |
|
|
|
options_match = re.search(r'Options:\s*(.*?)(?=\nAnswer:|Answer:)', response, re.DOTALL) |
|
if options_match: |
|
options_text = options_match.group(1).strip() |
|
option_pattern = r'\([A-D]\)\s*([^(]*?)(?=\s*\([A-D]\)|$)' |
|
options = [] |
|
for match in re.finditer(option_pattern, options_text): |
|
option_text = match.group(1).strip() |
|
if option_text: |
|
options.append(option_text) |
|
else: |
|
options = ["Options not found"] |
|
|
|
answer_match = re.search(r'Answer:\s*([A-D])', response) |
|
answer = answer_match.group(1) if answer_match else "Answer not found" |
|
|
|
return { |
|
"question": question, |
|
"options": options, |
|
"correct_answer": answer |
|
} |
|
except Exception as e: |
|
return { |
|
"question": "Parsing error", |
|
"options": ["Error parsing options"], |
|
"correct_answer": "N/A", |
|
"error": str(e) |
|
} |
|
|
|
def generate_mcq_web(chapter, topics_text, difficulty, cognitive_skill, num_questions=1): |
|
"""Web interface wrapper for MCQ generation""" |
|
|
|
if model is None or tokenizer is None: |
|
return """ |
|
<div style="border: 2px solid #ffc107; border-radius: 10px; padding: 20px; margin: 10px 0; background: #fff3cd;"> |
|
<h3 style="color: #856404;">⏳ Model Loading</h3> |
|
<p>The model is still loading. Please wait a moment and try again.</p> |
|
</div> |
|
""" |
|
|
|
try: |
|
|
|
topics_list = [t.strip() for t in topics_text.split(',') if t.strip()] |
|
if not topics_list: |
|
topics_list = ["General"] |
|
|
|
results = [] |
|
|
|
for i in range(min(num_questions, 3)): |
|
|
|
raw_response = generate_mcq(chapter, topics_list, difficulty, cognitive_skill) |
|
parsed = parse_mcq_response(raw_response) |
|
|
|
if "error" not in parsed: |
|
|
|
question_html = f""" |
|
<div style="border: 2px solid #e1e5e9; border-radius: 10px; padding: 20px; margin: 10px 0; background: #f8f9fa;"> |
|
<h3 style="color: #2c3e50; margin-top: 0;">📚 Question {i+1}</h3> |
|
<p style="font-size: 16px; line-height: 1.6; margin: 15px 0;"><strong>{parsed['question']}</strong></p> |
|
|
|
<div style="margin: 15px 0;"> |
|
<h4 style="color: #34495e;">Options:</h4> |
|
<ul style="list-style: none; padding: 0;"> |
|
<li style="margin: 8px 0; padding: 8px; background: #ecf0f1; border-radius: 5px;"> |
|
<strong>(A)</strong> {parsed['options'][0] if len(parsed['options']) > 0 else 'N/A'} |
|
</li> |
|
<li style="margin: 8px 0; padding: 8px; background: #ecf0f1; border-radius: 5px;"> |
|
<strong>(B)</strong> {parsed['options'][1] if len(parsed['options']) > 1 else 'N/A'} |
|
</li> |
|
<li style="margin: 8px 0; padding: 8px; background: #ecf0f1; border-radius: 5px;"> |
|
<strong>(C)</strong> {parsed['options'][2] if len(parsed['options']) > 2 else 'N/A'} |
|
</li> |
|
<li style="margin: 8px 0; padding: 8px; background: #ecf0f1; border-radius: 5px;"> |
|
<strong>(D)</strong> {parsed['options'][3] if len(parsed['options']) > 3 else 'N/A'} |
|
</li> |
|
</ul> |
|
</div> |
|
|
|
<div style="margin-top: 15px; padding: 10px; background: #d5edda; border-radius: 5px; border-left: 4px solid #28a745;"> |
|
<strong>✅ Correct Answer: {parsed['correct_answer']}</strong> |
|
</div> |
|
</div> |
|
""" |
|
results.append(question_html) |
|
else: |
|
error_html = f""" |
|
<div style="border: 2px solid #dc3545; border-radius: 10px; padding: 20px; margin: 10px 0; background: #f8d7da;"> |
|
<h3 style="color: #721c24;">❌ Error generating question {i+1}</h3> |
|
<p>{parsed.get('error', 'Unknown error occurred')}</p> |
|
</div> |
|
""" |
|
results.append(error_html) |
|
|
|
return "".join(results) |
|
|
|
except Exception as e: |
|
return f""" |
|
<div style="border: 2px solid #dc3545; border-radius: 10px; padding: 20px; margin: 10px 0; background: #f8d7da;"> |
|
<h3 style="color: #721c24;">❌ System Error</h3> |
|
<p>Error: {str(e)}</p> |
|
</div> |
|
""" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=generate_mcq_web, |
|
inputs=[ |
|
gr.Textbox( |
|
label="📚 Chapter", |
|
placeholder="e.g., Applications of Trigonometry, Conic Sections", |
|
value="Applications of Trigonometry", |
|
info="Enter the mathematics chapter or topic area" |
|
), |
|
gr.Textbox( |
|
label="📝 Topics (comma-separated)", |
|
placeholder="e.g., Heights and Distances, Circle, Tangents", |
|
value="Heights and Distances", |
|
info="Enter specific topics within the chapter, separated by commas" |
|
), |
|
gr.Dropdown( |
|
choices=["easy", "medium", "hard"], |
|
label="⚡ Difficulty Level", |
|
value="medium", |
|
info="Select the difficulty level for the questions" |
|
), |
|
gr.Dropdown( |
|
choices=["recall", "direct_application", "pattern_recognition", "strategic_reasoning", "trap_aware"], |
|
label="🧠 Cognitive Skill", |
|
value="direct_application", |
|
info="Select the type of thinking skill required" |
|
), |
|
gr.Slider( |
|
minimum=1, |
|
maximum=3, |
|
step=1, |
|
label="🔢 Number of Questions", |
|
value=1, |
|
info="How many questions to generate (max 3)" |
|
) |
|
], |
|
outputs=gr.HTML(label="Generated MCQ(s)"), |
|
|
|
title="🧮 Mathematics MCQ Generator", |
|
description=""" |
|
Generate high-quality mathematics multiple choice questions using AI. |
|
This model has been fine-tuned specifically for educational content creation. |
|
|
|
**Note**: Model loading may take a few minutes on first startup. |
|
""", |
|
|
|
article=""" |
|
### 🔬 About This Model |
|
|
|
This MCQ generator is powered by a fine-tuned version of DeepSeek-Math-7B, specifically adapted for mathematics education. |
|
|
|
### 💡 Tips for Best Results: |
|
- Be specific with chapter and topic names |
|
- Try different cognitive skill levels for variety |
|
- Start with 1 question to test, then generate more |
|
|
|
### 🤝 Collaboration |
|
This is part of a collaborative project to create specialized educational AI tools. |
|
""", |
|
|
|
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple"), |
|
|
|
examples=[ |
|
["Applications of Trigonometry", "Heights and Distances", "easy", "recall", 1], |
|
["Conic Sections", "Circle", "medium", "pattern_recognition", 1], |
|
["Applications of Trigonometry", "Angle of Elevation and Depression", "hard", "strategic_reasoning", 1] |
|
] |
|
) |
|
|
|
|
|
print("🚀 Starting model loading...") |
|
model_loaded = load_model() |
|
|
|
if model_loaded: |
|
print("✅ Ready to generate MCQs!") |
|
else: |
|
print("❌ Model loading failed. The interface may not work properly.") |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|