File size: 1,980 Bytes
5b26276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python3
"""
Export AuraMind models for mobile deployment
Creates optimized .ptl files for PyTorch Mobile
"""

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os

def export_for_mobile(model_name: str, variant: str):
    """Export model for mobile deployment"""
    
    print(f"Exporting {model_name} ({variant}) for mobile...")
    
    # Load model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="cpu",
        low_cpu_mem_usage=True
    )
    
    # Prepare for mobile export
    model.eval()
    
    # Create example input
    example_text = "[Assistant Mode] Help me with my tasks"
    example_input = tokenizer(
        example_text, 
        return_tensors="pt", 
        max_length=512,
        truncation=True
    )["input_ids"]
    
    # Trace the model
    traced_model = torch.jit.trace(model, example_input)
    
    # Optimize for mobile
    optimized_model = torch.jit.optimize_for_mobile(traced_model)
    
    # Save mobile-optimized model
    output_path = f"auramind_{variant}_mobile.ptl"
    optimized_model._save_for_lite_interpreter(output_path)
    
    print(f"✅ Mobile model saved: {output_path}")
    
    # Create metadata file
    metadata = {
        "model_name": model_name,
        "variant": variant,
        "tokenizer_vocab_size": tokenizer.vocab_size,
        "max_length": 512,
        "export_date": torch.jit.get_jit_operator_version(),
        "pytorch_version": torch.__version__
    }
    
    import json
    with open(f"auramind_{variant}_metadata.json", "w") as f:
        json.dump(metadata, f, indent=2)
    
    return output_path

if __name__ == "__main__":
    variants = ["270m", "180m", "90m"]
    
    for variant in variants:
        export_for_mobile("zail-ai/Auramind", variant)
        
    print("\n✅ All mobile exports completed!")