File size: 15,730 Bytes
2204d6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
import os
import re
import numpy as np
import time
import shutil
import json
import matplotlib.pyplot as plt
from huggingface_hub import login, create_repo, upload_folder, HfFolder
from pathlib import Path # Using pathlib for easier path manipulation

# --- Configuration Constants ---
# Model and Repo Details
BASE_MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
TARGET_REPO_NAME = "Tesslate/Gradience-T1-3B-Checkpoint" # Specify your target repo

# Training Parameters (Update if necessary)
TOTAL_STEPS = 9838 # Total expected steps for progress calculation

# File Names
README_FILENAME = "README.md"
ADAPTER_CONFIG_FILENAME = "adapter_config.json"
TRAINER_STATE_FILENAME = "trainer_state.json"
LOSS_PLOT_FILENAME = "loss.png"

# Plotting Configuration
LOSS_SMOOTHING_WINDOW = 40

# Monitoring Configuration
CHECKPOINT_DIR_PATTERN = re.compile(r"^checkpoint-(\d+)$")
POLL_INTERVAL_SECONDS = 30
PRE_UPLOAD_DELAY_SECONDS = 10 # Delay after finding checkpoint before processing

# --- Global State ---
# Set to track uploaded checkpoints (using Path objects for consistency)
uploaded_checkpoints = set()

# --- Helper Functions ---

def get_huggingface_token():
    """Retrieves the Hugging Face token from environment variable or login cache."""
    token = os.getenv('HUGGINGFACE_TOKEN')
    if token:
        print("Using Hugging Face token from HUGGINGFACE_TOKEN environment variable.")
        return token
    token = HfFolder.get_token()
    if token:
        print("Using Hugging Face token from saved credentials.")
        return token
    raise ValueError("Hugging Face token not found. Set HUGGINGFACE_TOKEN environment variable or login using `huggingface-cli login`.")

def update_adapter_config(config_path: Path, base_model_name: str):
    """
    Reads adapter_config.json, updates the base_model_name_or_path field,
    and saves it back.

    Args:
        config_path (Path): Path to the adapter_config.json file.
        base_model_name (str): The base model name to set.
    """
    try:
        with open(config_path, 'r') as file:
            config = json.load(file)

        config['base_model_name_or_path'] = base_model_name

        with open(config_path, 'w') as file:
            json.dump(config, file, indent=2)
        print(f"Updated 'base_model_name_or_path' in {config_path}")

    except FileNotFoundError:
        print(f"Error: Adapter config file not found at {config_path}")
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {config_path}. Is it valid?")
    except KeyError:
        print(f"Error: 'base_model_name_or_path' key not found in {config_path}")
    except Exception as e:
        print(f"An unexpected error occurred while updating {config_path}: {e}")

def generate_readme_content(checkpoint_number: int, total_steps: int, base_model: str, loss_plot_filename: str) -> str:
    """Generates the README content with updated progress."""
    if total_steps <= 0:
        progress_percentage = 0.0
    else:
        progress_percentage = min(100.0, (checkpoint_number / total_steps) * 100) # Ensure percentage doesn't exceed 100

    progress_width = f"{progress_percentage:.2f}%"
    progress_text = f"Progress: {checkpoint_number} out of {total_steps} steps"

    # Using an f-string for the template makes insertions cleaner
    readme_template = f"""
---
base_model: {base_model}
library_name: peft
---
# Gradience T1 3B (Step {checkpoint_number} Checkpoint)

> [!NOTE]
> Training in progress...

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>Progress Bar Example</title>
    <style>
        .progress-container {{
            width: 100%;
            background-color: #e0e0e0;
            border-radius: 25px;
            overflow: hidden;
            margin: 20px 0;
        }}
        .progress-bar {{
            height: 30px;
            width: 0;
            background-color: #76c7c0;
            text-align: center;
            line-height: 30px;
            color: white;
            border-radius: 25px 0 0 25px;
        }}
        .progress-text {{
            margin-top: 10px;
            font-size: 16px;
            font-family: Arial, sans-serif;
        }}
    </style>
</head>
<body>
<div style="width: 100%; background-color: #e0e0e0; border-radius: 25px; overflow: hidden; margin: 20px 0;">
  <div style="height: 30px; width: {progress_width}; background-color: #76c7c0; text-align: center; line-height: 30px; color: white; border-radius: 25px 0 0 25px;">
    <!-- {progress_percentage:.2f}% -->
  </div>
</div>
<p style="font-family: Arial, sans-serif; font-size: 16px;">{progress_text}</p>
</body>
</html>

## Training Loss
![Training Loss Chart]({loss_plot_filename})
    """.strip()
    return readme_template

