File size: 3,889 Bytes
60836c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import pandas as pd


# Load DeepSeek-Math 7B Instruct model and tokenizer
model_name = "deepseek-ai/deepseek-math-7b-instruct"
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
    model.generation_config.pad_token_id = tokenizer.eos_token_id
    print("DeepSeek-Math-7B-Instruct loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    model = None
    tokenizer = None

# Load training data from math_meme_training.txt
training_file = "math_meme_training.txt"
few_shot_prompt = "Below are examples of correcting incorrect math memes with step-by-step reasoning:\n\n"
if os.path.exists(training_file):
    with open(training_file, 'r') as f:
        lines = f.readlines()
        # Take up to 3 examples to keep prompt manageable
        examples = []
        for i in range(0, min(6, len(lines)), 2):  # Step by 2 for Prompt/Response pairs
            if lines[i].startswith("Prompt: ") and i+1 < len(lines) and lines[i+1].startswith("Response: "):
                prompt = lines[i].replace("Prompt: ", "").strip()
                response = lines[i+1].replace("Response: ", "").strip()
                examples.append(f"{prompt}\nResponse: {response}")
        few_shot_prompt += "\n".join(examples[:3]) + "\n\nNow, correct the following incorrect math meme:\n"
else:
    few_shot_prompt += "No training data found. Using default behavior.\n\nNow, correct the following incorrect math meme:\n"
    print("Warning: math_meme_training.txt not found. Using minimal prompt.")

# Function to generate correction
def generate_deepseek_correction(user_input):
    if model is None or tokenizer is None:
        return "Error: Model not loaded. Please check the Space logs and ensure a GPU is enabled."

    # Validate input format
    if '=' not in user_input or len(user_input.split('=')) != 2:
        return "Invalid input. Please enter in the format 'expression = wrong_answer' (e.g., '8 ÷ 2(2+2) = 1')."

    # Construct prompt with training data
    meme_prompt = f"Correct this incorrect math meme: '{user_input}'\nPlease reason step by step and provide the correct answer with an explanation."
    full_prompt = few_shot_prompt + meme_prompt

    try:
        # Generate response
        inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
        outputs = model.generate(
            **inputs,
            max_new_tokens=150,
            temperature=0.7,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id
        )
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        correction = response[len(full_prompt):].strip()

        if not correction:
            correction = "DeepSeek-Math failed to generate a meaningful correction."

        # Add funny error rating
        error_rating = "Funny Error Rating: '90% fixing memes, 10% wondering why fractions are so hard'"
        return f"{correction}\n\n{error_rating}"
    except Exception as e:
        return f"Error processing input: {str(e)}"

# Gradio interface
interface = gr.Interface(
    fn=generate_deepseek_correction,
    inputs=gr.Textbox(
        label="Enter an Incorrect Math Meme",
        placeholder="e.g., '7 + 2 × 3 = 27' or '1/2 + 1/2 = 0.25'"
    ),
    outputs=gr.Textbox(label="Corrected Answer and Explanation"),
    title="Math Meme Repair with DeepSeek-Math",
    description="Input an incorrect math meme (e.g., '8 ÷ 2(2+2) = 1') and get the correct answer with an explanation!",
    examples=[
        ["7 + 2 × 3 = 27"],
        ["4 ÷ 2(1+1) = 1"],
        ["1/2 + 1/2 = 0.25"]
    ]
)

# Launch the interface
interface.launch(server_name="0.0.0.0", server_port=7860)