Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Script to integrate improved monitoring with HF Datasets into training scripts | |
""" | |
import os | |
import sys | |
import re | |
from pathlib import Path | |
def update_training_script(script_path: str): | |
"""Update a training script to include improved monitoring""" | |
print(f"π§ Updating {script_path}...") | |
with open(script_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
# Check if monitoring is already imported | |
if 'from monitoring import' in content: | |
print(f" β οΈ Monitoring already imported in {script_path}") | |
return False | |
# Add monitoring import | |
import_pattern = r'(from \w+ import.*?)(\n\n|\n$)' | |
match = re.search(import_pattern, content, re.MULTILINE | re.DOTALL) | |
if match: | |
# Add monitoring import after existing imports | |
new_import = match.group(1) + '\nfrom monitoring import create_monitor_from_config\n' + match.group(2) | |
content = content.replace(match.group(0), new_import) | |
else: | |
# Add at the beginning if no imports found | |
content = 'from monitoring import create_monitor_from_config\n\n' + content | |
# Find the main training function and add monitoring | |
# Look for patterns like "def main():" or "def train():" | |
main_patterns = [ | |
r'def main\(\):', | |
r'def train\(\):', | |
r'def run_training\(\):' | |
] | |
monitoring_added = False | |
for pattern in main_patterns: | |
if re.search(pattern, content): | |
# Add monitoring initialization after config loading | |
config_pattern = r'(config\s*=\s*get_config\([^)]+\))' | |
config_match = re.search(config_pattern, content) | |
if config_match: | |
monitoring_code = ''' | |
# Initialize monitoring | |
monitor = None | |
if config.enable_tracking: | |
try: | |
monitor = create_monitor_from_config(config, getattr(config, 'experiment_name', None)) | |
logger.info(f"β Monitoring initialized for experiment: {monitor.experiment_name}") | |
logger.info(f"π Dataset repository: {monitor.dataset_repo}") | |
# Log configuration | |
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('_')} | |
monitor.log_configuration(config_dict) | |
except Exception as e: | |
logger.error(f"Failed to initialize monitoring: {e}") | |
logger.warning("Continuing without monitoring...") | |
''' | |
# Insert monitoring code after config loading | |
insert_point = config_match.end() | |
content = content[:insert_point] + monitoring_code + content[insert_point:] | |
# Add monitoring callback to trainer | |
trainer_pattern = r'(trainer\s*=\s*[^)]+\))' | |
trainer_match = re.search(trainer_pattern, content) | |
if trainer_match: | |
callback_code = ''' | |
# Add monitoring callback if available | |
if monitor: | |
try: | |
callback = monitor.create_monitoring_callback() | |
trainer.add_callback(callback) | |
logger.info("β Monitoring callback added to trainer") | |
except Exception as e: | |
logger.error(f"Failed to add monitoring callback: {e}") | |
''' | |
insert_point = trainer_match.end() | |
content = content[:insert_point] + callback_code + content[insert_point:] | |
# Add training summary logging | |
train_pattern = r'(trainer\.train\(\))' | |
train_match = re.search(train_pattern, content) | |
if train_match: | |
summary_code = ''' | |
# Log training summary | |
if monitor: | |
try: | |
summary = { | |
'final_loss': getattr(trainer, 'final_loss', None), | |
'total_steps': getattr(trainer, 'total_steps', None), | |
'training_duration': getattr(trainer, 'training_duration', None), | |
'model_path': output_path, | |
'config_file': config_path | |
} | |
monitor.log_training_summary(summary) | |
logger.info("β Training summary logged") | |
except Exception as e: | |
logger.error(f"Failed to log training summary: {e}") | |
''' | |
# Find the training call and add summary after it | |
train_call_pattern = r'(trainer\.train\(\)\s*\n\s*logger\.info\("Training completed successfully!"\))' | |
train_call_match = re.search(train_call_pattern, content) | |
if train_call_match: | |
insert_point = train_call_match.end() | |
content = content[:insert_point] + summary_code + content[insert_point:] | |
# Add error handling and cleanup | |
error_pattern = r'(except Exception as e:\s*\n\s*logger\.error\(f"Training failed: {e}"\)\s*\n\s*raise)' | |
error_match = re.search(error_pattern, content) | |
if error_match: | |
error_code = ''' | |
# Log error to monitoring | |
if monitor: | |
try: | |
error_summary = { | |
'error': str(e), | |
'status': 'failed', | |
'model_path': output_path, | |
'config_file': config_path | |
} | |
monitor.log_training_summary(error_summary) | |
except Exception as log_error: | |
logger.error(f"Failed to log error to monitoring: {log_error}") | |
''' | |
insert_point = error_match.end() | |
content = content[:insert_point] + error_code + content[insert_point:] | |
# Add finally block for cleanup | |
finally_pattern = r'(raise\s*\n\s*if __name__ == \'__main__\':)' | |
finally_match = re.search(finally_pattern, content) | |
if finally_match: | |
cleanup_code = ''' | |
finally: | |
# Close monitoring | |
if monitor: | |
try: | |
monitor.close() | |
logger.info("β Monitoring session closed") | |
except Exception as e: | |
logger.error(f"Failed to close monitoring: {e}") | |
''' | |
insert_point = finally_match.start() | |
content = content[:insert_point] + cleanup_code + content[insert_point:] | |
monitoring_added = True | |
break | |
if monitoring_added: | |
# Write updated content | |
with open(script_path, 'w', encoding='utf-8') as f: | |
f.write(content) | |
print(f" β Updated {script_path} with monitoring integration") | |
return True | |
else: | |
print(f" β οΈ Could not find main training function in {script_path}") | |
return False | |
def update_config_files(): | |
"""Update configuration files to include HF Datasets support""" | |
config_dir = Path("config") | |
config_files = list(config_dir.glob("*.py")) | |
print(f"π§ Updating configuration files...") | |
for config_file in config_files: | |
if config_file.name.startswith("__"): | |
continue | |
print(f" π Checking {config_file.name}...") | |
with open(config_file, 'r', encoding='utf-8') as f: | |
content = f.read() | |
# Check if HF Datasets config is already present | |
if 'TRACKIO_DATASET_REPO' in content: | |
print(f" β οΈ HF Datasets config already present in {config_file.name}") | |
continue | |
# Add HF Datasets configuration | |
trackio_pattern = r'(# Trackio monitoring configuration.*?experiment_name: Optional\[str\] = None)' | |
trackio_match = re.search(trackio_pattern, content, re.DOTALL) | |
if trackio_match: | |
hf_config = ''' | |
# HF Datasets configuration | |
hf_token: Optional[str] = None | |
dataset_repo: Optional[str] = None | |
''' | |
insert_point = trackio_match.end() | |
content = content[:insert_point] + hf_config + content[insert_point:] | |
# Write updated content | |
with open(config_file, 'w', encoding='utf-8') as f: | |
f.write(content) | |
print(f" β Added HF Datasets config to {config_file.name}") | |
else: | |
print(f" β οΈ Could not find Trackio config section in {config_file.name}") | |
def main(): | |
"""Main function to integrate monitoring into all training scripts""" | |
print("π Integrating improved monitoring with HF Datasets...") | |
print("=" * 60) | |
# Update main training script | |
main_script = "train.py" | |
if os.path.exists(main_script): | |
update_training_script(main_script) | |
else: | |
print(f"β οΈ Main training script {main_script} not found") | |
# Update configuration files | |
update_config_files() | |
# Update any other training scripts in config directory | |
config_dir = Path("config") | |
training_scripts = [ | |
"train_smollm3_openhermes_fr.py", | |
"train_smollm3_openhermes_fr_a100_balanced.py", | |
"train_smollm3_openhermes_fr_a100_large.py", | |
"train_smollm3_openhermes_fr_a100_max_performance.py", | |
"train_smollm3_openhermes_fr_a100_multiple_passes.py" | |
] | |
print(f"\nπ§ Updating training scripts in config directory...") | |
for script_name in training_scripts: | |
script_path = config_dir / script_name | |
if script_path.exists(): | |
update_training_script(str(script_path)) | |
else: | |
print(f" β οΈ Training script {script_name} not found") | |
print(f"\nβ Monitoring integration completed!") | |
print(f"\nπ Next steps:") | |
print(f"1. Set HF_TOKEN environment variable") | |
print(f"2. Optionally set TRACKIO_DATASET_REPO") | |
print(f"3. Run your training scripts with monitoring enabled") | |
print(f"4. Check your HF Dataset repository for experiment data") | |
if __name__ == "__main__": | |
main() |