File size: 3,500 Bytes
853c001
 
 
 
 
2e7af1a
 
853c001
 
 
 
 
2e7af1a
 
 
059a058
2e7af1a
 
 
 
 
 
 
 
 
 
853c001
 
 
 
 
2e7af1a
 
853c001
059a058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
853c001
2e7af1a
853c001
059a058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
853c001
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# app.py
import os
import gradio as gr
import keras
import keras_hub
# Import the specific downloader function
from huggingface_hub import hf_hub_download

# Set Keras backend
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

# --- 1. LOAD THE MERGED MODEL FROM THE HUB (CORRECTED METHOD) ---

# Define your repository and the filename of the model
repo_id = "Tarive/lora_research_abstracts"
model_filename = "model.keras" # The name we used during the upload step

print(f"Downloading model file '{model_filename}' from Hub repo: {repo_id}")
# Step 1: Explicitly download the .keras file and get its local path in the cache
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)

print(f"Loading merged model from local path: {model_path}")
# Step 2: Load the model directly from the specific file path
# This avoids the directory format error.
gemma_lm = keras.models.load_model(model_path)

# Compile the model with a sampler for generation
gemma_lm.compile(sampler=keras_hub.samplers.TopKSampler(k=5))
print("Model loaded and compiled successfully.")


# --- 2. DEFINE THE INFERENCE FUNCTION (No changes needed here) ---
def revise_abstract(draft_abstract, grant_type):
    if not draft_abstract or not grant_type:
        return "Error: Please provide both a draft abstract and a grant type."

    template = (
        "Instruction:\n"
        "You are an expert grant writer. Rewrite the following draft abstract to be more impactful and clear, "
        "following the specific conventions of a {activity_code} grant. Ensure the most compelling claims are front-loaded.\n\n"
        "Input Draft:\n"
        "{unoptimized_abstract}\n\n"
        "Revised Abstract:"
    )
    prompt = template.format(unoptimized_abstract=draft_abstract, activity_code=grant_type)
    output = gemma_lm.generate(prompt, max_length=1024)
    
    parts = output.split("Revised Abstract:")
    return parts[1].strip() if len(parts) > 1 else output.strip()

# --- 3. CREATE THE GRADIO INTERFACE (No changes needed here) ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Grant Abstract Revision Tool (Fine-Tuned on Gemma)")
    gr.Markdown("Enter a draft abstract and select its grant type. The model will rewrite it to be more impactful, based on patterns from successfully funded NIH grants.")
    
    with gr.Row():
        draft_input = gr.Textbox(lines=15, label="Input Draft Abstract", placeholder="Paste your draft abstract here...")
        grant_type_input = gr.Dropdown(
            ["R01", "R21", "F32", "T32", "P30", "R41", "R43", "R44", "K99"],
            label="Grant Type (Activity Code)",
            info="Select the grant mechanism you are targeting."
        )
    
    submit_button = gr.Button("Revise Abstract", variant="primary")
    revised_output = gr.Textbox(lines=15, label="Model's Revised Abstract", interactive=False)
    
    submit_button.click(fn=revise_abstract, inputs=[draft_input, grant_type_input], outputs=revised_output)
    
    gr.Examples(
        examples=[
            ["SUMMARY \nA pressing concern exists regarding lead poisoning...This study aimed to optimize and validate a dried blood spot collection device...", "R41"],
            ["This project is about figuring out how macrophages and S. flexneri interact...Our study will look at this in vitro and in vivo...", "R21"]
        ],
        inputs=[draft_input, grant_type_input]
    )

print("Launching Gradio app...")
demo.launch()