phi2-lora-gsm8k / app.py
darshjoshi16's picture
Update app.py
757c849 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
# Load base + adapter
base_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
model = PeftModel.from_pretrained(base_model, "darshjoshi16/phi2-lora-math")
tokenizer.pad_token = tokenizer.eos_token
model.eval()
def solve_math_question(question):
prompt = f"Q: {question.strip()}\nA:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=100)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Gradio app
demo = gr.Interface(
fn=solve_math_question,
inputs=gr.Textbox(label="Enter Math Word Problem"),
outputs=gr.Textbox(label="Model's Answer"),
title="Phi-2 LoRA on GSM8K",
description="Fine-tuned Phi-2 model using LoRA to solve math reasoning problems."
)
demo.launch()