def plot_loss_from_json(
    json_file_path: Path,
    output_image_path: Path,
    smooth_steps: int = LOSS_SMOOTHING_WINDOW
):
    """
    Reads training log data from a JSON file (trainer_state.json),
    extracts loss and step values, plots the original loss and a smoothed
    version (running average), and saves the plot to a PNG file.

    Args:
        json_file_path (Path): Path to the input trainer_state.json file.
        output_image_path (Path): Path where the output PNG plot will be saved.
        smooth_steps (int): Window size for running average smoothing.
                            If <= 0, no smoothing is applied.
    """
    print(f"Reading training log data from: {json_file_path}")
    print(f"Smoothing window: {smooth_steps if smooth_steps > 0 else 'Disabled'}")

    try:
        with open(json_file_path, 'r') as f:
            data = json.load(f)
    except FileNotFoundError:
        print(f"Error: JSON file not found at {json_file_path}")
        return
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {json_file_path}. Is it valid?")
        return
    except Exception as e:
        print(f"An unexpected error occurred while reading {json_file_path}: {e}")
        return

    log_history = data.get("log_history") # Use .get for safer access
    if not isinstance(log_history, list):
        print(f"Error: 'log_history' key not found or not a list in {json_file_path}")
        return

    steps, losses = [], []
    for entry in log_history:
        if isinstance(entry, dict) and "step" in entry and "loss" in entry and entry["loss"] is not None:
            try:
                steps.append(int(entry["step"]))
                losses.append(float(entry["loss"]))
            except (ValueError, TypeError):
                print(f"Warning: Skipping entry with non-numeric step/loss: {entry}")
        # else: # Optionally log skipped entries
            # print(f"Info: Skipping log entry missing 'step'/'loss' or loss is null: {entry}")

    if not steps:
        print("No valid step/loss data found in the log history to plot.")
        return

    # Convert to numpy arrays and sort by step (good practice)
    steps = np.array(steps)
    losses = np.array(losses)
    sorted_indices = np.argsort(steps)
    steps = steps[sorted_indices]
    losses = losses[sorted_indices]

    print(f"Found {len(steps)} valid data points to plot.")

    # Calculate Running Average
    smoothed_losses = None
    smoothed_steps = None
    apply_smoothing = smooth_steps > 0 and len(losses) >= smooth_steps

    if apply_smoothing:
        try:
            weights = np.ones(smooth_steps) / smooth_steps
            smoothed_losses = np.convolve(losses, weights, mode='valid')
            smoothed_steps = steps[smooth_steps - 1:] # Steps corresponding to the smoothed values
            print(f"Calculated smoothed loss over {len(smoothed_steps)} points.")
        except Exception as e:
            print(f"Warning: Could not calculate smoothed loss. Error: {e}")
            apply_smoothing = False # Disable if calculation fails
    elif smooth_steps > 0:
        print(f"Warning: Not enough data points ({len(losses)}) for smoothing window ({smooth_steps}). Skipping smoothing.")

    # Plotting
    plt.style.use('seaborn-v0_8-darkgrid') # Use a nice style
    plt.figure(figsize=(10, 6)) # Standard figure size

    plt.plot(steps, losses, linestyle='-', color='skyblue', alpha=0.5, label='Original Loss')

    if apply_smoothing and smoothed_losses is not None and smoothed_steps is not None:
        plt.plot(smoothed_steps, smoothed_losses, linestyle='-', color='dodgerblue', alpha=1.0, linewidth=1.5,
                 label=f'Smoothed Loss ({smooth_steps}-step avg)')

    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.title("Training Loss Progression")
    plt.legend()
    plt.tight_layout() # Adjust layout

    # Saving
    try:
        plt.savefig(output_image_path, format='png', dpi=150)
        print(f"Plot successfully saved to: {output_image_path}")
    except Exception as e:
        print(f"Error saving plot to {output_image_path}: {e}")
    finally:
        plt.close() # Ensure figure is closed to free memory

def prepare_checkpoint_folder(checkpoint_path: Path, checkpoint_number: int):
    """
    Updates README.md, adapter_config.json, and generates the loss plot
    within the specified checkpoint folder.
    """
    print(f"Preparing checkpoint folder: {checkpoint_path}")

    # 1. Update adapter config
    adapter_config_path = checkpoint_path / ADAPTER_CONFIG_FILENAME
    update_adapter_config(adapter_config_path, BASE_MODEL_NAME)

    # 2. Generate loss plot
    trainer_state_path = checkpoint_path / TRAINER_STATE_FILENAME
    loss_plot_path = checkpoint_path / LOSS_PLOT_FILENAME
    plot_loss_from_json(trainer_state_path, loss_plot_path, smooth_steps=LOSS_SMOOTHING_WINDOW)

    # 3. Generate and write README
    readme_path = checkpoint_path / README_FILENAME
    readme_content = generate_readme_content(checkpoint_number, TOTAL_STEPS, BASE_MODEL_NAME, LOSS_PLOT_FILENAME)
    try:
        with open(readme_path, 'w', encoding='utf-8') as file:
            file.write(readme_content)
        print(f"Generated and saved {README_FILENAME} in {checkpoint_path}")
    except Exception as e:
        print(f"Error writing README file to {readme_path}: {e}")

# --- Core Logic ---

