Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,10 +7,9 @@ import torch
|
|
| 7 |
from pytorch_lightning import LightningModule
|
| 8 |
from safetensors.torch import save_file
|
| 9 |
from torch import nn
|
| 10 |
-
from modelalign import BERTAlignModel
|
| 11 |
|
| 12 |
import gradio as gr
|
| 13 |
-
|
| 14 |
|
| 15 |
# ===========================
|
| 16 |
# Utility Functions
|
|
@@ -48,7 +47,11 @@ def load_checkpoint(model: LightningModule, checkpoint_path: str, device: str =
|
|
| 48 |
try:
|
| 49 |
# Load the checkpoint; adjust map_location based on device
|
| 50 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
return True, "Checkpoint loaded successfully."
|
| 53 |
except Exception as e:
|
| 54 |
return False, f"Failed to load checkpoint: {str(e)}"
|
|
@@ -80,31 +83,29 @@ def convert_checkpoint_to_safetensors(checkpoint_url: str, model_name: str):
|
|
| 80 |
# Step 1: Download the checkpoint
|
| 81 |
success, message = download_checkpoint(checkpoint_url, checkpoint_path)
|
| 82 |
if not success:
|
| 83 |
-
return
|
| 84 |
|
| 85 |
# Step 2: Initialize the model
|
| 86 |
success, model_or_msg = initialize_model(model_name)
|
| 87 |
if not success:
|
| 88 |
-
return
|
| 89 |
model = model_or_msg
|
| 90 |
|
| 91 |
# Step 3: Load the checkpoint
|
| 92 |
success, message = load_checkpoint(model, checkpoint_path)
|
| 93 |
if not success:
|
| 94 |
-
return
|
| 95 |
|
| 96 |
# Step 4: Convert to SafeTensors
|
| 97 |
success, message = convert_to_safetensors(model, safetensors_path)
|
| 98 |
if not success:
|
| 99 |
-
return
|
| 100 |
|
| 101 |
# Step 5: Read the safetensors file for download
|
| 102 |
try:
|
| 103 |
-
|
| 104 |
-
safetensors_bytes = f.read()
|
| 105 |
-
return safetensors_bytes, "Conversion successful! Download your SafeTensors file below."
|
| 106 |
except Exception as e:
|
| 107 |
-
return
|
| 108 |
|
| 109 |
# ===========================
|
| 110 |
# Gradio Interface Setup
|
|
@@ -125,12 +126,20 @@ Convert your PyTorch Lightning `.ckpt` checkpoints to the secure `safetensors` f
|
|
| 125 |
iface = gr.Interface(
|
| 126 |
fn=convert_checkpoint_to_safetensors,
|
| 127 |
inputs=[
|
| 128 |
-
gr.
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
],
|
| 131 |
outputs=[
|
| 132 |
-
gr.
|
| 133 |
-
gr.
|
| 134 |
],
|
| 135 |
title=title,
|
| 136 |
description=description,
|
|
|
|
| 7 |
from pytorch_lightning import LightningModule
|
| 8 |
from safetensors.torch import save_file
|
| 9 |
from torch import nn
|
|
|
|
| 10 |
|
| 11 |
import gradio as gr
|
| 12 |
+
from modelalign import BERTAlignModel
|
| 13 |
|
| 14 |
# ===========================
|
| 15 |
# Utility Functions
|
|
|
|
| 47 |
try:
|
| 48 |
# Load the checkpoint; adjust map_location based on device
|
| 49 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 50 |
+
# Assuming the checkpoint has a 'state_dict' key
|
| 51 |
+
if 'state_dict' in checkpoint:
|
| 52 |
+
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
| 53 |
+
else:
|
| 54 |
+
model.load_state_dict(checkpoint, strict=False)
|
| 55 |
return True, "Checkpoint loaded successfully."
|
| 56 |
except Exception as e:
|
| 57 |
return False, f"Failed to load checkpoint: {str(e)}"
|
|
|
|
| 83 |
# Step 1: Download the checkpoint
|
| 84 |
success, message = download_checkpoint(checkpoint_url, checkpoint_path)
|
| 85 |
if not success:
|
| 86 |
+
return None, message
|
| 87 |
|
| 88 |
# Step 2: Initialize the model
|
| 89 |
success, model_or_msg = initialize_model(model_name)
|
| 90 |
if not success:
|
| 91 |
+
return None, model_or_msg
|
| 92 |
model = model_or_msg
|
| 93 |
|
| 94 |
# Step 3: Load the checkpoint
|
| 95 |
success, message = load_checkpoint(model, checkpoint_path)
|
| 96 |
if not success:
|
| 97 |
+
return None, message
|
| 98 |
|
| 99 |
# Step 4: Convert to SafeTensors
|
| 100 |
success, message = convert_to_safetensors(model, safetensors_path)
|
| 101 |
if not success:
|
| 102 |
+
return None, message
|
| 103 |
|
| 104 |
# Step 5: Read the safetensors file for download
|
| 105 |
try:
|
| 106 |
+
return safetensors_path, "Conversion successful! Download your SafeTensors file below."
|
|
|
|
|
|
|
| 107 |
except Exception as e:
|
| 108 |
+
return None, f"Failed to prepare download: {str(e)}"
|
| 109 |
|
| 110 |
# ===========================
|
| 111 |
# Gradio Interface Setup
|
|
|
|
| 126 |
iface = gr.Interface(
|
| 127 |
fn=convert_checkpoint_to_safetensors,
|
| 128 |
inputs=[
|
| 129 |
+
gr.Textbox(
|
| 130 |
+
lines=2,
|
| 131 |
+
placeholder="Enter the checkpoint URL here...",
|
| 132 |
+
label="Checkpoint URL"
|
| 133 |
+
),
|
| 134 |
+
gr.Textbox(
|
| 135 |
+
lines=1,
|
| 136 |
+
placeholder="e.g., roberta-base",
|
| 137 |
+
label="Model Name"
|
| 138 |
+
)
|
| 139 |
],
|
| 140 |
outputs=[
|
| 141 |
+
gr.File(label="Download SafeTensors File"),
|
| 142 |
+
gr.Textbox(label="Status")
|
| 143 |
],
|
| 144 |
title=title,
|
| 145 |
description=description,
|