CodeT5-small Terminal Describer ONNX
This repository contains the ONNX (FP32 and INT8 quantized) versions of the fine-tuned CodeT5-small model for terminal command description. The base PyTorch model was trained on a combined dataset derived from NL2Bash, TLDR Pages, and NL2SH-ALFA.
For details on the training process, evaluation results, and performance metrics of the PyTorch model, please refer to the main model repository: Mitchins/codet5-small-terminal-describer
Model Structure
This repository is structured to provide both FP32 and INT8 quantized ONNX models, along with all necessary tokenizer and configuration files in the root for easy loading.
- Root Directory: Contains
config.json
, tokenizer files (vocab.json
,merges.txt
,tokenizer_config.json
,special_tokens_map.json
,added_tokens.json
,spiece.model
,generation_config.json
), and thisREADME.md
. fp32/
directory: Contains the FP32 ONNX models (encoder_model.onnx
,decoder_model.onnx
,decoder_with_past_model.onnx
).int8/
directory: Contains the INT8 quantized ONNX models (encoder_model.onnx
,decoder_model.onnx
).
Usage
Python Inference Example (ONNX Runtime)
To perform inference using the ONNX models with onnxruntime
, you can use the following Python code snippet. This example demonstrates how to load the encoder and decoder models and perform a generation step.
from transformers import AutoTokenizer
import onnxruntime
import numpy as np
import os
# --- Configuration ---
# Path to the directory containing the ONNX models and tokenizer files
# Make sure to download the model files from this repository first.
# Example:
# huggingface-cli download Mitchins/codet5-small-terminal-describer-ONNX --local-dir ./codet5-small-terminal-describer-ONNX
model_dir = "." # Current directory if downloaded locally
# --- Load Tokenizer and Config ---
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# --- Load ONNX Sessions (FP32 example) ---
encoder_session = onnxruntime.InferenceSession(os.path.join(model_dir, 'fp32/encoder_model.onnx'))
decoder_session = onnxruntime.InferenceSession(os.path.join(model_dir, 'fp32/decoder_model.onnx'))
decoder_with_past_session = onnxruntime.InferenceSession(os.path.join(model_dir, 'fp32/decoder_with_past_model.onnx'))
# For INT8 models:
# encoder_session_int8 = onnxruntime.InferenceSession(os.path.join(model_dir, 'int8/encoder_model.onnx'))
# decoder_session_int8 = onnxruntime.InferenceSession(os.path.join(model_dir, 'int8/decoder_model.onnx'))
# --- Inference Function ---
def generate_description_onnx(command, max_length=50, current_encoder_session=encoder_session, current_decoder_session=decoder_session, current_decoder_with_past_session=decoder_with_past_session):
input_text = f'describe: {command}'
input_ids = tokenizer(input_text, return_tensors='np').input_ids
attention_mask = np.ones(input_ids.shape, dtype=np.int64)
# 1. Encode input
encoder_outputs = current_encoder_session.run(None, {
"input_ids": input_ids,
"attention_mask": attention_mask
})
encoder_hidden_states = encoder_outputs[0]
# 2. Initialize decoder input
decoder_input_ids = np.array([[tokenizer.pad_token_id]], dtype=np.int64) # Start with pad_token_id
generated_tokens = []
past_decoder_key_values = None
past_encoder_key_values = None
for _ in range(max_length):
if past_decoder_key_values is None:
# First step: use decoder_session
decoder_outputs = current_decoder_session.run(None, {
"input_ids": decoder_input_ids,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": attention_mask
})
logits = decoder_outputs[0]
# Collect all present key-value pairs from the first decoder output
past_decoder_key_values = []
past_encoder_key_values = []
# Assuming 6 layers for CodeT5-small, each with 2 key/value pairs for decoder and 2 for encoder
for i in range(1, len(decoder_outputs), 4):
past_decoder_key_values.append(decoder_outputs[i]) # present.X.decoder.key
past_decoder_key_values.append(decoder_outputs[i+1]) # present.X.decoder.value
past_encoder_key_values.append(decoder_outputs[i+2]) # present.X.encoder.key
past_encoder_key_values.append(decoder_outputs[i+3]) # present.X.encoder.value
else:
# Subsequent steps: use decoder_with_past_session
decoder_inputs = {
"input_ids": decoder_input_ids[:, -1:], # Only pass the last generated token
"encoder_attention_mask": attention_mask # Encoder attention mask is constant
}
# Add past_key_values to decoder_inputs
# Assuming 6 layers for CodeT5-small
for i in range(6):
decoder_inputs[f"past_key_values.{i}.decoder.key"] = past_decoder_key_values[i*2]
decoder_inputs[f"past_key_values.{i}.decoder.value"] = past_decoder_key_values[i*2+1]
decoder_inputs[f"past_key_values.{i}.encoder.key"] = past_encoder_key_values[i*2]
decoder_inputs[f"past_key_values.{i}.encoder.value"] = past_encoder_key_values[i*2+1]
decoder_outputs = current_decoder_with_past_session.run(None, decoder_inputs)
logits = decoder_outputs[0]
# Update only the decoder key-value pairs from the output of decoder_with_past_session
new_past_decoder_key_values = []
for i in range(1, len(decoder_outputs), 2): # Iterate in groups of 2 for decoder key/value
new_past_decoder_key_values.append(decoder_outputs[i]) # present.X.decoder.key
new_past_decoder_key_values.append(decoder_outputs[i+1]) # present.X.decoder.value
past_decoder_key_values = new_past_decoder_key_values
next_token_logits = logits[:, -1, :]
next_token = np.argmax(next_token_logits, axis=-1)
if next_token.item() == tokenizer.eos_token_id:
break
generated_tokens.append(next_token.item())
decoder_input_ids = np.concatenate([decoder_input_ids, next_token.reshape(1, 1)], axis=-1)
description = tokenizer.decode(generated_tokens, skip_special_tokens=True)
return description
# --- Example Usage ---
command_input = "ls -l"
description = generate_description_onnx(command_input)
print(f"Command: {command_input}")
print(f"Description: {description}")
# Example with INT8 models (uncomment to use)
# encoder_session_int8 = onnxruntime.InferenceSession(os.path.join(model_dir, 'int8/encoder_model.onnx'))
# decoder_session_int8 = onnxruntime.InferenceSession(os.path.join(model_dir, 'int8/decoder_model.onnx'))
# description_int8 = generate_description_onnx(command_input, current_encoder_session=encoder_session_int8, current_decoder_session=decoder_session_int8)
# print(f"Description (INT8): {description_int8}")
- Downloads last month
- 3