def find_new_checkpoint(current_dir: Path = Path('.')) -> tuple[int, Path] | None:
    """
    Finds the checkpoint folder in the specified directory with the highest
    step number that has not been previously uploaded.

    Args:
        current_dir (Path): The directory to scan for checkpoints.

    Returns:
        tuple[int, Path] | None: A tuple containing the (checkpoint_number, folder_path)
                                 or None if no new checkpoint is found.
    """
    new_checkpoints = []
    try:
        for item in current_dir.iterdir():
            if item.is_dir():
                match = CHECKPOINT_DIR_PATTERN.match(item.name)
                # Check if it matches the pattern AND has not been uploaded
                if match and item not in uploaded_checkpoints:
                    checkpoint_number = int(match.group(1))
                    new_checkpoints.append((checkpoint_number, item))
    except FileNotFoundError:
        print(f"Error: Directory not found: {current_dir}")
        return None
    except Exception as e:
        print(f"Error scanning directory {current_dir}: {e}")
        return None

    if new_checkpoints:
        new_checkpoints.sort(key=lambda x: x[0], reverse=True) # Sort by step number, highest first
        return new_checkpoints[0] # Return the one with the highest step number
    return None

def upload_checkpoint_to_hf(folder_path: Path, checkpoint_number: int, repo_id: str):
    """
    Uploads the prepared checkpoint folder to Hugging Face Hub and deletes
    the folder locally upon successful upload.

    Args:
        folder_path (Path): Path to the local checkpoint folder.
        checkpoint_number (int): The checkpoint step number.
        repo_id (str): The Hugging Face repository ID (e.g., "username/repo-name").
    """
    print(f"\nAttempting to upload {folder_path.name} to Hugging Face repository: {repo_id}...")

    try:
        # Ensure repository exists
        create_repo(repo_id, repo_type="model", exist_ok=True)
        print(f"Repository {repo_id} exists or was created.")

        # Upload the folder contents
        upload_folder(
            folder_path=str(folder_path), # upload_folder expects string path
            repo_id=repo_id,
            commit_message=f"Upload checkpoint {checkpoint_number}",
            repo_type="model" # Explicitly set repo type
        )
        print(f"Successfully uploaded contents of {folder_path.name} to {repo_id}.")

        # Delete the local folder ONLY after successful upload
        try:
            shutil.rmtree(folder_path)
            print(f"Successfully deleted local folder: {folder_path}")
            return True # Indicate success
        except OSError as e:
            print(f"Error deleting local folder {folder_path}: {e}. Please delete manually.")
            return True # Upload succeeded, but deletion failed

    except Exception as e:
        print(f"ERROR during Hugging Face upload for {folder_path.name}: {e}")
        print("Upload failed. Local folder will not be deleted.")
        return False # Indicate failure

# --- Main Execution ---

def main():
    """
    Main loop to monitor for new checkpoints, prepare them, upload them to
    Hugging Face Hub, and clean up locally.
    """
    try:
        hf_token = get_huggingface_token()
        login(hf_token)
        print("\nSuccessfully logged into Hugging Face Hub.")
    except ValueError as e:
        print(f"Error: {e}")
        return # Exit if login fails
    except Exception as e:
        print(f"An unexpected error occurred during Hugging Face login: {e}")
        return

    print("\nStarting checkpoint monitor...")
    print(f"Will check for new checkpoints matching '{CHECKPOINT_DIR_PATTERN.pattern}' every {POLL_INTERVAL_SECONDS} seconds.")
    print(f"Target repository: {TARGET_REPO_NAME}")
    print(f"Found checkpoints will be tracked (not re-uploaded): {uploaded_checkpoints or 'None yet'}")
    print("-" * 30)

    while True:
        new_checkpoint_info = find_new_checkpoint()

        if new_checkpoint_info:
            checkpoint_number, folder_path = new_checkpoint_info
            print(f"\nFound new checkpoint: {folder_path.name} (Step {checkpoint_number})")

            # Optional delay: wait a bit in case files are still being written
            print(f"Waiting {PRE_UPLOAD_DELAY_SECONDS} seconds before processing...")
            time.sleep(PRE_UPLOAD_DELAY_SECONDS)

            # Prepare the folder (update README, config, generate plot)
            prepare_checkpoint_folder(folder_path, checkpoint_number)

            # Attempt upload and deletion
            upload_successful = upload_checkpoint_to_hf(
                folder_path=folder_path,
                checkpoint_number=checkpoint_number,
                repo_id=TARGET_REPO_NAME
            )

            if upload_successful:
                # Add to uploaded set ONLY if upload (and optionally deletion) was processed
                uploaded_checkpoints.add(folder_path)
                print(f"Added {folder_path.name} to the set of processed checkpoints.")

            print("-" * 30) # Separator after processing a checkpoint

        else:
            # Use \r for inline update when no checkpoint found
            print(f"\rNo new checkpoints found. Checking again in {POLL_INTERVAL_SECONDS} seconds... ", end="")

        # Wait before the next check
        time.sleep(POLL_INTERVAL_SECONDS)

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\nMonitoring stopped by user.")