File size: 3,598 Bytes
eca6215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import pytest
import torch
from src.config import (MAX_SEQ_LENGTH, DTYPE, LOAD_IN_4BIT, DEVICE_MAP, EOS_TOKEN, 
                        ALPACA_PROMPT_TEMPLATE, TRAIN_ARGS)

# Test that required configuration keys are present
def test_required_config_keys():
    assert MAX_SEQ_LENGTH is not None, "MAX_SEQ_LENGTH is not set."
    assert TRAIN_ARGS is not None, "TRAIN_ARGS is not set."
    assert ALPACA_PROMPT_TEMPLATE is not None, "ALPACA_PROMPT_TEMPLATE is not set."
    assert DEVICE_MAP is not None, "DEVICE_MAP is not set."

# Test that MAX_SEQ_LENGTH is a power of two
def test_max_seq_length():
    assert isinstance(MAX_SEQ_LENGTH, int), "MAX_SEQ_LENGTH should be an integer."
    assert MAX_SEQ_LENGTH > 0, "MAX_SEQ_LENGTH should be greater than 0."
    assert (MAX_SEQ_LENGTH & (MAX_SEQ_LENGTH - 1)) == 0, "MAX_SEQ_LENGTH should be a power of two."

# Test that TRAIN_ARGS dictionary contains required fields and types
def test_train_args():
    required_keys = [
        "per_device_train_batch_size", 
        "gradient_accumulation_steps", 
        "warmup_steps", 
        "max_steps", 
        "learning_rate", 
        "fp16", 
        "bf16", 
        "logging_steps", 
        "optim", 
        "weight_decay", 
        "lr_scheduler_type", 
        "seed", 
        "output_dir"
    ]
    
    for key in required_keys:
        assert key in TRAIN_ARGS, f"Missing {key} in TRAIN_ARGS."
    
    # Check types of specific fields
    assert isinstance(TRAIN_ARGS["per_device_train_batch_size"], int), "per_device_train_batch_size should be an integer."
    assert isinstance(TRAIN_ARGS["learning_rate"], float), "learning_rate should be a float."
    assert isinstance(TRAIN_ARGS["output_dir"], str), "output_dir should be a string."

# Test that the DEVICE_MAP references a valid CUDA device
@pytest.mark.gpu
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
def test_device_map():
    device = DEVICE_MAP.get('', None)
    assert device is not None, "DEVICE_MAP should reference a CUDA device."
    assert isinstance(device, int), "DEVICE_MAP should be an integer (CUDA device ID)."
    assert torch.cuda.is_available(), "CUDA is not available, but DEVICE_MAP points to a CUDA device."

# Test that the EOS_TOKEN is set dynamically based on the tokenizer
def test_eos_token():
    assert EOS_TOKEN is not None, "EOS_TOKEN should be dynamically set based on tokenizer."

# Test the ALPACA_PROMPT_TEMPLATE for expected formatting
def test_alpaca_prompt_template():
    test_instruction = "Test Instruction"
    test_input = "Test Input"
    test_output = "Test Output"
    
    formatted_prompt = ALPACA_PROMPT_TEMPLATE.format(test_instruction, test_input, test_output)
    
    # Ensure that the prompt template contains the required placeholders
    assert "{}" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain placeholders."
    assert "###Instruction:" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain '###Instruction'."
    assert "###Input:" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain '###Input'."
    assert "###Response:" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain '###Response'."

# Test that the LOAD_IN_4BIT setting is a boolean
def test_load_in_4bit():
    assert isinstance(LOAD_IN_4BIT, bool), "LOAD_IN_4BIT should be a boolean."

# Test for the DTYPE (should be None or a valid data type)
def test_dtype():
    assert DTYPE is None or isinstance(DTYPE, type), "DTYPE should be None or a valid data type."