|  | """ | 
					
						
						|  | This module contains unit tests for the `freeze_layers_except` function. | 
					
						
						|  |  | 
					
						
						|  | The `freeze_layers_except` function is used to freeze layers in a model, except for the specified layers. | 
					
						
						|  | The unit tests in this module verify the behavior of the `freeze_layers_except` function in different scenarios. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | import unittest | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  |  | 
					
						
						|  | from axolotl.utils.freeze import freeze_layers_except | 
					
						
						|  |  | 
					
						
						|  | ZERO = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] | 
					
						
						|  | ONE_TO_TEN = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TestFreezeLayersExcept(unittest.TestCase): | 
					
						
						|  | """ | 
					
						
						|  | A test case class for the `freeze_layers_except` function. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def setUp(self): | 
					
						
						|  | self.model = _TestModel() | 
					
						
						|  |  | 
					
						
						|  | def test_freeze_layers_with_dots_in_name(self): | 
					
						
						|  | freeze_layers_except(self.model, ["features.layer"]) | 
					
						
						|  | self.assertTrue( | 
					
						
						|  | self.model.features.layer.weight.requires_grad, | 
					
						
						|  | "model.features.layer should be trainable.", | 
					
						
						|  | ) | 
					
						
						|  | self.assertFalse( | 
					
						
						|  | self.model.classifier.weight.requires_grad, | 
					
						
						|  | "model.classifier should be frozen.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def test_freeze_layers_without_dots_in_name(self): | 
					
						
						|  | freeze_layers_except(self.model, ["classifier"]) | 
					
						
						|  | self.assertFalse( | 
					
						
						|  | self.model.features.layer.weight.requires_grad, | 
					
						
						|  | "model.features.layer should be trainable.", | 
					
						
						|  | ) | 
					
						
						|  | self.assertTrue( | 
					
						
						|  | self.model.classifier.weight.requires_grad, | 
					
						
						|  | "model.classifier should be frozen.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def test_freeze_layers_regex_patterns(self): | 
					
						
						|  |  | 
					
						
						|  | freeze_layers_except(self.model, [r"^features.[a-z]+.weight$", r"class[a-c]+"]) | 
					
						
						|  | self.assertTrue( | 
					
						
						|  | self.model.features.layer.weight.requires_grad, | 
					
						
						|  | "model.features.layer should be trainable.", | 
					
						
						|  | ) | 
					
						
						|  | self.assertFalse( | 
					
						
						|  | self.model.classifier.weight.requires_grad, | 
					
						
						|  | "model.classifier should be frozen.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def test_all_layers_frozen(self): | 
					
						
						|  | freeze_layers_except(self.model, []) | 
					
						
						|  | self.assertFalse( | 
					
						
						|  | self.model.features.layer.weight.requires_grad, | 
					
						
						|  | "model.features.layer should be frozen.", | 
					
						
						|  | ) | 
					
						
						|  | self.assertFalse( | 
					
						
						|  | self.model.classifier.weight.requires_grad, | 
					
						
						|  | "model.classifier should be frozen.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def test_all_layers_unfrozen(self): | 
					
						
						|  | freeze_layers_except(self.model, ["features.layer", "classifier"]) | 
					
						
						|  | self.assertTrue( | 
					
						
						|  | self.model.features.layer.weight.requires_grad, | 
					
						
						|  | "model.features.layer should be trainable.", | 
					
						
						|  | ) | 
					
						
						|  | self.assertTrue( | 
					
						
						|  | self.model.classifier.weight.requires_grad, | 
					
						
						|  | "model.classifier should be trainable.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def test_freeze_layers_with_range_pattern_start_end(self): | 
					
						
						|  | freeze_layers_except(self.model, ["features.layer[1:5]"]) | 
					
						
						|  | self.assertTrue( | 
					
						
						|  | self.model.features.layer.weight.requires_grad, | 
					
						
						|  | "model.features.layer should be trainable.", | 
					
						
						|  | ) | 
					
						
						|  | self.assertFalse( | 
					
						
						|  | self.model.classifier.weight.requires_grad, | 
					
						
						|  | "model.classifier should be frozen.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self._assert_gradient_output( | 
					
						
						|  | [ | 
					
						
						|  | ZERO, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def test_freeze_layers_with_range_pattern_single_index(self): | 
					
						
						|  | freeze_layers_except(self.model, ["features.layer[5]"]) | 
					
						
						|  | self.assertTrue( | 
					
						
						|  | self.model.features.layer.weight.requires_grad, | 
					
						
						|  | "model.features.layer should be trainable.", | 
					
						
						|  | ) | 
					
						
						|  | self.assertFalse( | 
					
						
						|  | self.model.classifier.weight.requires_grad, | 
					
						
						|  | "model.classifier should be frozen.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self._assert_gradient_output( | 
					
						
						|  | [ZERO, ZERO, ZERO, ZERO, ZERO, ONE_TO_TEN, ZERO, ZERO, ZERO, ZERO] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def test_freeze_layers_with_range_pattern_start_omitted(self): | 
					
						
						|  | freeze_layers_except(self.model, ["features.layer[:5]"]) | 
					
						
						|  | self.assertTrue( | 
					
						
						|  | self.model.features.layer.weight.requires_grad, | 
					
						
						|  | "model.features.layer should be trainable.", | 
					
						
						|  | ) | 
					
						
						|  | self.assertFalse( | 
					
						
						|  | self.model.classifier.weight.requires_grad, | 
					
						
						|  | "model.classifier should be frozen.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self._assert_gradient_output( | 
					
						
						|  | [ | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def test_freeze_layers_with_range_pattern_end_omitted(self): | 
					
						
						|  | freeze_layers_except(self.model, ["features.layer[4:]"]) | 
					
						
						|  | self.assertTrue( | 
					
						
						|  | self.model.features.layer.weight.requires_grad, | 
					
						
						|  | "model.features.layer should be trainable.", | 
					
						
						|  | ) | 
					
						
						|  | self.assertFalse( | 
					
						
						|  | self.model.classifier.weight.requires_grad, | 
					
						
						|  | "model.classifier should be frozen.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self._assert_gradient_output( | 
					
						
						|  | [ | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def test_freeze_layers_with_range_pattern_merge_included(self): | 
					
						
						|  | freeze_layers_except(self.model, ["features.layer[4:]", "features.layer[5:6]"]) | 
					
						
						|  | self.assertTrue( | 
					
						
						|  | self.model.features.layer.weight.requires_grad, | 
					
						
						|  | "model.features.layer should be trainable.", | 
					
						
						|  | ) | 
					
						
						|  | self.assertFalse( | 
					
						
						|  | self.model.classifier.weight.requires_grad, | 
					
						
						|  | "model.classifier should be frozen.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self._assert_gradient_output( | 
					
						
						|  | [ | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def test_freeze_layers_with_range_pattern_merge_intersect(self): | 
					
						
						|  | freeze_layers_except(self.model, ["features.layer[4:7]", "features.layer[6:8]"]) | 
					
						
						|  | self.assertTrue( | 
					
						
						|  | self.model.features.layer.weight.requires_grad, | 
					
						
						|  | "model.features.layer should be trainable.", | 
					
						
						|  | ) | 
					
						
						|  | self.assertFalse( | 
					
						
						|  | self.model.classifier.weight.requires_grad, | 
					
						
						|  | "model.classifier should be frozen.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self._assert_gradient_output( | 
					
						
						|  | [ | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def test_freeze_layers_with_range_pattern_merge_separate(self): | 
					
						
						|  | freeze_layers_except( | 
					
						
						|  | self.model, | 
					
						
						|  | ["features.layer[1:2]", "features.layer[3:4]", "features.layer[5:6]"], | 
					
						
						|  | ) | 
					
						
						|  | self.assertTrue( | 
					
						
						|  | self.model.features.layer.weight.requires_grad, | 
					
						
						|  | "model.features.layer should be trainable.", | 
					
						
						|  | ) | 
					
						
						|  | self.assertFalse( | 
					
						
						|  | self.model.classifier.weight.requires_grad, | 
					
						
						|  | "model.classifier should be frozen.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self._assert_gradient_output( | 
					
						
						|  | [ | 
					
						
						|  | ZERO, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ZERO, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ZERO, | 
					
						
						|  | ONE_TO_TEN, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ZERO, | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def _assert_gradient_output(self, expected): | 
					
						
						|  | input_tensor = torch.tensor([ONE_TO_TEN], dtype=torch.float32) | 
					
						
						|  |  | 
					
						
						|  | self.model.features.layer.weight.grad = None | 
					
						
						|  | output = self.model.features.layer(input_tensor) | 
					
						
						|  | loss = output.sum() | 
					
						
						|  | loss.backward() | 
					
						
						|  |  | 
					
						
						|  | expected_grads = torch.tensor(expected) | 
					
						
						|  | torch.testing.assert_close( | 
					
						
						|  | self.model.features.layer.weight.grad, expected_grads | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class _SubLayerModule(nn.Module): | 
					
						
						|  | def __init__(self): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.layer = nn.Linear(10, 10) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class _TestModel(nn.Module): | 
					
						
						|  | def __init__(self): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.features = _SubLayerModule() | 
					
						
						|  | self.classifier = nn.Linear(10, 2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | unittest.main() | 
					
						
						|  |  |