babybirdprd commited on
Commit
d43add7
·
verified ·
1 Parent(s): a6627dc

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +163 -0
README.md ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model:
4
+ - BAAI/bge-code-v1
5
+ tags:
6
+ - bge
7
+ - embedding
8
+ - code
9
+ - onnx
10
+ - ONNX
11
+ ---
12
+ This repository contains the ONNX (Open Neural Network Exchange) version of the powerful BAAI/bge-code-v1 model, optimized for high-performance inference.This model is ideal for generating embeddings for code snippets and can be used in a variety of environments thanks to the ONNX Runtime.Original ModelThis model is a conversion of BAAI/bge-code-v1. All credit for the training and architecture goes to the original authors at the Beijing Academy of Artificial Intelligence (BAAI).How to UseBelow are examples of how to use this ONNX model in Python, Rust, and with Docker for a production-ready API.Usage with PythonYou can use the onnxruntime and huggingface_hub libraries to easily download and run this model.1. Install Dependencies:pip install onnxruntime huggingface_hub tokenizers
13
+ 2. Python Code:import numpy as np
14
+ import onnxruntime as ort
15
+ from huggingface_hub import snapshot_download
16
+ from tokenizers import Tokenizer
17
+ import os
18
+
19
+ class BgeCodeOnnx:
20
+ def __init__(self, repo_id="babybirdprd/bge-code-v1-onnx"):
21
+ # Download all files from the Hub and get the local directory path
22
+ snapshot_dir = snapshot_download(repo_id=repo_id)
23
+
24
+ # Load the tokenizer and the ONNX session
25
+ model_path = os.path.join(snapshot_dir, "model.onnx")
26
+ tokenizer_path = os.path.join(snapshot_dir, "tokenizer.json")
27
+
28
+ self.tokenizer = Tokenizer.from_file(tokenizer_path)
29
+ self.session = ort.InferenceSession(model_path)
30
+
31
+ # Get the expected input names from the model
32
+ self.input_names = [inp.name for inp in self.session.get_inputs()]
33
+ print(f"Model initialized. Expects inputs: {self.input_names}")
34
+
35
+ def embed(self, sentences):
36
+ # Tokenize the input sentences
37
+ self.tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=512)
38
+ encoded_input = self.tokenizer.encode_batch(sentences)
39
+
40
+ # Build the input dictionary for ONNX Runtime
41
+ ort_inputs = {}
42
+ if "input_ids" in self.input_names:
43
+ ort_inputs['input_ids'] = np.array([e.ids for e in encoded_input], dtype=np.int64)
44
+ if "attention_mask" in self.input_names:
45
+ ort_inputs['attention_mask'] = np.array([e.attention_mask for e in encoded_input], dtype=np.int64)
46
+
47
+ # Run inference
48
+ ort_outputs = self.session.run(None, ort_inputs)
49
+ last_hidden_state = ort_outputs[0]
50
+
51
+ # Perform pooling (get the [CLS] token embedding)
52
+ pooled_embeddings = last_hidden_state[:, 0]
53
+
54
+ # Normalize the embeddings
55
+ norms = np.linalg.norm(pooled_embeddings, axis=1, keepdims=True)
56
+ normalized_embeddings = pooled_embeddings / norms
57
+
58
+ return normalized_embeddings
59
+
60
+ # --- Example Usage ---
61
+ if __name__ == '__main__':
62
+ model = BgeCodeOnnx()
63
+
64
+ code_snippets = [
65
+ "fn main() { let x = 5; }",
66
+ "struct User { id: u64, name: String }"
67
+ ]
68
+
69
+ embeddings = model.embed(code_snippets)
70
+
71
+ print("\nEmbeddings generated successfully!")
72
+ for i, snippet in enumerate(code_snippets):
73
+ print(f"\nInput: '{snippet}'")
74
+ print(f"Embedding (first 5 dims): {embeddings[i][:5]}")
75
+ Usage with RustYou can use the ort crate in Rust to run this model natively for high performance.1. Add Dependencies to Cargo.toml:[dependencies]
76
+ anyhow = "1.0"
77
+ ndarray = "0.15"
78
+ ort = "2.0.0"
79
+ tokenizers = "0.19.1"
80
+ hf-hub = "0.3.3"
81
+ tokio = { version = "1", features = ["full"] }
82
+ camino = "1.1.7"
83
+ 2. Rust Code (src/main.rs):use anyhow::Result;
84
+ use camino::Utf8Path;
85
+ use hf_hub::api::sync::Api;
86
+ use ndarray::{s, Array2, ArrayView2, Axis};
87
+ use ort::{Environment, Session, SessionBuilder, Value};
88
+ use tokenizers::Tokenizer;
89
+
90
+ fn normalize_l2(x: ArrayView2<f32>) -> Array2<f32> {
91
+ let norms = x.map_axis(Axis(1), |row| row.dot(&row).sqrt());
92
+ x / &norms.into_shape((norms.len(), 1)).unwrap()
93
+ }
94
+
95
+ #[tokio::main]
96
+ async fn main() -> Result<()> {
97
+ let environment = Environment::builder().with_name("bge-code-test").build()?;
98
+
99
+ println!("Downloading model files...");
100
+ let api = Api::new()?;
101
+ let repo = api.repo(hf_hub::Repo::model(
102
+ "babybirdprd/bge-code-v1-onnx".to_string(),
103
+ ));
104
+ let model_dir = repo.path();
105
+ let model_path = Utf8Path::from_path(model_dir.join("model.onnx").as_path()).unwrap();
106
+ let tokenizer_path = Utf8Path::from_path(model_dir.join("tokenizer.json").as_path()).unwrap();
107
+ println!("✅ Model downloaded to: {}", model_dir.display());
108
+
109
+ println!("\nLoading tokenizer and creating ORT session...");
110
+ let mut tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
111
+ let session = SessionBuilder::new(&environment)?.with_model_from_file(model_path)?;
112
+ println!("✅ Session created successfully.");
113
+
114
+ let sentences = vec![
115
+ "fn main() { let x = 5; }",
116
+ "struct User { id: u64, name: String }",
117
+ ];
118
+
119
+ let padding_params = tokenizers::PaddingParams {
120
+ strategy: tokenizers::PaddingStrategy::BatchLongest,
121
+ ..Default::default()
122
+ };
123
+ tokenizer.with_padding(Some(padding_params));
124
+ let tokenized_input = tokenizer.encode_batch(sentences.clone(), true).unwrap();
125
+
126
+ let input_ids: Vec<i64> = tokenized_input.iter().flat_map(|enc| enc.get_ids().iter().map(|&id| id as i64)).collect();
127
+ let attention_mask: Vec<i64> = tokenized_input.iter().flat_map(|enc| enc.get_attention_mask().iter().map(|&id| id as i64)).collect();
128
+
129
+ let batch_size = sentences.len();
130
+ let sequence_length = tokenized_input[0].get_ids().len();
131
+
132
+ println!("\nRunning inference...");
133
+ // This model only requires 'input_ids' and 'attention_mask'
134
+ let outputs = session.run(vec![
135
+ Value::from_array(Array2::from_shape_vec((batch_size, sequence_length), input_ids)?.view())?,
136
+ Value::from_array(Array2::from_shape_vec((batch_size, sequence_length), attention_mask)?.view())?,
137
+ ])?;
138
+
139
+ println!("Post-processing embeddings...");
140
+ let last_hidden_state = outputs[0].try_extract_tensor::<f32>()?;
141
+ let (batch_size, _seq_len, hidden_dim) = last_hidden_state.dims();
142
+ let pooled_embeddings = last_hidden_state.slice(s![.., 0, ..]).into_shape((batch_size, hidden_dim)).unwrap().to_owned();
143
+ let normalized_embeddings = normalize_l2(pooled_embeddings.view());
144
+
145
+ println!("\n🎉 Rust ORT Test Complete! 🎉");
146
+ for (i, sentence) in sentences.iter().enumerate() {
147
+ let embedding_slice = normalized_embeddings.slice(s![i, 0..5]).to_vec();
148
+ println!("\nInput: '{}'", sentence);
149
+ println!("Embedding (first 5 dims): {:?}", embedding_slice);
150
+ }
151
+ Ok(())
152
+ }
153
+ Usage with Docker (text-embeddings-inference)You can serve this model as a high-throughput, OpenAI-compatible API using the text-embeddings-inference container.1. Run the Docker Container:The command below will download the Docker image and your model, then start the server.# Define a volume to store the model
154
+ export MODEL_DIR=$HOME/bge-code-onnx-model
155
+
156
+ # Run the container, mounting the volume
157
+ docker run --pull always -p 8080:80 -v $MODEL_DIR:/data \
158
+ ghcr.io/huggingface/text-embeddings-inference:main-onnx \
159
+ --model-id babybirdprd/bge-code-v1-onnx
160
+ 2. Call the API:Once the server is running, you can call it using any HTTP client.curl http://localhost:8080/embed \
161
+ -X POST \
162
+ -H "Content-Type: application/json" \
163
+ -d '{"inputs": "fn main() { let message = \"Hello, ONNX!\"; }"}'