--- license: other license_name: ngen4-community-license license_link: https://tnsaai-builds.framer.website/community/licenses/ngen4 tags: - ngen4 - causal-lm - large-language-model - transformers - pytorch - rope # - wikitext-2 # Add if this is the version being shared # - redpajama # Add if a RedPajama trained version is shared # - instruction-tuned # Add if an instruction-tuned version is shared pipeline_tag: text-generation # model-index: # - name: ngen4-50m-wikitext2 # Example # results: # - task: # type: text-generation # name: Perplexity # dataset: # name: Wikitext-2 (Validation) # type: wikitext # config: wikitext-2-raw-v1 # split: validation # metrics: # - type: perplexity # value: [TODO: Add perplexity from training] # e.g., 1.7781 # - type: eval_loss # value: [TODO: Add eval loss from training] # e.g., 0.5755 --- # NGen4: A Causal Language Model with Rotary Positional Embeddings **NGen4** is a decoder-only causal language model architecture designed with a focus on modern techniques, particularly Rotary Positional Embeddings (RoPE) for effective handling of long contexts and flexible attention mechanisms. This repository contains the PyTorch implementation using the Hugging Face Transformers library. This model card describes the NGen4 architecture. Specific instances (e.g., a 50M parameter model pre-trained on Wikitext-2) will have their details in respective sections or separate model cards if uploaded individually. ## Model Approach & Architecture NGen4 is built upon the standard Transformer decoder architecture, incorporating several key design choices for performance, flexibility, and effective long-context modeling. ### 1. Core Architecture: Decoder-Only Causal LM * **Type**: Autoregressive, decoder-only Transformer. * **Objective**: Trained for next-token prediction, making it suitable for text generation tasks. * **Causal Masking**: A causal attention mask is applied in the self-attention layers to ensure that predictions for a token at position `i` can only depend on known outputs at positions less than `i`. ### 2. Positional Embeddings: Rotary Positional Embeddings (RoPE) NGen4 replaces traditional absolute or learned positional embeddings with **Rotary Positional Embeddings (RoPE)**. * **Mechanism**: RoPE encodes absolute positional information by rotating pairs of features in the query and key projections based on their position. This is applied *after* the Q/K projections but *before* the attention dot product. * **Advantages**: * **Long Context Scalability**: RoPE has shown excellent performance in generalizing to sequence lengths longer than those seen during training. * **Relative Position Encoding**: While encoding absolute positions, it implicitly captures relative positional information in the self-attention mechanism through the rotational property. * **No Trainable Parameters**: RoPE itself does not add trainable parameters for positional encoding. * **Implementation Details**: * `rope_theta`: The base period for the rotary encodings (default: `10000.0`). * `rope_pct`: The percentage of head dimensions to which RoPE is applied (default: `1.0`, meaning all dimensions). The actual `rope_dim` is calculated as `head_dim * rope_pct`. * The implementation includes caching for `sin` and `cos` values to improve efficiency. * The configuration includes `rope_scaling` (default: `None`) as a placeholder for future integration of RoPE scaling strategies (e.g., NTK-aware scaling, YaRN) to further enhance long-context capabilities. ### 3. Attention Mechanism: Multi-Head Self-Attention (MHSA) * **Standard MHSA**: The model uses multi-head self-attention as the core mechanism for information aggregation. * **Configurable Implementations**: The `NGen4Config` allows specifying different attention implementations via the `attn_implementation` parameter: * `"eager"`: The standard PyTorch implementation. Provides a clear reference but can be less memory/compute efficient. * `"sdpa"` (Scaled Dot Product Attention): Leverages PyTorch 2.0's built-in optimized `F.scaled_dot_product_attention`. Generally faster and more memory-efficient than eager. * `"flash_attention_2"`: If the `flash-attn` library is installed and PyTorch >= 2.0, this option can be used for significant speedups and memory savings, especially for longer sequences and on compatible hardware. * **Projections**: Separate linear projections are used to create Query (Q), Key (K), and Value (V) tensors from the input hidden states (`c_attn`). An output projection (`c_proj`) is applied after attention. * **Dropout**: Dropout is applied to attention weights (`attn_pdrop`) and residual connections (`resid_pdrop`). ### 4. Transformer Blocks (`NGen4Block`) Each NGen4 block follows a standard Pre-LayerNormalization (Pre-LN) structure: 1. **Layer Normalization (`ln_1`)**: Applied to the input hidden states. 2. **Multi-Head Self-Attention (`attn`)**: As described above. 3. **Residual Connection**: The output of the attention module is added to the input of `ln_1`. 4. **Layer Normalization (`ln_2`)**: Applied to the output of the first residual connection. 5. **Feed-Forward Network (MLP) (`mlp`)**: * Two linear layers with an activation function in between. * The intermediate size (`n_inner`) defaults to `4 * n_embd`. * The activation function is configurable via `activation_function` (e.g., `"gelu_new"`). 6. **Residual Connection**: The output of the MLP is added to the input of `ln_2`. ### 5. Embeddings and Output Layer * **Token Embeddings (`wte`)**: A standard learnable embedding layer maps input token IDs to dense vectors (`n_embd` dimensions). * **LM Head (`lm_head`)**: A linear layer maps the final hidden states from the transformer blocks to logits over the vocabulary (`vocab_size`). * **Weight Tying**: The weights of the token embedding layer (`wte.weight`) and the LM head (`lm_head.weight`) are typically tied. This is declared in the model via `_tied_weights_keys = ["lm_head.weight"]` and in the configuration via `tie_word_embeddings=True`. This practice reduces parameters and can improve performance. ### 6. Key Configuration Parameters The architecture is defined by parameters in `NGen4Config`, including: * `vocab_size`: Size of the vocabulary. * `n_positions`: Maximum sequence length the model can process (context window). * `n_embd`: Dimensionality of the token embeddings and hidden states. * `n_layer`: Number of NGen4 transformer blocks. * `n_head`: Number of attention heads. * `n_inner`: Dimensionality of the intermediate layer in the MLP. * `activation_function`: Activation function for the MLP. * Dropout rates: `resid_pdrop`, `embd_pdrop`, `attn_pdrop`. * `layer_norm_epsilon`: Epsilon for LayerNorm stability. * RoPE parameters: `use_rope`, `rope_theta`, `rope_scaling`, `rope_pct`. * `attn_implementation`: Choice of attention backend. * `tie_word_embeddings`: Whether to tie input and output embeddings. ### 7. Gradient Checkpointing The model supports gradient checkpointing to reduce memory usage during training by trading VRAM for a small amount of recomputation. This is enabled through the Hugging Face `Trainer` or by directly setting the `gradient_checkpointing` attribute on the `NGen4Model` instance. ## Training Data *(This section should be updated based on the specific model instance)* * **Initial Pre-training (Example for a 50M parameter model):** * **Dataset:** Wikitext-2 (`wikitext`, `wikitext-2-raw-v1` configuration). * **Preprocessing:** Text was tokenized using a GPT-2 tokenizer and then grouped into blocks of sequence length (e.g., 512 tokens). * **Planned Further Pre-training:** * The model architecture is designed for larger datasets. Future work includes further pre-training on more extensive and diverse corpora like subsets of RedPajama (e.g., the arXiv subset) to enhance general language understanding and generation capabilities. ## Training Procedure *(This section should be updated based on the specific model instance)* * **Framework:** Trained using PyTorch and the Hugging Face `Trainer` API. * **Example Configuration (for Wikitext-2, ~50M model):** * `n_embd`: 512 * `n_layer`: 8 * `n_head`: 8 * `n_positions`: 512 (also `block_size` for training) * `attn_implementation`: "eager" (can be set to "flash_attention_2" or "sdpa" if supported) * **Optimizer:** AdamW. * **Learning Rate:** E.g., 5e-5 with a linear scheduler and warmup. * **Batch Size:** Effective batch size achieved through `per_device_train_batch_size` and `gradient_accumulation_steps`. * **Epochs:** E.g., 3 epochs for Wikitext-2. * **Mixed Precision (FP16):** Enabled to reduce memory and potentially speed up training. * **Gradient Checkpointing:** Utilized during training. ## Evaluation Results *(This section should be updated based on the specific model instance. Provide perplexity scores, loss, etc., on relevant evaluation datasets.)* * **Example (Wikitext-2, ~50M model):** * Validation Perplexity: `[TODO: Add perplexity, e.g., 1.7781]` * Validation Loss: `[TODO: Add eval loss, e.g., 0.5755]` ## How to Use ### Prerequisites Ensure you have PyTorch, Transformers, and the `ngen4.py` model definition file. ```bash pip install torch transformers ``` If using Flash Attention 2, install it separately: ```bash pip install flash-attn --no-build-isolation ``` ### Loading the Model and Tokenizer ```python from transformers import AutoTokenizer from ngen4 import NGen4ForCausalLM, NGen4Config # Ensure ngen4.py is in your Python path model_path = "[path_to_your_saved_ngen4_model_directory]" # e.g., "./ngen4_wikitext2_50M_tied_v2" # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Common practice for GPT-like models # Load model # The NGen4Config will be loaded automatically from the model_path if config.json exists. # If you need to override config parameters, you can load NGen4Config first, modify it, # and then pass it to from_pretrained. model = NGen4ForCausalLM.from_pretrained(model_path) # Set device (e.g., CUDA if available) import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() # Set to evaluation mode for inference ``` ### Text Generation Example ```python prompt = "Once upon a time in a land far away" input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) attention_mask = torch.ones_like(input_ids) # Explicitly create attention mask with torch.no_grad(): output_sequences = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_length=150, temperature=0.7, top_k=50, top_p=0.95, do_sample=True, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id ) generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True) print(generated_text) ``` *(For more detailed sampling options, see the `sample_ngen4.py` script if provided in the model repository.)* ## Intended Use & Limitations * **Intended Use:** This model is intended for text generation, capable of tasks like completing prompts, creative writing, and potentially (with further fine-tuning) summarization, question answering, etc. The base pre-trained model is primarily for next-token prediction based on its training data. * **Limitations:** * The model's knowledge is limited to its training data. * It may generate factually incorrect, biased, or nonsensical text. * Performance on specific downstream tasks will heavily depend on the quality and nature of further fine-tuning. * The current ~50M parameter model trained on Wikitext-2 is a demonstration model and will have limited capabilities compared to larger models trained on more diverse datasets. ## Future Work * **Further Pre-training:** Scale up pre-training using larger and more diverse datasets such as subsets of RedPajama (e.g., arXiv, Books, CommonCrawl). * **Instruction Fine-tuning:** Fine-tune the pre-trained NGen4 model on instruction-following datasets (e.g., Alpaca, Dolly, OpenOrca) to improve its ability to follow commands and engage in conversational interactions. * **RoPE Scaling:** Implement and evaluate RoPE scaling techniques (NTK-aware, YaRN) to further enhance long-context performance. * **Evaluation:** Conduct comprehensive evaluations on a wider range of downstream NLP benchmarks. --- *Model architecture and training scripts developed by TNSA AI.