huzaifa113's picture
Update app.py
c652a26 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import pandas as pd
# Load DeepSeek-Math 7B Instruct model and tokenizer
model_name = "deepseek-ai/deepseek-math-7b-instruct"
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
model.generation_config.pad_token_id = tokenizer.eos_token_id
print("DeepSeek-Math-7B-Instruct loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
model = None
tokenizer = None
# Load training data from math_meme_training.txt
training_file = "math_meme_training.txt"
few_shot_prompt = "Below are examples of correcting incorrect math memes with step-by-step reasoning:\n\n"
if os.path.exists(training_file):
with open(training_file, 'r') as f:
lines = f.readlines()
# Take up to 3 examples to keep prompt manageable
examples = []
for i in range(0, min(6, len(lines)), 2): # Step by 2 for Prompt/Response pairs
if lines[i].startswith("Prompt: ") and i+1 < len(lines) and lines[i+1].startswith("Response: "):
prompt = lines[i].replace("Prompt: ", "").strip()
response = lines[i+1].replace("Response: ", "").strip()
examples.append(f"{prompt}\nResponse: {response}")
few_shot_prompt += "\n".join(examples[:3]) + "\n\nNow, correct the following incorrect math meme:\n"
else:
few_shot_prompt += "No training data found. Using default behavior.\n\nNow, correct the following incorrect math meme:\n"
print("Warning: math_meme_training.txt not found. Using minimal prompt.")
# Function to generate correction
def generate_deepseek_correction(user_input):
if model is None or tokenizer is None:
return "Error: Model not loaded. Please check the Space logs and ensure a GPU is enabled."
# Validate input format
if '=' not in user_input or len(user_input.split('=')) != 2:
return "Invalid input. Please enter in the format 'expression = wrong_answer' (e.g., '8 ÷ 2(2+2) = 1')."
# Construct prompt with training data
meme_prompt = f"Correct this incorrect math meme: '{user_input}'\nPlease reason step by step and provide the correct answer with an explanation."
full_prompt = few_shot_prompt + meme_prompt
try:
# Generate response
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
do_sample=True,
eos_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
correction = response[len(full_prompt):].strip()
if not correction:
correction = "DeepSeek-Math failed to generate a meaningful correction."
# Add funny error rating
error_rating = "Funny Error Rating: '90% fixing memes, 10% wondering why fractions are so hard'"
return f"{correction}\n\n{error_rating}"
except Exception as e:
return f"Error processing input: {str(e)}"
# Gradio interface
interface = gr.Interface(
fn=generate_deepseek_correction,
inputs=gr.Textbox(
label="Enter an Incorrect Math Meme",
placeholder="e.g., '7 + 2 × 3 = 27' or '1/2 + 1/2 = 0.25'"
),
outputs=gr.Textbox(label="Corrected Answer and Explanation"),
title="Math Meme Repair with DeepSeek-Math",
description="Input an incorrect math meme (e.g., '8 ÷ 2(2+2) = 1') and get the correct answer with an explanation!",
examples=[
["7 + 2 × 3 = 27"],
["4 ÷ 2(1+1) = 1"],
["1/2 + 1/2 = 0.25"]
]
)
# Launch the interface
interface.launch(server_name="0.0.0.0", server_port=7860)