Edit model card

Gemma-IT-Expanded-Unfrozen-Layers

This method employs mergekit's passthrough method to expand blocks within the "google/gemma-7b-it" model. For every fourth layer, a new layer is added, with the o_proj and down_proj parameters of these added layers initialized to zero, mirroring the approach used in LLaMA Pro. It's important to note that this configuration has not undergone fine-tuning. Therefore, when fine-tuning, ensure that only every fourth layer is adjusted, while all other layers remain frozen.

🧩 Configuration

slices:
  - sources:
      - model: google/gemma-7b-it
        layer_range: [0, 4]
  - sources:
      - model: google/gemma-7b-it
        layer_range: [3, 4]
        parameters:
          scale:
            - filter: o_proj
              value: 0.0
            - filter: down_proj
              value: 0.0
            - value: 1.0

  - sources:
      - model: google/gemma-7b-it
        layer_range: [4, 8]
  - sources:
      - model: google/gemma-7b-it
        layer_range: [7, 8]
        parameters:
          scale:
            - filter: o_proj
              value: 0.0
            - filter: down_proj
              value: 0.0
            - value: 1.0

  - sources:
      - model: google/gemma-7b-it
        layer_range: [8, 12]
  - sources:
      - model: google/gemma-7b-it
        layer_range: [11, 12]
        parameters:
          scale:
            - filter: o_proj
              value: 0.0
            - filter: down_proj
              value: 0.0
            - value: 1.0

  - sources:
      - model: google/gemma-7b-it
        layer_range: [12, 16]
  - sources:
      - model: google/gemma-7b-it
        layer_range: [15, 16]
        parameters:
          scale:
            - filter: o_proj
              value: 0.0
            - filter: down_proj
              value: 0.0
            - value: 1.0

  - sources:
      - model: google/gemma-7b-it
        layer_range: [16, 20]
  - sources:
      - model: google/gemma-7b-it
        layer_range: [19, 20]
        parameters:
          scale:
            - filter: o_proj
              value: 0.0
            - filter: down_proj
              value: 0.0
            - value: 1.0

  - sources:
      - model: google/gemma-7b-it
        layer_range: [20, 24]
  - sources:
      - model: google/gemma-7b-it
        layer_range: [23, 24]
        parameters:
          scale:
            - filter: o_proj
              value: 0.0
            - filter: down_proj
              value: 0.0
            - value: 1.0

  - sources:
      - model: google/gemma-7b-it
        layer_range: [24, 28]
  - sources:
      - model: google/gemma-7b-it
        layer_range: [27, 28]
        parameters:
          scale:
            - filter: o_proj
              value: 0.0
            - filter: down_proj
              value: 0.0
            - value: 1.0

merge_method: passthrough
dtype: bfloat16


# Function to freeze layers

from transformers import AutoModelForCausalLM

def update_layer_gradients(model, n):
    """
    Enables gradients only for every nth layer within the model's layers, starting from the layer after the 0th.
    
    :param model: The model instance, assumed to be of type GemmaForCausalLM or similar.
    :param n: Interval at which layers after the first will have their gradients enabled, indicating they are newly added.
    """
    layers = model.model.layers  # Access the ModuleList containing the layers
  
    for i, layer in enumerate(layers):
        if i % n == (n - 1):  # Enables gradients for every nth layer, starting from the layer after the 0th
            print(i)
            for param in layer.parameters():
                param.requires_grad = True
        else:
            for param in layer.parameters():
                param.requires_grad = False

# Load the model
model = AutoModelForCausalLM.from_pretrained("/Users/gayalshamane/Documents/mergekit/gemma-2b-it-expanded")


# Update layer gradients, specify the correct value for n based on your model's architecture
n = 5  # Example: update every 4rd layer, starting from the first layer after the 0th, adjust this value as needed
update_layer_gradients(model, n)
Downloads last month
12
Safetensors
Model size
10.5B params
Tensor type
F32
·
Inference Examples
Inference API (serverless) is not available, repository is disabled.