|
|
|
--- |
|
license: apache-2.0 |
|
tags: |
|
- code-generation |
|
- t5 |
|
- lora |
|
- peft |
|
- transformers |
|
library_name: peft |
|
base_model: t5-small |
|
datasets: nvidia/OpenCodeReasoning |
|
model-index: |
|
- name: T5-Small with LoRA on OpenCodeReasoning |
|
results: |
|
- task: |
|
type: text2text-generation |
|
name: Code Generation |
|
dataset: |
|
name: OpenCodeReasoning |
|
type: nvidia/OpenCodeReasoning |
|
metrics: |
|
- name: Loss |
|
type: loss |
|
value: 4.69 |
|
--- |
|
|
|
# T5-Small with LoRA on OpenCodeReasoning |
|
|
|
This is a LoRA fine-tuned version of T5-small on a subset of NVIDIA's OpenCodeReasoning dataset using [PEFT](https://github.com/huggingface/peft). |
|
Improved version to be uploaded soon. |
|
|
|
## Loss Curve |
|
|
|
| Step | Train Loss | Val Loss | |
|
|------|------------|----------| |
|
| 50 | 8.63 | 8.17 | |
|
| 100 | 6.04 | 5.35 | |
|
| 150 | 5.31 | 4.90 | |
|
| 200 | 5.19 | 4.71 | |
|
| 250 | 4.94 | 4.59 | |
|
| 300 | 4.95 | 4.51 | |
|
| 350 | 4.79 | 4.46 | |
|
| 400 | 4.89 | 4.42 | |
|
| 450 | 4.69 | 4.40 | |
|
|
|
Final Train Loss: **4.69** |
|
Final Eval Loss: **4.40** |
|
|
|
|
|
## Notes |
|
|
|
Trained on subset of OpenCodeReasoning due to Colab memory limits |
|
|
|
Use PeftModel with t5-small base |
|
|
|
Metrics used: Loss (BLEU skipped due to output structure) |
|
|
|
|
|
## License |
|
Apache 2.0 |
|
|
|
|
|
## Example Usage |
|
|
|
```python |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
from peft import PeftModel, PeftConfig |
|
|
|
config = PeftConfig.from_pretrained("ShahzebKhoso/t5-small-opencode-lora") |
|
base_model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path) |
|
model = PeftModel.from_pretrained(base_model, "ShahzebKhoso/t5-small-opencode-lora") |
|
tokenizer = AutoTokenizer.from_pretrained("ShahzebKhoso/t5-small-opencode-lora") |
|
|
|
inputs = tokenizer("generate code: write a function to reverse a string", return_tensors="pt") |
|
outputs = model.generate(**inputs) |
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
''' |
|
|
|
|