danxh commited on
Commit
2ae78f5
·
verified ·
1 Parent(s): cbd0cb5

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
+ from peft import PeftModel
6
+ import re
7
+
8
+ # Model configuration
9
+ BASE_MODEL = "deepseek-ai/deepseek-math-7b-instruct"
10
+ REPO_ID = "danxh/math-mcq-generator-v1"
11
+
12
+ # Load model and tokenizer
13
+ @torch.no_grad()
14
+ def load_model():
15
+ """Load the fine-tuned model"""
16
+
17
+ bnb_config = BitsAndBytesConfig(
18
+ load_in_4bit=True,
19
+ bnb_4bit_use_double_quant=True,
20
+ bnb_4bit_quant_type="nf4",
21
+ bnb_4bit_compute_dtype="bfloat16"
22
+ )
23
+
24
+ # Load base model
25
+ base_model = AutoModelForCausalLM.from_pretrained(
26
+ BASE_MODEL,
27
+ quantization_config=bnb_config,
28
+ device_map="auto",
29
+ torch_dtype=torch.bfloat16
30
+ )
31
+
32
+ # Load LoRA adapter
33
+ model = PeftModel.from_pretrained(base_model, REPO_ID)
34
+
35
+ # Load tokenizer
36
+ tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
37
+ if tokenizer.pad_token is None:
38
+ tokenizer.pad_token = tokenizer.eos_token
39
+
40
+ return model, tokenizer
41
+
42
+ # Initialize model
43
+ print("🔄 Loading model...")
44
+ model, tokenizer = load_model()
45
+ print("✅ Model loaded successfully!")
46
+
47
+ def generate_mcq(chapter, topics, difficulty="medium", cognitive_skill="direct_application"):
48
+ """Generate MCQ using the fine-tuned model"""
49
+
50
+ input_text = f"chapter: {chapter}\ntopics: {topics}\nDifficulty: {difficulty}\nCognitive Skill: {cognitive_skill}"
51
+
52
+ prompt = f"""### Instruction:
53
+ Generate a math MCQ similar in style to the provided examples.
54
+
55
+ ### Input:
56
+ {input_text}
57
+
58
+ ### Response:
59
+ """
60
+
61
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
62
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
63
+
64
+ with torch.no_grad():
65
+ outputs = model.generate(
66
+ **inputs,
67
+ max_new_tokens=300,
68
+ temperature=0.7,
69
+ do_sample=True,
70
+ pad_token_id=tokenizer.eos_token_id,
71
+ eos_token_id=tokenizer.eos_token_id,
72
+ repetition_penalty=1.1
73
+ )
74
+
75
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
76
+ response_start = generated_text.find("### Response:") + len("### Response:")
77
+ response = generated_text[response_start:].strip()
78
+
79
+ return response
80
+
81
+ def parse_mcq_response(response):
82
+ """Parse the model response"""
83
+ try:
84
+ question_match = re.search(r'Question:\s*(.*?)(?=\nOptions:|Options:)', response, re.DOTALL)
85
+ question = question_match.group(1).strip() if question_match else "Question not found"
86
+
87
+ options_match = re.search(r'Options:\s*(.*?)(?=\nAnswer:|Answer:)', response, re.DOTALL)
88
+ if options_match:
89
+ options_text = options_match.group(1).strip()
90
+ option_pattern = r'\([A-D]\)\s*([^(]*?)(?=\s*\([A-D]\)|$)'
91
+ options = []
92
+ for match in re.finditer(option_pattern, options_text):
93
+ option_text = match.group(1).strip()
94
+ if option_text:
95
+ options.append(option_text)
96
+ else:
97
+ options = ["Options not found"]
98
+
99
+ answer_match = re.search(r'Answer:\s*([A-D])', response)
100
+ answer = answer_match.group(1) if answer_match else "Answer not found"
101
+
102
+ return {
103
+ "question": question,
104
+ "options": options,
105
+ "correct_answer": answer
106
+ }
107
+ except Exception as e:
108
+ return {
109
+ "question": "Parsing error",
110
+ "options": ["Error parsing options"],
111
+ "correct_answer": "N/A",
112
+ "error": str(e)
113
+ }
114
+
115
+ # [Include the web interface function here - copy from above]
116
+ def generate_mcq_web(chapter, topics_text, difficulty, cognitive_skill, num_questions=1):
117
+ # [Copy the function implementation from above]
118
+ pass
119
+
120
+ # Create and launch interface
121
+ interface = create_gradio_interface()
122
+
123
+ if __name__ == "__main__":
124
+ interface.launch()