Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +8 -34
- .gitignore +175 -0
- CMD.md +91 -0
- README.md +60 -0
- hf_upload.ipynb +358 -0
- old_scripts/convert_for_unity.py +1024 -0
- old_scripts/convert_single_model.py +492 -0
- old_scripts/convert_to_onnx.py +261 -0
- old_scripts/test_chat.py +402 -0
- onnx_models/bloom_onnx/config.json +32 -0
- onnx_models/bloom_onnx/generation_config.json +7 -0
- onnx_models/bloom_onnx/model.onnx +3 -0
- onnx_models/bloom_onnx/special_tokens_map.json +30 -0
- onnx_models/bloom_onnx/tokenizer.json +3 -0
- onnx_models/bloom_onnx/tokenizer_config.json +48 -0
- onnx_models/bloom_onnx_quantized/config.json +32 -0
- onnx_models/bloom_onnx_quantized/model_quantized.onnx +3 -0
- onnx_models/bloom_onnx_quantized/ort_config.json +33 -0
- onnx_models/bloom_onnx_quantized/special_tokens_map.json +30 -0
- onnx_models/bloom_onnx_quantized/tokenizer.json +3 -0
- onnx_models/bloom_onnx_quantized/tokenizer_config.json +48 -0
- onnx_models/falcon_onnx/config.json +41 -0
- onnx_models/falcon_onnx/generation_config.json +6 -0
- onnx_models/falcon_onnx/merges.txt +0 -0
- onnx_models/falcon_onnx/model.onnx +3 -0
- onnx_models/falcon_onnx/special_tokens_map.json +23 -0
- onnx_models/falcon_onnx/tokenizer.json +0 -0
- onnx_models/falcon_onnx/tokenizer_config.json +20 -0
- onnx_models/falcon_onnx/vocab.json +0 -0
- onnx_models/gpt2_onnx/config.json +41 -0
- onnx_models/gpt2_onnx/generation_config.json +6 -0
- onnx_models/gpt2_onnx/merges.txt +0 -0
- onnx_models/gpt2_onnx/model.onnx +3 -0
- onnx_models/gpt2_onnx/special_tokens_map.json +5 -0
- onnx_models/gpt2_onnx/tokenizer.json +0 -0
- onnx_models/gpt2_onnx/tokenizer_config.json +20 -0
- onnx_models/gpt2_onnx/vocab.json +0 -0
- onnx_models/gpt2_onnx_quantized/config.json +41 -0
- onnx_models/gpt2_onnx_quantized/merges.txt +0 -0
- onnx_models/gpt2_onnx_quantized/model_quantized.onnx +3 -0
- onnx_models/gpt2_onnx_quantized/ort_config.json +33 -0
- onnx_models/gpt2_onnx_quantized/special_tokens_map.json +23 -0
- onnx_models/gpt2_onnx_quantized/tokenizer.json +0 -0
- onnx_models/gpt2_onnx_quantized/tokenizer_config.json +20 -0
- onnx_models/gpt2_onnx_quantized/vocab.json +0 -0
- onnx_models/opt_onnx/config.json +31 -0
- onnx_models/opt_onnx/generation_config.json +7 -0
- onnx_models/opt_onnx/merges.txt +0 -0
- onnx_models/opt_onnx/model.onnx +3 -0
- onnx_models/opt_onnx/special_tokens_map.json +30 -0
.gitattributes
CHANGED
@@ -1,35 +1,9 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.onnx_data filter=lfs diff=lfs merge=lfs -text
|
3 |
+
|
4 |
+
onnx_models/bloom_onnx/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
5 |
+
onnx_models/bloom_onnx_quantized/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
6 |
+
onnx_models/qwen_onnx/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
7 |
+
onnx_models/qwen_onnx_quantized/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
8 |
+
onnx_models/tinyllama_onnx/tokenizer.model filter=lfs diff=lfs merge=lfs -text
|
9 |
+
onnx_models/tinyllama_onnx_quantized/tokenizer.model filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
#.idea/
|
169 |
+
|
170 |
+
# Ruff stuff:
|
171 |
+
.ruff_cache/
|
172 |
+
|
173 |
+
# PyPI configuration file
|
174 |
+
.pypirc
|
175 |
+
*.onnx_data
|
CMD.md
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Qwen-0.5B Model
|
2 |
+
``` bash
|
3 |
+
# Step 1: Export for text generation with past KV cache (better for chat)
|
4 |
+
echo "Exporting Qwen-0.5B..."
|
5 |
+
optimum-cli export onnx --model Qwen/Qwen1.5-0.5B --task text-generation-with-past onnx_models/qwen_onnx/
|
6 |
+
|
7 |
+
# Step 2: Quantize for ARM64 (Mobile target) using static INT8 quantization
|
8 |
+
echo "Quantizing Qwen-0.5B for ARM64 (Static)..."
|
9 |
+
optimum-cli onnxruntime quantize --onnx_model onnx_models/qwen_onnx/ --arm64 -o onnx_models/qwen_onnx_quantized/
|
10 |
+
```
|
11 |
+
|
12 |
+
-----------------------------------
|
13 |
+
|
14 |
+
### TinyLlama-1.1B
|
15 |
+
``` bash
|
16 |
+
# Step 1: Export for text generation with past KV cache (better for chat)
|
17 |
+
echo "Exporting TinyLlama-1.1B..."
|
18 |
+
optimum-cli export onnx --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --task text-generation-with-past onnx_models/tinyllama_onnx/
|
19 |
+
|
20 |
+
# Step 2: Attempt Quantization for ARM64 (Static INT8)
|
21 |
+
# This is the step you mentioned takes too long or fails. Try it, but have the alternative ready.
|
22 |
+
echo "Attempting TinyLlama-1.1B quantization for ARM64 (Static)..."
|
23 |
+
optimum-cli onnxruntime quantize --onnx_model onnx_models/tinyllama_onnx/ --arm64 -o onnx_models/tinyllama_onnx_quantized/
|
24 |
+
```
|
25 |
+
|
26 |
+
-----------------------------------
|
27 |
+
|
28 |
+
### Phi-1.5 Model
|
29 |
+
``` bash
|
30 |
+
# Step 1: Export for text generation with past KV cache (better for chat)
|
31 |
+
echo "Exporting Phi-1.5..."
|
32 |
+
optimum-cli export onnx --model microsoft/phi-1_5 --task text-generation-with-past onnx_models/phi_onnx/
|
33 |
+
|
34 |
+
# Step 2: Attempt Quantization for ARM64 (Static INT8) -- Failed with me (need much memory)
|
35 |
+
echo "Quantizing Phi-1.5 for ARM64 (Static)..."
|
36 |
+
optimum-cli onnxruntime quantize --onnx_model onnx_models/phi_onnx/ --arm64 -o onnx_models/phi_onnx_quantized/
|
37 |
+
```
|
38 |
+
|
39 |
+
-----------------------------------
|
40 |
+
|
41 |
+
### Falcon-1B Model
|
42 |
+
``` bash
|
43 |
+
# Export
|
44 |
+
echo "Exporting Falcon-1B..."
|
45 |
+
optimum-cli export onnx --model tiiuae/falcon-rw-1b --task text-generation-with-past onnx_models/falcon_onnx/
|
46 |
+
|
47 |
+
# Quantize for ARM64 -- Failed with me (need much memory)
|
48 |
+
echo "Quantizing Falcon-1B for ARM64..."
|
49 |
+
optimum-cli onnxruntime quantize --onnx_model onnx_models/falcon_onnx/ --arm64 -o onnx_models/falcon_onnx_quantized/
|
50 |
+
```
|
51 |
+
|
52 |
+
-----------------------------------
|
53 |
+
|
54 |
+
### GPT-2Medium Model
|
55 |
+
``` bash
|
56 |
+
# Export GPT2-Medium
|
57 |
+
echo "Exporting GPT2-Medium..."
|
58 |
+
optimum-cli export onnx --model gpt2-medium --task text-generation-with-past onnx_models/gpt2_onnx/
|
59 |
+
|
60 |
+
# Quantize for ARM64
|
61 |
+
echo "Quantizing GPT2-Medium for ARM64..."
|
62 |
+
optimum-cli onnxruntime quantize --onnx_model onnx_models/gpt2_onnx/ --arm64 -o onnx_models/gpt2_onnx_quantized/
|
63 |
+
```
|
64 |
+
|
65 |
+
-----------------------------------
|
66 |
+
|
67 |
+
### OPT-350M Model
|
68 |
+
``` bash
|
69 |
+
# Export OPT-350M
|
70 |
+
echo "Exporting OPT-350M..."
|
71 |
+
optimum-cli export onnx --model facebook/opt-350m --task text-generation-with-past onnx_models/opt_onnx/
|
72 |
+
|
73 |
+
# Quantize for ARM64
|
74 |
+
echo "Quantizing OPT-350M for ARM64..."
|
75 |
+
optimum-cli onnxruntime quantize --onnx_model onnx_models/opt_onnx/ --arm64 -o onnx_models/opt_onnx_quantized/
|
76 |
+
```
|
77 |
+
|
78 |
+
-----------------------------------
|
79 |
+
|
80 |
+
### Bloom-560M Model
|
81 |
+
``` bash
|
82 |
+
# Export Bloom-560M
|
83 |
+
echo "Exporting Bloom-560M..."
|
84 |
+
optimum-cli export onnx --model bigscience/bloom-560m --task text-generation-with-past onnx_models/bloom_onnx/
|
85 |
+
|
86 |
+
# Quantize for ARM64
|
87 |
+
echo "Quantizing Bloom-560M for ARM64..."
|
88 |
+
optimum-cli onnxruntime quantize --onnx_model onnx_models/bloom_onnx/ --arm64 -o onnx_models/bloom_onnx_quantized/
|
89 |
+
```
|
90 |
+
|
91 |
+
-----------------------------------
|
README.md
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🚀 LLM to ONNX Converter
|
2 |
+
> Convert small language models to ONNX format with **guaranteed reliability** for RAG and chatbot applications on resource-constrained hardware.
|
3 |
+
|
4 |
+
## 📋 Overview
|
5 |
+
This repository provides scripts to convert small language models to ONNX format and create INT8 quantized versions for efficient deployment on resource-constrained devices. Perfect for mobile applications, Unity game engines, and embedded systems.
|
6 |
+
|
7 |
+
## ✅ Tested Models
|
8 |
+
We've successfully tested the following models with example outputs:
|
9 |
+
|
10 |
+
| Model | Size | Quantized | Response Quality | Speed (sec) |
|
11 |
+
|-------|------|-----------|-----------------|-------------|
|
12 |
+
| Qwen-0.5B | 500M | ✅ | ❌ Poor | 8.37 |
|
13 |
+
| Qwen-0.5B | 500M | ❌ | ✅ Good | 15.69 |
|
14 |
+
| TinyLlama-1.1B | 1.1B | ✅ | ❌ Poor | 10.15 |
|
15 |
+
| TinyLlama-1.1B | 1.1B | ❌ | ✅ Good | 19.23 |
|
16 |
+
| Phi-1.5 | 1.3B | ❌ | ✅ Good | 15.32 |
|
17 |
+
| Falcon-RW-1B | 1B | ❌ | ✅ Good | 21.56 |
|
18 |
+
| GPT2-Medium | 355M | ✅ | ✅ Good | 6.27 |
|
19 |
+
| GPT2-Medium | 355M | ❌ | ✅ Good | 12.77 |
|
20 |
+
| OPT-350M | 350M | ✅ | ✅ Good | 4.33 |
|
21 |
+
| OPT-350M | 350M | ❌ | ✅ Good | 10.42 |
|
22 |
+
| Bloom-560M | 560M | ✅ | ❌ Poor | 11.93 |
|
23 |
+
| Bloom-560M | 560M | ❌ | ✅ Good | 34.38 |
|
24 |
+
|
25 |
+
## 🌟 Recommendations
|
26 |
+
Based on our testing:
|
27 |
+
1. **For best speed + quality:** OPT-350M (quantized) - fastest with good quality
|
28 |
+
2. **For best overall quality:** Phi-1.5 (non-quantized) - excellent responses
|
29 |
+
3. **For smallest size:** GPT2-Medium or OPT-350M (quantized) - small with good performance
|
30 |
+
|
31 |
+
## 🚩 Key Findings
|
32 |
+
- Quantization provides ~2x speed improvement
|
33 |
+
- Smaller models (350-500M) quantize better than larger models (1B+)
|
34 |
+
- Some architectures (OPT, GPT2) handle quantization better than others
|
35 |
+
|
36 |
+
## 📁 Repository Structure
|
37 |
+
```
|
38 |
+
onnx_models/
|
39 |
+
├── bloom_onnx/
|
40 |
+
├── bloom_onnx_quantized/
|
41 |
+
├── falcon_onnx/
|
42 |
+
├── gpt2_onnx/
|
43 |
+
├── gpt2_onnx_quantized/
|
44 |
+
├── opt_onnx/
|
45 |
+
├── opt_onnx_quantized/
|
46 |
+
├── phi_onnx/
|
47 |
+
├── qwen_onnx/
|
48 |
+
├── qwen_onnx_quantized/
|
49 |
+
├── tinyllama_onnx/
|
50 |
+
└── tinyllama_onnx_quantized/
|
51 |
+
```
|
52 |
+
|
53 |
+
## 📚 Requirements
|
54 |
+
- Python 3.8+
|
55 |
+
- optimum
|
56 |
+
- onnxruntime
|
57 |
+
- transformers
|
58 |
+
- numpy
|
59 |
+
|
60 |
+
---------------
|
hf_upload.ipynb
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "22fbff0c",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"data": {
|
11 |
+
"application/vnd.jupyter.widget-view+json": {
|
12 |
+
"model_id": "a10a920b0f9749058ee8dd5ce613705a",
|
13 |
+
"version_major": 2,
|
14 |
+
"version_minor": 0
|
15 |
+
},
|
16 |
+
"text/plain": [
|
17 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
"metadata": {},
|
21 |
+
"output_type": "display_data"
|
22 |
+
}
|
23 |
+
],
|
24 |
+
"source": [
|
25 |
+
"from huggingface_hub import login\n",
|
26 |
+
"login()"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": null,
|
32 |
+
"id": "25711ffa",
|
33 |
+
"metadata": {},
|
34 |
+
"outputs": [
|
35 |
+
{
|
36 |
+
"name": "stderr",
|
37 |
+
"output_type": "stream",
|
38 |
+
"text": [
|
39 |
+
"/home/administrator/miniconda/lib/python3.12/site-packages/huggingface_hub/hf_api.py:9561: UserWarning: Warnings while validating metadata in README.md:\n",
|
40 |
+
"- empty or missing yaml metadata in repo card\n",
|
41 |
+
" warnings.warn(f\"Warnings while validating metadata in README.md:\\n{message}\")\n"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"data": {
|
46 |
+
"application/vnd.jupyter.widget-view+json": {
|
47 |
+
"model_id": "09c1ecf794c84518803cca555425306a",
|
48 |
+
"version_major": 2,
|
49 |
+
"version_minor": 0
|
50 |
+
},
|
51 |
+
"text/plain": [
|
52 |
+
"model.onnx: 0%| | 0.00/798k [00:00<?, ?B/s]"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
"metadata": {},
|
56 |
+
"output_type": "display_data"
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"data": {
|
60 |
+
"application/vnd.jupyter.widget-view+json": {
|
61 |
+
"model_id": "32f7fccea8414cac920db0d57af0e7d5",
|
62 |
+
"version_major": 2,
|
63 |
+
"version_minor": 0
|
64 |
+
},
|
65 |
+
"text/plain": [
|
66 |
+
"tokenizer.json: 0%| | 0.00/21.8M [00:00<?, ?B/s]"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
"metadata": {},
|
70 |
+
"output_type": "display_data"
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"data": {
|
74 |
+
"application/vnd.jupyter.widget-view+json": {
|
75 |
+
"model_id": "ca28ba13ab1040d7bdf3c44430b5a266",
|
76 |
+
"version_major": 2,
|
77 |
+
"version_minor": 0
|
78 |
+
},
|
79 |
+
"text/plain": [
|
80 |
+
"model.onnx: 0%| | 0.00/655k [00:00<?, ?B/s]"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
"metadata": {},
|
84 |
+
"output_type": "display_data"
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"data": {
|
88 |
+
"application/vnd.jupyter.widget-view+json": {
|
89 |
+
"model_id": "dc762c21c9a64897b160c8bc0a745942",
|
90 |
+
"version_major": 2,
|
91 |
+
"version_minor": 0
|
92 |
+
},
|
93 |
+
"text/plain": [
|
94 |
+
"Upload 18 LFS files: 0%| | 0/18 [00:00<?, ?it/s]"
|
95 |
+
]
|
96 |
+
},
|
97 |
+
"metadata": {},
|
98 |
+
"output_type": "display_data"
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"data": {
|
102 |
+
"application/vnd.jupyter.widget-view+json": {
|
103 |
+
"model_id": "2c80e14a6ec54618bda83086bd6ab6d3",
|
104 |
+
"version_major": 2,
|
105 |
+
"version_minor": 0
|
106 |
+
},
|
107 |
+
"text/plain": [
|
108 |
+
"model_quantized.onnx: 0%| | 0.00/561M [00:00<?, ?B/s]"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
"metadata": {},
|
112 |
+
"output_type": "display_data"
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"data": {
|
116 |
+
"application/vnd.jupyter.widget-view+json": {
|
117 |
+
"model_id": "94a1fd0e86f74aa99cdce9b12d4b9bc3",
|
118 |
+
"version_major": 2,
|
119 |
+
"version_minor": 0
|
120 |
+
},
|
121 |
+
"text/plain": [
|
122 |
+
"tokenizer.json: 0%| | 0.00/21.8M [00:00<?, ?B/s]"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
"metadata": {},
|
126 |
+
"output_type": "display_data"
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"data": {
|
130 |
+
"application/vnd.jupyter.widget-view+json": {
|
131 |
+
"model_id": "f4b966da10894d32815492fae6e891d1",
|
132 |
+
"version_major": 2,
|
133 |
+
"version_minor": 0
|
134 |
+
},
|
135 |
+
"text/plain": [
|
136 |
+
"model.onnx: 0%| | 0.00/1.42G [00:00<?, ?B/s]"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
"metadata": {},
|
140 |
+
"output_type": "display_data"
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"data": {
|
144 |
+
"application/vnd.jupyter.widget-view+json": {
|
145 |
+
"model_id": "28bf8c5b822c40639b24d588e968dadd",
|
146 |
+
"version_major": 2,
|
147 |
+
"version_minor": 0
|
148 |
+
},
|
149 |
+
"text/plain": [
|
150 |
+
"model_quantized.onnx: 0%| | 0.00/357M [00:00<?, ?B/s]"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
"metadata": {},
|
154 |
+
"output_type": "display_data"
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"data": {
|
158 |
+
"application/vnd.jupyter.widget-view+json": {
|
159 |
+
"model_id": "483803bf88fc4e41ae696925d39ae6e2",
|
160 |
+
"version_major": 2,
|
161 |
+
"version_minor": 0
|
162 |
+
},
|
163 |
+
"text/plain": [
|
164 |
+
"model.onnx: 0%| | 0.00/1.33G [00:00<?, ?B/s]"
|
165 |
+
]
|
166 |
+
},
|
167 |
+
"metadata": {},
|
168 |
+
"output_type": "display_data"
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"data": {
|
172 |
+
"application/vnd.jupyter.widget-view+json": {
|
173 |
+
"model_id": "55056aef07564537b9d8f6b6e2d9d87c",
|
174 |
+
"version_major": 2,
|
175 |
+
"version_minor": 0
|
176 |
+
},
|
177 |
+
"text/plain": [
|
178 |
+
"model_quantized.onnx: 0%| | 0.00/333M [00:00<?, ?B/s]"
|
179 |
+
]
|
180 |
+
},
|
181 |
+
"metadata": {},
|
182 |
+
"output_type": "display_data"
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"data": {
|
186 |
+
"application/vnd.jupyter.widget-view+json": {
|
187 |
+
"model_id": "0442e7fbdea74b89950176cef30930dc",
|
188 |
+
"version_major": 2,
|
189 |
+
"version_minor": 0
|
190 |
+
},
|
191 |
+
"text/plain": [
|
192 |
+
"model.onnx: 0%| | 0.00/814k [00:00<?, ?B/s]"
|
193 |
+
]
|
194 |
+
},
|
195 |
+
"metadata": {},
|
196 |
+
"output_type": "display_data"
|
197 |
+
},
|
198 |
+
{
|
199 |
+
"data": {
|
200 |
+
"application/vnd.jupyter.widget-view+json": {
|
201 |
+
"model_id": "f46c2c40bc2c4686920bd0deb3259df6",
|
202 |
+
"version_major": 2,
|
203 |
+
"version_minor": 0
|
204 |
+
},
|
205 |
+
"text/plain": [
|
206 |
+
"model.onnx: 0%| | 0.00/1.86G [00:00<?, ?B/s]"
|
207 |
+
]
|
208 |
+
},
|
209 |
+
"metadata": {},
|
210 |
+
"output_type": "display_data"
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"data": {
|
214 |
+
"application/vnd.jupyter.widget-view+json": {
|
215 |
+
"model_id": "bb2d49b750ef4399ae2defea1fb1593d",
|
216 |
+
"version_major": 2,
|
217 |
+
"version_minor": 0
|
218 |
+
},
|
219 |
+
"text/plain": [
|
220 |
+
"tokenizer.json: 0%| | 0.00/11.4M [00:00<?, ?B/s]"
|
221 |
+
]
|
222 |
+
},
|
223 |
+
"metadata": {},
|
224 |
+
"output_type": "display_data"
|
225 |
+
},
|
226 |
+
{
|
227 |
+
"data": {
|
228 |
+
"application/vnd.jupyter.widget-view+json": {
|
229 |
+
"model_id": "418c0351956b4a75854237f5ec3077c1",
|
230 |
+
"version_major": 2,
|
231 |
+
"version_minor": 0
|
232 |
+
},
|
233 |
+
"text/plain": [
|
234 |
+
"model_quantized.onnx: 0%| | 0.00/466M [00:00<?, ?B/s]"
|
235 |
+
]
|
236 |
+
},
|
237 |
+
"metadata": {},
|
238 |
+
"output_type": "display_data"
|
239 |
+
},
|
240 |
+
{
|
241 |
+
"data": {
|
242 |
+
"application/vnd.jupyter.widget-view+json": {
|
243 |
+
"model_id": "bd17923e98b741e6bc159b25d2a25717",
|
244 |
+
"version_major": 2,
|
245 |
+
"version_minor": 0
|
246 |
+
},
|
247 |
+
"text/plain": [
|
248 |
+
"tokenizer.json: 0%| | 0.00/11.4M [00:00<?, ?B/s]"
|
249 |
+
]
|
250 |
+
},
|
251 |
+
"metadata": {},
|
252 |
+
"output_type": "display_data"
|
253 |
+
},
|
254 |
+
{
|
255 |
+
"data": {
|
256 |
+
"application/vnd.jupyter.widget-view+json": {
|
257 |
+
"model_id": "93080e4239d84969a69f2c2f86b658b3",
|
258 |
+
"version_major": 2,
|
259 |
+
"version_minor": 0
|
260 |
+
},
|
261 |
+
"text/plain": [
|
262 |
+
"model.onnx: 0%| | 0.00/987k [00:00<?, ?B/s]"
|
263 |
+
]
|
264 |
+
},
|
265 |
+
"metadata": {},
|
266 |
+
"output_type": "display_data"
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"data": {
|
270 |
+
"application/vnd.jupyter.widget-view+json": {
|
271 |
+
"model_id": "ea2fda33732745e48f7e1a6e1dae5ecc",
|
272 |
+
"version_major": 2,
|
273 |
+
"version_minor": 0
|
274 |
+
},
|
275 |
+
"text/plain": [
|
276 |
+
"tokenizer.model: 0%| | 0.00/500k [00:00<?, ?B/s]"
|
277 |
+
]
|
278 |
+
},
|
279 |
+
"metadata": {},
|
280 |
+
"output_type": "display_data"
|
281 |
+
},
|
282 |
+
{
|
283 |
+
"data": {
|
284 |
+
"application/vnd.jupyter.widget-view+json": {
|
285 |
+
"model_id": "851fad297317474495592a8db14aabf5",
|
286 |
+
"version_major": 2,
|
287 |
+
"version_minor": 0
|
288 |
+
},
|
289 |
+
"text/plain": [
|
290 |
+
"model_quantized.onnx: 0%| | 0.00/1.10G [00:00<?, ?B/s]"
|
291 |
+
]
|
292 |
+
},
|
293 |
+
"metadata": {},
|
294 |
+
"output_type": "display_data"
|
295 |
+
},
|
296 |
+
{
|
297 |
+
"data": {
|
298 |
+
"application/vnd.jupyter.widget-view+json": {
|
299 |
+
"model_id": "18d5836cab4f440799db945c1af7cfeb",
|
300 |
+
"version_major": 2,
|
301 |
+
"version_minor": 0
|
302 |
+
},
|
303 |
+
"text/plain": [
|
304 |
+
"tokenizer.model: 0%| | 0.00/500k [00:00<?, ?B/s]"
|
305 |
+
]
|
306 |
+
},
|
307 |
+
"metadata": {},
|
308 |
+
"output_type": "display_data"
|
309 |
+
}
|
310 |
+
],
|
311 |
+
"source": [
|
312 |
+
"from huggingface_hub import upload_folder\n",
|
313 |
+
"\n",
|
314 |
+
"# Path to your offline-models directory\n",
|
315 |
+
"folder_path = \"/home/administrator/offline-rag-model/offline-models\"\n",
|
316 |
+
"\n",
|
317 |
+
"# Your Hugging Face repository name\n",
|
318 |
+
"repo_name = \"onnx-models\"\n",
|
319 |
+
"\n",
|
320 |
+
"# Upload all files to Hugging Face\n",
|
321 |
+
"upload_folder(\n",
|
322 |
+
" folder_path=folder_path,\n",
|
323 |
+
" repo_id=f\"agoor97/{repo_name}\",\n",
|
324 |
+
" repo_type=\"model\",\n",
|
325 |
+
")"
|
326 |
+
]
|
327 |
+
},
|
328 |
+
{
|
329 |
+
"cell_type": "code",
|
330 |
+
"execution_count": null,
|
331 |
+
"id": "67055a1c",
|
332 |
+
"metadata": {},
|
333 |
+
"outputs": [],
|
334 |
+
"source": []
|
335 |
+
}
|
336 |
+
],
|
337 |
+
"metadata": {
|
338 |
+
"kernelspec": {
|
339 |
+
"display_name": "base",
|
340 |
+
"language": "python",
|
341 |
+
"name": "python3"
|
342 |
+
},
|
343 |
+
"language_info": {
|
344 |
+
"codemirror_mode": {
|
345 |
+
"name": "ipython",
|
346 |
+
"version": 3
|
347 |
+
},
|
348 |
+
"file_extension": ".py",
|
349 |
+
"mimetype": "text/x-python",
|
350 |
+
"name": "python",
|
351 |
+
"nbconvert_exporter": "python",
|
352 |
+
"pygments_lexer": "ipython3",
|
353 |
+
"version": "3.12.9"
|
354 |
+
}
|
355 |
+
},
|
356 |
+
"nbformat": 4,
|
357 |
+
"nbformat_minor": 5
|
358 |
+
}
|
old_scripts/convert_for_unity.py
ADDED
@@ -0,0 +1,1024 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import logging
|
6 |
+
import traceback
|
7 |
+
import torch
|
8 |
+
import warnings
|
9 |
+
import numpy as np
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11 |
+
from transformers.generation import GenerationConfig
|
12 |
+
from tqdm import tqdm
|
13 |
+
from onnxruntime.quantization import quantize_dynamic, QuantType
|
14 |
+
|
15 |
+
# Configure logging
|
16 |
+
logging.basicConfig(
|
17 |
+
level=logging.INFO,
|
18 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
19 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
20 |
+
)
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
# Suppress unhelpful warnings
|
24 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
25 |
+
|
26 |
+
|
27 |
+
class GenerationWrapper(torch.nn.Module):
|
28 |
+
"""
|
29 |
+
Wrapper for model export that handles generation properly.
|
30 |
+
This ensures the model can be correctly used for text generation.
|
31 |
+
"""
|
32 |
+
def __init__(self, model):
|
33 |
+
super().__init__()
|
34 |
+
self.model = model
|
35 |
+
self.config = model.config
|
36 |
+
|
37 |
+
def forward(self, input_ids, attention_mask=None):
|
38 |
+
# Return only the logits to avoid complex structures
|
39 |
+
with torch.no_grad():
|
40 |
+
try:
|
41 |
+
# Standard approach for most models
|
42 |
+
outputs = self.model(
|
43 |
+
input_ids=input_ids,
|
44 |
+
attention_mask=attention_mask,
|
45 |
+
use_cache=False,
|
46 |
+
return_dict=True
|
47 |
+
)
|
48 |
+
return outputs.logits
|
49 |
+
except Exception as e:
|
50 |
+
logger.warning(f"Standard forward pass failed, trying fallback: {str(e)}")
|
51 |
+
# Fallback for models with different API
|
52 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
53 |
+
if hasattr(outputs, 'logits'):
|
54 |
+
return outputs.logits
|
55 |
+
elif isinstance(outputs, tuple) and len(outputs) > 0:
|
56 |
+
return outputs[0] # First element is typically logits
|
57 |
+
else:
|
58 |
+
raise ValueError("Could not extract logits from model outputs")
|
59 |
+
|
60 |
+
def verify_model_generation(model, tokenizer, device="cpu"):
|
61 |
+
"""Test model generation capabilities before export"""
|
62 |
+
model.eval()
|
63 |
+
|
64 |
+
# Use a chat-like prompt for better testing
|
65 |
+
prompt = "User: Hello, how are you today?\nAssistant:"
|
66 |
+
|
67 |
+
logger.info("Testing model generation...")
|
68 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
69 |
+
|
70 |
+
# Configure generation parameters
|
71 |
+
gen_config = GenerationConfig(
|
72 |
+
max_length=100,
|
73 |
+
do_sample=True,
|
74 |
+
temperature=0.7,
|
75 |
+
num_return_sequences=1,
|
76 |
+
)
|
77 |
+
|
78 |
+
try:
|
79 |
+
# Try generation
|
80 |
+
with torch.no_grad():
|
81 |
+
outputs = model.generate(
|
82 |
+
**inputs,
|
83 |
+
generation_config=gen_config
|
84 |
+
)
|
85 |
+
|
86 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
87 |
+
logger.info(f"Test generation result: {generated_text}")
|
88 |
+
|
89 |
+
if len(generated_text) <= len(prompt):
|
90 |
+
logger.warning("Generation output is not longer than input prompt!")
|
91 |
+
|
92 |
+
return True
|
93 |
+
except Exception as e:
|
94 |
+
logger.error(f"Generation test failed: {str(e)}")
|
95 |
+
return False
|
96 |
+
|
97 |
+
def test_onnx_model(onnx_path, tokenizer):
|
98 |
+
"""Verify the ONNX model can be loaded and run"""
|
99 |
+
try:
|
100 |
+
import onnxruntime as ort
|
101 |
+
|
102 |
+
logger.info("Testing ONNX model inference...")
|
103 |
+
session = ort.InferenceSession(onnx_path)
|
104 |
+
|
105 |
+
# Get input and output names
|
106 |
+
input_names = [input.name for input in session.get_inputs()]
|
107 |
+
output_names = [output.name for output in session.get_outputs()]
|
108 |
+
|
109 |
+
# Create test input
|
110 |
+
prompt = "User: Hello, how are you?\nAssistant:"
|
111 |
+
inputs = tokenizer(prompt, return_tensors="np")
|
112 |
+
|
113 |
+
# Prepare input dict
|
114 |
+
onnx_inputs = {}
|
115 |
+
for name in input_names:
|
116 |
+
if name == "input_ids" and "input_ids" in inputs:
|
117 |
+
onnx_inputs[name] = inputs["input_ids"]
|
118 |
+
elif name == "attention_mask" and "attention_mask" in inputs:
|
119 |
+
onnx_inputs[name] = inputs["attention_mask"]
|
120 |
+
|
121 |
+
# Run inference
|
122 |
+
outputs = session.run(output_names, onnx_inputs)
|
123 |
+
|
124 |
+
# Check output shape
|
125 |
+
logits = outputs[0]
|
126 |
+
logger.info(f"ONNX model output shape: {logits.shape}")
|
127 |
+
|
128 |
+
if logits.shape[0] != 1 or logits.shape[1] != inputs["input_ids"].shape[1]:
|
129 |
+
logger.warning("Output shape doesn't match expected dimensions!")
|
130 |
+
|
131 |
+
# Test next token prediction
|
132 |
+
next_token_logits = logits[0, -1, :]
|
133 |
+
next_token_id = np.argmax(next_token_logits)
|
134 |
+
next_token = tokenizer.decode([next_token_id])
|
135 |
+
logger.info(f"Next predicted token: '{next_token}'")
|
136 |
+
|
137 |
+
return True
|
138 |
+
except Exception as e:
|
139 |
+
logger.error(f"ONNX model test failed: {str(e)}")
|
140 |
+
return False
|
141 |
+
|
142 |
+
def post_process_onnx_for_unity(onnx_path):
|
143 |
+
"""
|
144 |
+
Post-process ONNX model to be compatible with Unity Sentis
|
145 |
+
using only core onnx functionality (no onnxsim)
|
146 |
+
"""
|
147 |
+
try:
|
148 |
+
import onnx
|
149 |
+
|
150 |
+
logger.info("Post-processing ONNX model for Unity compatibility...")
|
151 |
+
|
152 |
+
# First, create a backup of the original model
|
153 |
+
backup_path = onnx_path.replace(".onnx", "_original.onnx")
|
154 |
+
import shutil
|
155 |
+
shutil.copy(onnx_path, backup_path)
|
156 |
+
logger.info(f"Original model backed up to {backup_path}")
|
157 |
+
|
158 |
+
# Load the model
|
159 |
+
model = onnx.load(onnx_path)
|
160 |
+
|
161 |
+
# Basic model checks and optimizations
|
162 |
+
try:
|
163 |
+
# Check model validity
|
164 |
+
onnx.checker.check_model(model)
|
165 |
+
logger.info("✓ Model structure validated successfully")
|
166 |
+
|
167 |
+
# Apply shape inference
|
168 |
+
inferred_model = onnx.shape_inference.infer_shapes(model)
|
169 |
+
onnx.save(inferred_model, onnx_path)
|
170 |
+
logger.info("✓ Applied shape inference")
|
171 |
+
|
172 |
+
except Exception as e:
|
173 |
+
logger.warning(f"Model validation/optimization error (continuing): {str(e)}")
|
174 |
+
|
175 |
+
return True
|
176 |
+
|
177 |
+
except Exception as e:
|
178 |
+
logger.warning(f"ONNX post-processing error (skipping): {str(e)}")
|
179 |
+
return False
|
180 |
+
|
181 |
+
def is_architecture_compatible(model_id):
|
182 |
+
"""
|
183 |
+
Check if the model architecture is expected to be compatible with ONNX opset 11
|
184 |
+
"""
|
185 |
+
model_id_lower = model_id.lower()
|
186 |
+
|
187 |
+
# Models known to work with opset 11
|
188 |
+
compatible_architectures = [
|
189 |
+
"gpt2", "distilgpt2", "opt-125m", "opt-350m",
|
190 |
+
"pythia-70m", "pythia-160m", "rwkv", "gpt-neo"
|
191 |
+
]
|
192 |
+
|
193 |
+
# Models likely requiring higher opsets (usually 14+)
|
194 |
+
incompatible_architectures = [
|
195 |
+
"llama", "mistral", "mixtral", "tinyllama", "phi-2",
|
196 |
+
"gemma", "falcon", "bloom"
|
197 |
+
]
|
198 |
+
|
199 |
+
# Check for compatibility
|
200 |
+
for arch in compatible_architectures:
|
201 |
+
if arch in model_id_lower:
|
202 |
+
return True, 11
|
203 |
+
|
204 |
+
# Check for known incompatible architectures
|
205 |
+
for arch in incompatible_architectures:
|
206 |
+
if arch in model_id_lower:
|
207 |
+
return False, 14
|
208 |
+
|
209 |
+
# For phi-1 models, use opset 14 but mark as potentially compatible
|
210 |
+
if "phi-1" in model_id_lower:
|
211 |
+
return True, 14
|
212 |
+
|
213 |
+
# Default to opset 14 for unknown architectures
|
214 |
+
return False, 14
|
215 |
+
|
216 |
+
def setup_chat_template(model_id, tokenizer):
|
217 |
+
"""
|
218 |
+
Setup appropriate chat template based on model architecture
|
219 |
+
"""
|
220 |
+
model_id_lower = model_id.lower()
|
221 |
+
|
222 |
+
# Try to setup chat template if it doesn't have one
|
223 |
+
try:
|
224 |
+
if not hasattr(tokenizer, "chat_template") or tokenizer.chat_template is None:
|
225 |
+
logger.info("Setting up chat template for improved conversations...")
|
226 |
+
|
227 |
+
# Determine chat template based on model
|
228 |
+
if "gpt2" in model_id_lower or "pythia" in model_id_lower or "opt" in model_id_lower:
|
229 |
+
# Simple template for base models
|
230 |
+
chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\nHuman: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAI: {{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\nAI: {% endif %}"
|
231 |
+
tokenizer.chat_template = chat_template
|
232 |
+
logger.info("✓ Added simple Human/AI chat template")
|
233 |
+
|
234 |
+
elif "phi" in model_id_lower:
|
235 |
+
# Microsoft Phi models template
|
236 |
+
chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\nHuman: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\nAssistant: {% endif %}"
|
237 |
+
tokenizer.chat_template = chat_template
|
238 |
+
logger.info("✓ Added Phi-style Human/Assistant chat template")
|
239 |
+
|
240 |
+
elif "rwkv" in model_id_lower:
|
241 |
+
# RWKV template
|
242 |
+
chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nBot: {{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\nBot: {% endif %}"
|
243 |
+
tokenizer.chat_template = chat_template
|
244 |
+
logger.info("✓ Added RWKV-style User/Bot chat template")
|
245 |
+
|
246 |
+
except Exception as e:
|
247 |
+
logger.warning(f"Couldn't setup chat template: {str(e)}")
|
248 |
+
logger.info("Chat template setup will need to be handled in Unity")
|
249 |
+
|
250 |
+
def convert_model(model_id, output_dir="./onnx_models", seq_length=32, quantize=True, force_opset=None):
|
251 |
+
"""
|
252 |
+
Convert a model to ONNX format with focus on Unity compatibility.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
model_id: HuggingFace model ID or path
|
256 |
+
output_dir: Directory to save the model
|
257 |
+
seq_length: Input sequence length for export
|
258 |
+
quantize: Whether to quantize the model to INT8
|
259 |
+
force_opset: Force a specific ONNX opset version
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
bool: Success status
|
263 |
+
"""
|
264 |
+
start_time = time.time()
|
265 |
+
|
266 |
+
# Check model architecture for compatibility
|
267 |
+
is_compatible, recommended_opset = is_architecture_compatible(model_id)
|
268 |
+
|
269 |
+
# Use forced opset if provided, otherwise use recommended
|
270 |
+
opset_version = force_opset if force_opset is not None else recommended_opset
|
271 |
+
|
272 |
+
# Warn if using a model that might not be compatible with Unity
|
273 |
+
if not is_compatible and opset_version < 14:
|
274 |
+
logger.warning(f"⚠ Model {model_id} may not be compatible with opset {opset_version}")
|
275 |
+
logger.warning(f"⚠ Recommended opset for this model: {recommended_opset}")
|
276 |
+
logger.warning(f"⚠ You can force a higher opset with --opset {recommended_opset}")
|
277 |
+
|
278 |
+
logger.info(f"\n{'=' * 60}")
|
279 |
+
logger.info(f"Converting {model_id} to ONNX for Unity (opset {opset_version})")
|
280 |
+
logger.info(f"{'=' * 60}")
|
281 |
+
|
282 |
+
# Create output directory
|
283 |
+
model_name = model_id.split("/")[-1]
|
284 |
+
model_dir = os.path.join(output_dir, model_name)
|
285 |
+
os.makedirs(model_dir, exist_ok=True)
|
286 |
+
|
287 |
+
try:
|
288 |
+
# Step 1: Load tokenizer
|
289 |
+
logger.info("Step 1/7: Loading tokenizer...")
|
290 |
+
|
291 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
292 |
+
if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'):
|
293 |
+
logger.info("Adding pad_token = eos_token")
|
294 |
+
tokenizer.pad_token = tokenizer.eos_token
|
295 |
+
|
296 |
+
# Setup chat template for better conversation formatting
|
297 |
+
setup_chat_template(model_id, tokenizer)
|
298 |
+
|
299 |
+
# Save tokenizer
|
300 |
+
tokenizer.save_pretrained(model_dir)
|
301 |
+
logger.info(f"✓ Tokenizer saved to {model_dir}")
|
302 |
+
|
303 |
+
# Step 2: Load model with reliability optimizations
|
304 |
+
logger.info("Step 2/7: Loading model...")
|
305 |
+
|
306 |
+
# Clean memory
|
307 |
+
gc.collect()
|
308 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
309 |
+
|
310 |
+
# Determine device
|
311 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
312 |
+
|
313 |
+
# Load model with full precision
|
314 |
+
try:
|
315 |
+
model = AutoModelForCausalLM.from_pretrained(
|
316 |
+
model_id,
|
317 |
+
torch_dtype=torch.float32, # Use full precision for reliability
|
318 |
+
low_cpu_mem_usage=True, # Reduce memory usage
|
319 |
+
device_map=device # Use CUDA if available
|
320 |
+
)
|
321 |
+
except Exception as e:
|
322 |
+
logger.warning(f"Standard loading failed, trying with 'trust_remote_code=True': {str(e)}")
|
323 |
+
# Some models (like RWKV) need trust_remote_code
|
324 |
+
model = AutoModelForCausalLM.from_pretrained(
|
325 |
+
model_id,
|
326 |
+
torch_dtype=torch.float32,
|
327 |
+
low_cpu_mem_usage=True,
|
328 |
+
device_map=device,
|
329 |
+
trust_remote_code=True
|
330 |
+
)
|
331 |
+
|
332 |
+
# Save config
|
333 |
+
model.config.save_pretrained(model_dir)
|
334 |
+
logger.info(f"✓ Model config saved to {model_dir}")
|
335 |
+
|
336 |
+
# Step 3: Verify model can generate chat responses
|
337 |
+
logger.info("Step 3/7: Validating chat capabilities...")
|
338 |
+
|
339 |
+
if not verify_model_generation(model, tokenizer, device):
|
340 |
+
logger.warning("⚠ Model chat test didn't complete successfully")
|
341 |
+
logger.info("Continuing with export anyway...")
|
342 |
+
|
343 |
+
# Step 4: Export to ONNX
|
344 |
+
logger.info(f"Step 4/7: Exporting to ONNX format with opset {opset_version}...")
|
345 |
+
|
346 |
+
# Wrap model with generation-optimized interface
|
347 |
+
wrapped_model = GenerationWrapper(model)
|
348 |
+
wrapped_model.eval()
|
349 |
+
|
350 |
+
# Clean memory again
|
351 |
+
gc.collect()
|
352 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
353 |
+
|
354 |
+
# Export to ONNX with appropriate opset version
|
355 |
+
onnx_path = os.path.join(model_dir, "model.onnx")
|
356 |
+
|
357 |
+
# Create minimal input
|
358 |
+
batch_size = 1
|
359 |
+
dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long)
|
360 |
+
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long)
|
361 |
+
|
362 |
+
# Move tensors to correct device
|
363 |
+
dummy_input = dummy_input.to(device)
|
364 |
+
attention_mask = attention_mask.to(device)
|
365 |
+
|
366 |
+
# Export to ONNX with required opset
|
367 |
+
with torch.no_grad():
|
368 |
+
torch.onnx.export(
|
369 |
+
wrapped_model, # Wrapped model
|
370 |
+
(dummy_input, attention_mask), # Input tensors
|
371 |
+
onnx_path, # Output path
|
372 |
+
export_params=True, # Store weights
|
373 |
+
opset_version=opset_version, # Required opset version
|
374 |
+
do_constant_folding=True, # Optimize constants
|
375 |
+
input_names=['input_ids', 'attention_mask'], # Input names
|
376 |
+
output_names=['logits'], # Output name
|
377 |
+
dynamic_axes={ # Dynamic dimensions
|
378 |
+
'input_ids': {0: 'batch_size', 1: 'sequence'},
|
379 |
+
'attention_mask': {0: 'batch_size', 1: 'sequence'},
|
380 |
+
'logits': {0: 'batch_size', 1: 'sequence'}
|
381 |
+
}
|
382 |
+
)
|
383 |
+
|
384 |
+
# Clean up to save memory
|
385 |
+
del model
|
386 |
+
del wrapped_model
|
387 |
+
gc.collect()
|
388 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
389 |
+
|
390 |
+
# Verify export success
|
391 |
+
if os.path.exists(onnx_path):
|
392 |
+
size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
|
393 |
+
logger.info(f"✓ ONNX model saved to {onnx_path}")
|
394 |
+
logger.info(f"✓ Original size: {size_mb:.2f} MB")
|
395 |
+
|
396 |
+
# Step 5: Post-process the ONNX model for better Unity compatibility
|
397 |
+
logger.info("Step 5/7: Post-processing ONNX model for Unity compatibility...")
|
398 |
+
|
399 |
+
# Try to post-process model for Unity
|
400 |
+
try:
|
401 |
+
post_process_onnx_for_unity(onnx_path)
|
402 |
+
except Exception as e:
|
403 |
+
logger.warning(f"Post-processing failed (non-critical): {str(e)}")
|
404 |
+
|
405 |
+
# Test ONNX model
|
406 |
+
test_onnx_model(onnx_path, tokenizer)
|
407 |
+
|
408 |
+
# Step 6: Quantize the model (optional)
|
409 |
+
if quantize:
|
410 |
+
logger.info("Step 6/7: Applying INT8 quantization...")
|
411 |
+
quant_path = onnx_path.replace(".onnx", "_quantized.onnx")
|
412 |
+
|
413 |
+
try:
|
414 |
+
with tqdm(total=100, desc="Quantizing") as pbar:
|
415 |
+
# Update progress callback
|
416 |
+
def update_progress(x):
|
417 |
+
pbar.update(1)
|
418 |
+
|
419 |
+
# Apply quantization
|
420 |
+
quantize_dynamic(
|
421 |
+
model_input=onnx_path,
|
422 |
+
model_output=quant_path,
|
423 |
+
per_channel=False,
|
424 |
+
reduce_range=False,
|
425 |
+
weight_type=QuantType.QInt8,
|
426 |
+
optimize_model=True,
|
427 |
+
use_external_data_format=False
|
428 |
+
)
|
429 |
+
|
430 |
+
pbar.update(100) # Ensure progress reaches 100%
|
431 |
+
|
432 |
+
if os.path.exists(quant_path):
|
433 |
+
quant_size = os.path.getsize(quant_path) / (1024 * 1024)
|
434 |
+
logger.info(f"✓ Quantized size: {quant_size:.2f} MB")
|
435 |
+
logger.info(f"✓ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%")
|
436 |
+
|
437 |
+
# Test the quantized model
|
438 |
+
test_onnx_model(quant_path, tokenizer)
|
439 |
+
|
440 |
+
# Rename original as backup
|
441 |
+
backup_path = onnx_path.replace(".onnx", "_fp32.onnx")
|
442 |
+
os.rename(onnx_path, backup_path)
|
443 |
+
|
444 |
+
# Replace original with quantized
|
445 |
+
os.rename(quant_path, onnx_path)
|
446 |
+
logger.info("✓ Original model preserved as *_fp32.onnx")
|
447 |
+
logger.info("✓ Replaced original with quantized version")
|
448 |
+
else:
|
449 |
+
logger.warning("⚠ Quantized file not created, using original")
|
450 |
+
except Exception as e:
|
451 |
+
logger.error(f"⚠ Quantization error: {str(e)}")
|
452 |
+
logger.info("⚠ Using original model without quantization")
|
453 |
+
else:
|
454 |
+
logger.info("Step 6/7: Skipping quantization as requested")
|
455 |
+
|
456 |
+
# Step 7: Generate Unity integration examples
|
457 |
+
logger.info("Step 7/7: Generating Unity integration examples...")
|
458 |
+
|
459 |
+
# Create a Unity integration example
|
460 |
+
unity_example_path = os.path.join(model_dir, "unity_integration.cs")
|
461 |
+
with open(unity_example_path, 'w') as f:
|
462 |
+
f.write("""
|
463 |
+
using UnityEngine;
|
464 |
+
using Unity.Sentis;
|
465 |
+
using System.Collections.Generic;
|
466 |
+
using System.Linq;
|
467 |
+
using System.Text;
|
468 |
+
using System.Threading.Tasks;
|
469 |
+
|
470 |
+
public class ONNXChatbot : MonoBehaviour
|
471 |
+
{
|
472 |
+
[SerializeField] private ModelAsset modelAsset;
|
473 |
+
[SerializeField] private TextAsset tokenizerVocabJson;
|
474 |
+
[SerializeField] private int maxTokens = 50;
|
475 |
+
[SerializeField] private float temperature = 0.7f;
|
476 |
+
|
477 |
+
private IWorker worker;
|
478 |
+
private Dictionary<string, Tensor> inputs;
|
479 |
+
private SimpleTokenizer tokenizer;
|
480 |
+
private bool isGenerating = false;
|
481 |
+
|
482 |
+
void Start()
|
483 |
+
{
|
484 |
+
// Initialize the model
|
485 |
+
var model = ModelLoader.Load(modelAsset);
|
486 |
+
worker = WorkerFactory.CreateWorker(WorkerFactory.Type.ComputePrecompiled, model);
|
487 |
+
|
488 |
+
// Initialize tokenizer
|
489 |
+
tokenizer = new SimpleTokenizer(tokenizerVocabJson.text);
|
490 |
+
|
491 |
+
// Prepare for inference
|
492 |
+
inputs = new Dictionary<string, Tensor>();
|
493 |
+
|
494 |
+
Debug.Log("Model and tokenizer initialized successfully.");
|
495 |
+
}
|
496 |
+
|
497 |
+
public async Task<string> GenerateResponseAsync(string userMessage)
|
498 |
+
{
|
499 |
+
if (isGenerating)
|
500 |
+
{
|
501 |
+
Debug.LogWarning("Already generating a response. Please wait.");
|
502 |
+
return "Already generating a response. Please wait.";
|
503 |
+
}
|
504 |
+
|
505 |
+
isGenerating = true;
|
506 |
+
|
507 |
+
try
|
508 |
+
{
|
509 |
+
// Format prompt with chat template
|
510 |
+
string prompt = FormatChatPrompt(userMessage);
|
511 |
+
Debug.Log($"Formatted prompt: {prompt}");
|
512 |
+
|
513 |
+
// Tokenize input
|
514 |
+
var tokenIds = tokenizer.Encode(prompt);
|
515 |
+
Debug.Log($"Encoded to {tokenIds.Length} tokens");
|
516 |
+
|
517 |
+
if (tokenIds.Length > 0)
|
518 |
+
{
|
519 |
+
// Generate response token by token
|
520 |
+
StringBuilder responseBuilder = new StringBuilder();
|
521 |
+
List<int> currentIds = tokenIds.ToList();
|
522 |
+
|
523 |
+
for (int i = 0; i < maxTokens; i++)
|
524 |
+
{
|
525 |
+
// Make sure we don't exceed the model's context window
|
526 |
+
if (currentIds.Count > 1024)
|
527 |
+
{
|
528 |
+
// If too long, keep only the last 1024 tokens
|
529 |
+
currentIds = currentIds.Skip(currentIds.Count - 1024).Take(1024).ToList();
|
530 |
+
}
|
531 |
+
|
532 |
+
// Create tensors for current sequence
|
533 |
+
using (var inputIdsTensor = new TensorInt(new TensorShape(1, currentIds.Count), currentIds.ToArray()))
|
534 |
+
using (var attentionMaskTensor = new TensorInt(new TensorShape(1, currentIds.Count), Enumerable.Repeat(1, currentIds.Count).ToArray()))
|
535 |
+
{
|
536 |
+
// Run inference
|
537 |
+
inputs.Clear();
|
538 |
+
inputs["input_ids"] = inputIdsTensor;
|
539 |
+
inputs["attention_mask"] = attentionMaskTensor;
|
540 |
+
|
541 |
+
worker.Execute(inputs);
|
542 |
+
var logits = worker.PeekOutput() as TensorFloat;
|
543 |
+
|
544 |
+
// Get next token prediction
|
545 |
+
int nextToken = SampleNextToken(logits, currentIds, temperature);
|
546 |
+
|
547 |
+
// If we hit the end token or a newline after content, stop
|
548 |
+
if (nextToken == tokenizer.EosToken ||
|
549 |
+
(i > 0 && nextToken == tokenizer.NewlineToken))
|
550 |
+
{
|
551 |
+
break;
|
552 |
+
}
|
553 |
+
|
554 |
+
// Add token to current sequence for next iteration
|
555 |
+
currentIds.Add(nextToken);
|
556 |
+
|
557 |
+
// Decode the latest token
|
558 |
+
string newToken = tokenizer.Decode(new[] { nextToken });
|
559 |
+
responseBuilder.Append(newToken);
|
560 |
+
|
561 |
+
// For smoother output, yield every few tokens
|
562 |
+
if (i % 5 == 0)
|
563 |
+
{
|
564 |
+
await Task.Delay(1);
|
565 |
+
}
|
566 |
+
}
|
567 |
+
}
|
568 |
+
|
569 |
+
// Return the full response, without the prompt
|
570 |
+
string fullResponse = responseBuilder.ToString();
|
571 |
+
return CleanResponse(fullResponse);
|
572 |
+
}
|
573 |
+
else
|
574 |
+
{
|
575 |
+
Debug.LogError("Tokenization failed: empty token list");
|
576 |
+
return "Sorry, I couldn't process that input.";
|
577 |
+
}
|
578 |
+
}
|
579 |
+
catch (System.Exception ex)
|
580 |
+
{
|
581 |
+
Debug.LogError($"Generation error: {ex.Message}\\n{ex.StackTrace}");
|
582 |
+
return "Sorry, an error occurred while generating a response.";
|
583 |
+
}
|
584 |
+
finally
|
585 |
+
{
|
586 |
+
isGenerating = false;
|
587 |
+
}
|
588 |
+
}
|
589 |
+
|
590 |
+
private string FormatChatPrompt(string userMessage)
|
591 |
+
{
|
592 |
+
// You may need to adjust this template based on your specific model
|
593 |
+
return $"User: {userMessage}\\nAssistant:";
|
594 |
+
}
|
595 |
+
|
596 |
+
private string CleanResponse(string response)
|
597 |
+
{
|
598 |
+
// Extract only the Assistant's response
|
599 |
+
int assistantPrefix = response.IndexOf("Assistant:");
|
600 |
+
if (assistantPrefix >= 0)
|
601 |
+
{
|
602 |
+
response = response.Substring(assistantPrefix + "Assistant:".Length).Trim();
|
603 |
+
}
|
604 |
+
|
605 |
+
// Stop at any "User:" marker if present
|
606 |
+
int nextUser = response.IndexOf("User:");
|
607 |
+
if (nextUser >= 0)
|
608 |
+
{
|
609 |
+
response = response.Substring(0, nextUser).Trim();
|
610 |
+
}
|
611 |
+
|
612 |
+
return response;
|
613 |
+
}
|
614 |
+
|
615 |
+
private int SampleNextToken(TensorFloat logits, List<int> currentInputs, float temp)
|
616 |
+
{
|
617 |
+
// Get logits for the last position
|
618 |
+
int lastPos = currentInputs.Count - 1;
|
619 |
+
int vocabSize = logits.shape.channels;
|
620 |
+
|
621 |
+
// Prepare array for logits
|
622 |
+
float[] lastLogits = new float[vocabSize];
|
623 |
+
|
624 |
+
// Extract logits for the last token position
|
625 |
+
for (int i = 0; i < vocabSize; i++)
|
626 |
+
{
|
627 |
+
lastLogits[i] = logits[0, lastPos, i];
|
628 |
+
}
|
629 |
+
|
630 |
+
// Simple temperature-based sampling
|
631 |
+
if (temp <= 0.0f)
|
632 |
+
{
|
633 |
+
// Greedy sampling (argmax)
|
634 |
+
int maxIndex = 0;
|
635 |
+
float maxValue = lastLogits[0];
|
636 |
+
|
637 |
+
for (int i = 1; i < vocabSize; i++)
|
638 |
+
{
|
639 |
+
if (lastLogits[i] > maxValue)
|
640 |
+
{
|
641 |
+
maxValue = lastLogits[i];
|
642 |
+
maxIndex = i;
|
643 |
+
}
|
644 |
+
}
|
645 |
+
|
646 |
+
return maxIndex;
|
647 |
+
}
|
648 |
+
else
|
649 |
+
{
|
650 |
+
// Temperature sampling
|
651 |
+
// Apply temperature
|
652 |
+
for (int i = 0; i < vocabSize; i++)
|
653 |
+
{
|
654 |
+
lastLogits[i] /= temp;
|
655 |
+
}
|
656 |
+
|
657 |
+
// Softmax
|
658 |
+
float maxLogit = lastLogits.Max();
|
659 |
+
float sum = 0.0f;
|
660 |
+
|
661 |
+
for (int i = 0; i < vocabSize; i++)
|
662 |
+
{
|
663 |
+
lastLogits[i] = Mathf.Exp(lastLogits[i] - maxLogit);
|
664 |
+
sum += lastLogits[i];
|
665 |
+
}
|
666 |
+
|
667 |
+
for (int i = 0; i < vocabSize; i++)
|
668 |
+
{
|
669 |
+
lastLogits[i] /= sum;
|
670 |
+
}
|
671 |
+
|
672 |
+
// Sample from distribution
|
673 |
+
float random = Random.value;
|
674 |
+
float cumulativeProb = 0.0f;
|
675 |
+
|
676 |
+
for (int i = 0; i < vocabSize; i++)
|
677 |
+
{
|
678 |
+
cumulativeProb += lastLogits[i];
|
679 |
+
if (random < cumulativeProb)
|
680 |
+
{
|
681 |
+
return i;
|
682 |
+
}
|
683 |
+
}
|
684 |
+
|
685 |
+
// Fallback to last token if sampling fails
|
686 |
+
return vocabSize - 1;
|
687 |
+
}
|
688 |
+
}
|
689 |
+
|
690 |
+
void OnDestroy()
|
691 |
+
{
|
692 |
+
worker?.Dispose();
|
693 |
+
}
|
694 |
+
}
|
695 |
+
|
696 |
+
// Simple tokenizer implementation for Unity
|
697 |
+
public class SimpleTokenizer
|
698 |
+
{
|
699 |
+
private Dictionary<string, int> vocab;
|
700 |
+
private Dictionary<int, string> reversedVocab;
|
701 |
+
|
702 |
+
public int PadToken { get; private set; }
|
703 |
+
public int EosToken { get; private set; }
|
704 |
+
public int BosToken { get; private set; }
|
705 |
+
public int NewlineToken { get; private set; }
|
706 |
+
|
707 |
+
public SimpleTokenizer(string vocabJson)
|
708 |
+
{
|
709 |
+
// Parse the vocabulary from JSON
|
710 |
+
vocab = new Dictionary<string, int>();
|
711 |
+
|
712 |
+
// Simple JSON parsing (you'll need a proper JSON parser in production)
|
713 |
+
string[] entries = vocabJson.Split(new[] { '\\n', '{', '}', '\"', ':', ',' },
|
714 |
+
System.StringSplitOptions.RemoveEmptyEntries);
|
715 |
+
|
716 |
+
for (int i = 0; i < entries.Length - 1; i += 2)
|
717 |
+
{
|
718 |
+
string token = entries[i].Trim();
|
719 |
+
if (int.TryParse(entries[i + 1].Trim(), out int id))
|
720 |
+
{
|
721 |
+
vocab[token] = id;
|
722 |
+
}
|
723 |
+
}
|
724 |
+
|
725 |
+
// Create reversed vocabulary for decoding
|
726 |
+
reversedVocab = vocab.ToDictionary(kv => kv.Value, kv => kv.Key);
|
727 |
+
|
728 |
+
// Find special tokens
|
729 |
+
SetSpecialTokens();
|
730 |
+
|
731 |
+
Debug.Log($"Tokenizer initialized with {vocab.Count} tokens");
|
732 |
+
}
|
733 |
+
|
734 |
+
private void SetSpecialTokens()
|
735 |
+
{
|
736 |
+
// Try to find standard special tokens
|
737 |
+
PadToken = FindToken(new[] { "<pad>", "[PAD]", "<|endoftext|>" });
|
738 |
+
EosToken = FindToken(new[] { "</s>", "<|endoftext|>", "[EOS]", "<eos>" });
|
739 |
+
BosToken = FindToken(new[] { "<s>", "<|startoftext|>", "[BOS]", "<bos>" });
|
740 |
+
|
741 |
+
// Find newline token
|
742 |
+
foreach (var entry in vocab)
|
743 |
+
{
|
744 |
+
if (entry.Key == "\\n" || entry.Key == "<\\n>" || entry.Key == "\\n")
|
745 |
+
{
|
746 |
+
NewlineToken = entry.Value;
|
747 |
+
break;
|
748 |
+
}
|
749 |
+
}
|
750 |
+
|
751 |
+
Debug.Log($"Special tokens - PAD: {PadToken}, EOS: {EosToken}, BOS: {BosToken}, NEWLINE: {NewlineToken}");
|
752 |
+
}
|
753 |
+
|
754 |
+
private int FindToken(string[] candidates)
|
755 |
+
{
|
756 |
+
foreach (var candidate in candidates)
|
757 |
+
{
|
758 |
+
if (vocab.TryGetValue(candidate, out int id))
|
759 |
+
{
|
760 |
+
return id;
|
761 |
+
}
|
762 |
+
}
|
763 |
+
|
764 |
+
// Return -1 if not found
|
765 |
+
return -1;
|
766 |
+
}
|
767 |
+
|
768 |
+
public int[] Encode(string text)
|
769 |
+
{
|
770 |
+
// Simple character-level tokenization
|
771 |
+
// In production, use a proper BPE/WordPiece tokenizer implementation
|
772 |
+
List<int> tokens = new List<int>();
|
773 |
+
StringBuilder currentToken = new StringBuilder();
|
774 |
+
|
775 |
+
// Add BOS token if available
|
776 |
+
if (BosToken != -1)
|
777 |
+
{
|
778 |
+
tokens.Add(BosToken);
|
779 |
+
}
|
780 |
+
|
781 |
+
// Very simple tokenization - in production, this would implement
|
782 |
+
// the specific tokenization algorithm for your model
|
783 |
+
foreach (char c in text)
|
784 |
+
{
|
785 |
+
currentToken.Append(c);
|
786 |
+
string current = currentToken.ToString();
|
787 |
+
|
788 |
+
if (vocab.TryGetValue(current, out int id))
|
789 |
+
{
|
790 |
+
tokens.Add(id);
|
791 |
+
currentToken.Clear();
|
792 |
+
}
|
793 |
+
else if (currentToken.Length > 10)
|
794 |
+
{
|
795 |
+
// If token is too long, add unknown token and reset
|
796 |
+
tokens.Add(vocab.ContainsKey("<unk>") ? vocab["<unk>"] : 0);
|
797 |
+
currentToken.Clear();
|
798 |
+
currentToken.Append(c);
|
799 |
+
}
|
800 |
+
}
|
801 |
+
|
802 |
+
// Handle any remaining text
|
803 |
+
if (currentToken.Length > 0)
|
804 |
+
{
|
805 |
+
tokens.Add(vocab.ContainsKey("<unk>") ? vocab["<unk>"] : 0);
|
806 |
+
}
|
807 |
+
|
808 |
+
return tokens.ToArray();
|
809 |
+
}
|
810 |
+
|
811 |
+
public string Decode(int[] ids)
|
812 |
+
{
|
813 |
+
StringBuilder result = new StringBuilder();
|
814 |
+
|
815 |
+
foreach (int id in ids)
|
816 |
+
{
|
817 |
+
if (reversedVocab.TryGetValue(id, out string token))
|
818 |
+
{
|
819 |
+
// Some tokenizers use special prefixes like "Ġ" for spaces
|
820 |
+
string processedToken = token
|
821 |
+
.Replace("Ġ", " ")
|
822 |
+
.Replace("Ċ", "\n")
|
823 |
+
.Replace("▁", " ");
|
824 |
+
|
825 |
+
result.Append(processedToken);
|
826 |
+
}
|
827 |
+
}
|
828 |
+
|
829 |
+
return result.ToString();
|
830 |
+
}
|
831 |
+
}
|
832 |
+
""")
|
833 |
+
|
834 |
+
# Calculate elapsed time
|
835 |
+
end_time = time.time()
|
836 |
+
duration = end_time - start_time
|
837 |
+
logger.info(f"✓ Conversion completed in {duration:.2f} seconds")
|
838 |
+
logger.info(f"✓ Final model size: {os.path.getsize(onnx_path) / (1024 * 1024):.2f} MB")
|
839 |
+
|
840 |
+
# Create a Python example usage file
|
841 |
+
example_path = os.path.join(model_dir, "example_usage.py")
|
842 |
+
with open(example_path, 'w') as f:
|
843 |
+
f.write("""
|
844 |
+
import onnxruntime as ort
|
845 |
+
from transformers import AutoTokenizer
|
846 |
+
import numpy as np
|
847 |
+
|
848 |
+
# Load tokenizer and model
|
849 |
+
tokenizer = AutoTokenizer.from_pretrained("./") # Path to model directory
|
850 |
+
session = ort.InferenceSession("./model.onnx")
|
851 |
+
|
852 |
+
def generate_response(user_message, max_length=50):
|
853 |
+
# Format as a chat message
|
854 |
+
prompt = f"User: {user_message}\\nAssistant:"
|
855 |
+
inputs = tokenizer(prompt, return_tensors="np")
|
856 |
+
|
857 |
+
input_ids = inputs["input_ids"]
|
858 |
+
attention_mask = inputs["attention_mask"]
|
859 |
+
|
860 |
+
# Simple auto-regressive generation loop
|
861 |
+
for _ in range(max_length):
|
862 |
+
# Run inference for a single step
|
863 |
+
outputs = session.run(
|
864 |
+
["logits"],
|
865 |
+
{
|
866 |
+
"input_ids": input_ids,
|
867 |
+
"attention_mask": attention_mask
|
868 |
+
}
|
869 |
+
)
|
870 |
+
|
871 |
+
# Get next token prediction from logits
|
872 |
+
logits = outputs[0]
|
873 |
+
next_token_logits = logits[0, -1, :]
|
874 |
+
|
875 |
+
# Apply temperature sampling
|
876 |
+
temperature = 0.7
|
877 |
+
next_token_logits = next_token_logits / temperature
|
878 |
+
|
879 |
+
# Apply softmax to get probabilities
|
880 |
+
exp_logits = np.exp(next_token_logits - np.max(next_token_logits))
|
881 |
+
probs = exp_logits / np.sum(exp_logits)
|
882 |
+
|
883 |
+
# Sample from the distribution
|
884 |
+
next_token_id = np.random.choice(probs.shape[0], p=probs)
|
885 |
+
|
886 |
+
# Stop if we hit the end of sequence token
|
887 |
+
if next_token_id == tokenizer.eos_token_id:
|
888 |
+
break
|
889 |
+
|
890 |
+
# Append new token to the input_ids
|
891 |
+
input_ids = np.concatenate([input_ids, [[next_token_id]]], axis=1)
|
892 |
+
attention_mask = np.concatenate([attention_mask, [[1]]], axis=1)
|
893 |
+
|
894 |
+
# Decode the entire response
|
895 |
+
response = tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
896 |
+
|
897 |
+
# Extract only the assistant's response
|
898 |
+
if "Assistant:" in response:
|
899 |
+
response = response.split("Assistant:")[-1].strip()
|
900 |
+
|
901 |
+
return response
|
902 |
+
|
903 |
+
# Example usage
|
904 |
+
while True:
|
905 |
+
user_input = input("You: ")
|
906 |
+
if user_input.lower() in ['exit', 'quit']:
|
907 |
+
break
|
908 |
+
response = generate_response(user_input)
|
909 |
+
print(f"Assistant: {response}")
|
910 |
+
""")
|
911 |
+
|
912 |
+
logger.info(f"✓ Example usage saved to {example_path}")
|
913 |
+
logger.info(f"✓ Unity integration example saved to {unity_example_path}")
|
914 |
+
return True
|
915 |
+
|
916 |
+
else:
|
917 |
+
logger.error(f"× ONNX file not created at {onnx_path}")
|
918 |
+
return False
|
919 |
+
|
920 |
+
except Exception as e:
|
921 |
+
logger.error(f"�� Error converting model: {str(e)}")
|
922 |
+
logger.error(traceback.format_exc())
|
923 |
+
return False
|
924 |
+
|
925 |
+
if __name__ == "__main__":
|
926 |
+
# Parse command line arguments
|
927 |
+
parser_available = False
|
928 |
+
try:
|
929 |
+
import argparse
|
930 |
+
parser = argparse.ArgumentParser(description="Convert HuggingFace models to ONNX for Unity")
|
931 |
+
parser.add_argument("model_id", type=str, help="HuggingFace model ID or path")
|
932 |
+
parser.add_argument("--output_dir", "-o", type=str, default="./onnx_models",
|
933 |
+
help="Output directory for the converted model")
|
934 |
+
parser.add_argument("--seq_length", "-s", type=int, default=32,
|
935 |
+
help="Sequence length for model export")
|
936 |
+
parser.add_argument("--no_quantize", action="store_true",
|
937 |
+
help="Skip INT8 quantization step")
|
938 |
+
parser.add_argument("--opset", "-op", type=int, default=None,
|
939 |
+
help="Force a specific ONNX opset version")
|
940 |
+
|
941 |
+
args = parser.parse_args()
|
942 |
+
parser_available = True
|
943 |
+
|
944 |
+
model_id = args.model_id
|
945 |
+
output_dir = args.output_dir
|
946 |
+
seq_length = args.seq_length
|
947 |
+
quantize = not args.no_quantize
|
948 |
+
force_opset = args.opset
|
949 |
+
|
950 |
+
except (ImportError, NameError):
|
951 |
+
# Fallback if argparse is not available
|
952 |
+
parser_available = False
|
953 |
+
|
954 |
+
if not parser_available:
|
955 |
+
if len(sys.argv) < 2:
|
956 |
+
print("Usage: python unity_compatible_converter.py MODEL_ID [OUTPUT_DIR] [SEQ_LENGTH] [--no-quantize] [--opset]")
|
957 |
+
print("Example: python unity_compatible_converter.py distilgpt2 ./onnx_models 32")
|
958 |
+
print("\nRecommended chat models for Unity:")
|
959 |
+
print(" - distilgpt2 (smallest, opset 11)")
|
960 |
+
print(" - EleutherAI/pythia-70m (better quality, opset 11)")
|
961 |
+
print(" - microsoft/phi-1 (high quality, opset 14)")
|
962 |
+
print(" - TinyLlama/TinyLlama-1.1B-Chat-v1.0 (chat-tuned, opset 14)")
|
963 |
+
sys.exit(1)
|
964 |
+
|
965 |
+
model_id = sys.argv[1]
|
966 |
+
output_dir = sys.argv[2] if len(sys.argv) > 2 else "./onnx_models"
|
967 |
+
seq_length = int(sys.argv[3]) if len(sys.argv) > 3 else 32
|
968 |
+
quantize = "--no-quantize" not in sys.argv and "--no_quantize" not in sys.argv
|
969 |
+
force_opset = None
|
970 |
+
|
971 |
+
# Check for opset flag
|
972 |
+
for i, arg in enumerate(sys.argv):
|
973 |
+
if arg == "--opset" and i + 1 < len(sys.argv):
|
974 |
+
force_opset = int(sys.argv[i + 1])
|
975 |
+
|
976 |
+
# Check model architecture for automatic opset recommendation
|
977 |
+
is_compatible, recommended_opset = is_architecture_compatible(model_id)
|
978 |
+
|
979 |
+
# Print header
|
980 |
+
logger.info("\nUNITY-COMPATIBLE ONNX CONVERTER")
|
981 |
+
logger.info("===============================")
|
982 |
+
logger.info(f"Model: {model_id}")
|
983 |
+
logger.info(f"Output directory: {output_dir}")
|
984 |
+
logger.info(f"Sequence length: {seq_length}")
|
985 |
+
|
986 |
+
if force_opset is not None:
|
987 |
+
logger.info(f"ONNX opset version: {force_opset} (forced)")
|
988 |
+
else:
|
989 |
+
logger.info(f"Recommended ONNX opset: {recommended_opset}")
|
990 |
+
logger.info(f"Architecture compatible with opset 11: {'Yes' if is_compatible else 'No'}")
|
991 |
+
|
992 |
+
logger.info(f"Quantization: {'Enabled' if quantize else 'Disabled'}")
|
993 |
+
|
994 |
+
# Create output directory
|
995 |
+
os.makedirs(output_dir, exist_ok=True)
|
996 |
+
|
997 |
+
# Convert the model
|
998 |
+
success = convert_model(model_id, output_dir, seq_length, quantize, force_opset)
|
999 |
+
|
1000 |
+
if success:
|
1001 |
+
logger.info("\n" + "=" * 60)
|
1002 |
+
logger.info("CONVERSION SUCCESSFUL")
|
1003 |
+
logger.info("=" * 60)
|
1004 |
+
logger.info(f"Model: {model_id}")
|
1005 |
+
logger.info(f"Output directory: {os.path.abspath(output_dir)}")
|
1006 |
+
logger.info("The model is ready for Unity integration!")
|
1007 |
+
logger.info("\nNext steps:")
|
1008 |
+
logger.info("1. Import the ONNX model into Unity using the Sentis package")
|
1009 |
+
logger.info("2. Use the unity_integration.cs file as a starting point")
|
1010 |
+
logger.info("3. For tokenization in Unity, implement the SimpleTokenizer class")
|
1011 |
+
else:
|
1012 |
+
logger.info("\n" + "=" * 60)
|
1013 |
+
logger.info("CONVERSION FAILED")
|
1014 |
+
logger.info("=" * 60)
|
1015 |
+
logger.info("Please try one of the recommended models that work well with Unity:")
|
1016 |
+
|
1017 |
+
if is_compatible:
|
1018 |
+
logger.info("Compatible with Unity (opset 11):")
|
1019 |
+
logger.info(" - distilgpt2")
|
1020 |
+
logger.info(" - EleutherAI/pythia-70m")
|
1021 |
+
|
1022 |
+
logger.info("Advanced models (require opset 14):")
|
1023 |
+
logger.info(" - microsoft/phi-1 --opset 14")
|
1024 |
+
logger.info(" - TinyLlama/TinyLlama-1.1B-Chat-v1.0 --opset 14")
|
old_scripts/convert_single_model.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import logging
|
6 |
+
import traceback
|
7 |
+
import torch
|
8 |
+
import warnings
|
9 |
+
import numpy as np
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11 |
+
from transformers.generation import GenerationConfig
|
12 |
+
from tqdm import tqdm
|
13 |
+
import onnx
|
14 |
+
from onnxruntime.quantization import quantize_dynamic, QuantType
|
15 |
+
|
16 |
+
# Configure logging
|
17 |
+
logging.basicConfig(
|
18 |
+
level=logging.INFO,
|
19 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
20 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
21 |
+
)
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
# Suppress unhelpful warnings
|
25 |
+
warnings.filterwarnings("ignore", category=UserWarning, message=".*The shape of the input dimension.*")
|
26 |
+
warnings.filterwarnings("ignore", category=UserWarning, message=".*Converting a tensor to a Python.*")
|
27 |
+
warnings.filterwarnings("ignore", category=UserWarning, message=".*The model does not use GenerationMixin.*")
|
28 |
+
|
29 |
+
|
30 |
+
class GenerationWrapper(torch.nn.Module):
|
31 |
+
"""
|
32 |
+
Wrapper for model export that handles generation properly.
|
33 |
+
This ensures the model can be correctly used for text generation.
|
34 |
+
"""
|
35 |
+
def __init__(self, model):
|
36 |
+
super().__init__()
|
37 |
+
self.model = model
|
38 |
+
self.config = model.config
|
39 |
+
|
40 |
+
def forward(self, input_ids, attention_mask=None):
|
41 |
+
# Return only the logits to avoid complex structures
|
42 |
+
with torch.no_grad():
|
43 |
+
try:
|
44 |
+
# Standard approach for most models
|
45 |
+
outputs = self.model(
|
46 |
+
input_ids=input_ids,
|
47 |
+
attention_mask=attention_mask,
|
48 |
+
use_cache=False,
|
49 |
+
return_dict=True
|
50 |
+
)
|
51 |
+
return outputs.logits
|
52 |
+
except Exception as e:
|
53 |
+
logger.warning(f"Standard forward pass failed, trying fallback: {str(e)}")
|
54 |
+
# Fallback for models with different API
|
55 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
56 |
+
if hasattr(outputs, 'logits'):
|
57 |
+
return outputs.logits
|
58 |
+
elif isinstance(outputs, tuple) and len(outputs) > 0:
|
59 |
+
return outputs[0] # First element is typically logits
|
60 |
+
else:
|
61 |
+
raise ValueError("Could not extract logits from model outputs")
|
62 |
+
|
63 |
+
|
64 |
+
def verify_model_generation(model, tokenizer, device="cpu"):
|
65 |
+
"""Test model generation capabilities before export"""
|
66 |
+
model.eval()
|
67 |
+
prompt = "Hello, how are you today? I am"
|
68 |
+
|
69 |
+
logger.info("Testing model generation...")
|
70 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
71 |
+
|
72 |
+
# Configure generation parameters
|
73 |
+
gen_config = GenerationConfig(
|
74 |
+
max_length=30,
|
75 |
+
do_sample=True,
|
76 |
+
temperature=0.7,
|
77 |
+
num_return_sequences=1,
|
78 |
+
)
|
79 |
+
|
80 |
+
try:
|
81 |
+
# Try generation
|
82 |
+
with torch.no_grad():
|
83 |
+
outputs = model.generate(
|
84 |
+
**inputs,
|
85 |
+
generation_config=gen_config
|
86 |
+
)
|
87 |
+
|
88 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
89 |
+
logger.info(f"Test generation result: {generated_text}")
|
90 |
+
|
91 |
+
if len(generated_text) <= len(prompt):
|
92 |
+
logger.warning("Generation output is not longer than input prompt!")
|
93 |
+
|
94 |
+
return True
|
95 |
+
except Exception as e:
|
96 |
+
logger.error(f"Generation test failed: {str(e)}")
|
97 |
+
return False
|
98 |
+
|
99 |
+
|
100 |
+
def test_onnx_model(onnx_path, tokenizer):
|
101 |
+
"""Verify the ONNX model can be loaded and run"""
|
102 |
+
try:
|
103 |
+
import onnxruntime as ort
|
104 |
+
|
105 |
+
logger.info("Testing ONNX model inference...")
|
106 |
+
session = ort.InferenceSession(onnx_path)
|
107 |
+
|
108 |
+
# Get input and output names
|
109 |
+
input_names = [input.name for input in session.get_inputs()]
|
110 |
+
output_names = [output.name for output in session.get_outputs()]
|
111 |
+
|
112 |
+
# Create test input
|
113 |
+
prompt = "Hello, how are you?"
|
114 |
+
inputs = tokenizer(prompt, return_tensors="np")
|
115 |
+
|
116 |
+
# Prepare input dict
|
117 |
+
onnx_inputs = {}
|
118 |
+
for name in input_names:
|
119 |
+
if name == "input_ids" and "input_ids" in inputs:
|
120 |
+
onnx_inputs[name] = inputs["input_ids"]
|
121 |
+
elif name == "attention_mask" and "attention_mask" in inputs:
|
122 |
+
onnx_inputs[name] = inputs["attention_mask"]
|
123 |
+
|
124 |
+
# Run inference
|
125 |
+
outputs = session.run(output_names, onnx_inputs)
|
126 |
+
|
127 |
+
# Check output shape
|
128 |
+
logits = outputs[0]
|
129 |
+
logger.info(f"ONNX model output shape: {logits.shape}")
|
130 |
+
|
131 |
+
if logits.shape[0] != 1 or logits.shape[1] != inputs["input_ids"].shape[1]:
|
132 |
+
logger.warning("Output shape doesn't match expected dimensions!")
|
133 |
+
|
134 |
+
# Test next token prediction
|
135 |
+
next_token_logits = logits[0, -1, :]
|
136 |
+
next_token_id = np.argmax(next_token_logits)
|
137 |
+
next_token = tokenizer.decode([next_token_id])
|
138 |
+
logger.info(f"Next predicted token: '{next_token}'")
|
139 |
+
|
140 |
+
return True
|
141 |
+
except Exception as e:
|
142 |
+
logger.error(f"ONNX model test failed: {str(e)}")
|
143 |
+
return False
|
144 |
+
|
145 |
+
|
146 |
+
def optimize_onnx_model(onnx_path):
|
147 |
+
"""Apply ONNX optimizations to improve performance"""
|
148 |
+
try:
|
149 |
+
logger.info("Optimizing ONNX model...")
|
150 |
+
|
151 |
+
# Load the model
|
152 |
+
model = onnx.load(onnx_path)
|
153 |
+
|
154 |
+
# Apply optimizations
|
155 |
+
from onnxruntime.transformers import optimizer
|
156 |
+
|
157 |
+
# Get model type from path
|
158 |
+
model_path = os.path.dirname(onnx_path)
|
159 |
+
model_name = os.path.basename(model_path).lower()
|
160 |
+
|
161 |
+
# Determine model type for optimization
|
162 |
+
if "gpt" in model_name:
|
163 |
+
model_type = "gpt2"
|
164 |
+
elif "opt" in model_name:
|
165 |
+
model_type = "opt"
|
166 |
+
elif "pythia" in model_name:
|
167 |
+
model_type = "gpt_neox"
|
168 |
+
else:
|
169 |
+
model_type = "gpt2" # Default fallback
|
170 |
+
|
171 |
+
logger.info(f"Using optimization profile for model type: {model_type}")
|
172 |
+
|
173 |
+
# Try to optimize the model
|
174 |
+
try:
|
175 |
+
optimized_model = optimizer.optimize_model(
|
176 |
+
onnx_path,
|
177 |
+
model_type=model_type,
|
178 |
+
num_heads=8, # Will be overridden by model's real config
|
179 |
+
hidden_size=768, # Will be overridden by model's real config
|
180 |
+
optimization_options=None
|
181 |
+
)
|
182 |
+
optimized_model.save_model_to_file(onnx_path)
|
183 |
+
logger.info("✓ ONNX model optimized")
|
184 |
+
return True
|
185 |
+
except Exception as e:
|
186 |
+
logger.warning(f"Optimization failed (non-critical): {str(e)}")
|
187 |
+
return False
|
188 |
+
|
189 |
+
except Exception as e:
|
190 |
+
logger.warning(f"ONNX optimization error (skipping): {str(e)}")
|
191 |
+
return False
|
192 |
+
|
193 |
+
|
194 |
+
def convert_model(model_id, output_dir="./onnx_models", seq_length=32, quantize=True):
|
195 |
+
"""
|
196 |
+
Convert a model to ONNX format with focus on reliability for generation.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
model_id: HuggingFace model ID or path
|
200 |
+
output_dir: Directory to save the model
|
201 |
+
seq_length: Input sequence length for export
|
202 |
+
quantize: Whether to quantize the model to INT8
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
bool: Success status
|
206 |
+
"""
|
207 |
+
start_time = time.time()
|
208 |
+
|
209 |
+
logger.info(f"\n{'=' * 60}")
|
210 |
+
logger.info(f"Converting {model_id} to ONNX (optimized for generation)")
|
211 |
+
logger.info(f"{'=' * 60}")
|
212 |
+
|
213 |
+
# Create output directory
|
214 |
+
model_name = model_id.split("/")[-1]
|
215 |
+
model_dir = os.path.join(output_dir, model_name)
|
216 |
+
os.makedirs(model_dir, exist_ok=True)
|
217 |
+
|
218 |
+
try:
|
219 |
+
# Step 1: Load tokenizer
|
220 |
+
logger.info("Step 1/6: Loading tokenizer...")
|
221 |
+
|
222 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
223 |
+
if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'):
|
224 |
+
logger.info("Adding pad_token = eos_token")
|
225 |
+
tokenizer.pad_token = tokenizer.eos_token
|
226 |
+
|
227 |
+
# Save tokenizer
|
228 |
+
tokenizer.save_pretrained(model_dir)
|
229 |
+
logger.info(f"✓ Tokenizer saved to {model_dir}")
|
230 |
+
|
231 |
+
# Step 2: Load model with reliability optimizations
|
232 |
+
logger.info("Step 2/6: Loading model...")
|
233 |
+
|
234 |
+
# Clean memory
|
235 |
+
gc.collect()
|
236 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
237 |
+
|
238 |
+
# Determine device
|
239 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
240 |
+
|
241 |
+
# Load model with full precision
|
242 |
+
model = AutoModelForCausalLM.from_pretrained(
|
243 |
+
model_id,
|
244 |
+
torch_dtype=torch.float32, # Use full precision for reliability
|
245 |
+
low_cpu_mem_usage=True, # Reduce memory usage
|
246 |
+
device_map=device # Use CUDA if available
|
247 |
+
)
|
248 |
+
|
249 |
+
# Save config
|
250 |
+
model.config.save_pretrained(model_dir)
|
251 |
+
logger.info(f"✓ Model config saved to {model_dir}")
|
252 |
+
|
253 |
+
# Step 3: Verify model can generate text
|
254 |
+
logger.info("Step 3/6: Validating generation capabilities...")
|
255 |
+
|
256 |
+
if not verify_model_generation(model, tokenizer, device):
|
257 |
+
logger.warning("⚠ Model generation test didn't complete successfully")
|
258 |
+
logger.info("Continuing with export anyway...")
|
259 |
+
|
260 |
+
# Step 4: Wrap and prepare for export
|
261 |
+
logger.info("Step 4/6: Preparing for export...")
|
262 |
+
|
263 |
+
# Wrap model with generation-optimized interface
|
264 |
+
wrapped_model = GenerationWrapper(model)
|
265 |
+
wrapped_model.eval()
|
266 |
+
|
267 |
+
# Clean memory again
|
268 |
+
gc.collect()
|
269 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
270 |
+
|
271 |
+
# Step 5: Export to ONNX
|
272 |
+
logger.info("Step 5/6: Exporting to ONNX format...")
|
273 |
+
onnx_path = os.path.join(model_dir, "model.onnx")
|
274 |
+
|
275 |
+
# Create minimal input
|
276 |
+
batch_size = 1
|
277 |
+
dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long)
|
278 |
+
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long)
|
279 |
+
|
280 |
+
# Move tensors to correct device
|
281 |
+
dummy_input = dummy_input.to(device)
|
282 |
+
attention_mask = attention_mask.to(device)
|
283 |
+
|
284 |
+
# Export to ONNX with required opset for transformer models
|
285 |
+
with torch.no_grad():
|
286 |
+
torch.onnx.export(
|
287 |
+
wrapped_model, # Wrapped model
|
288 |
+
(dummy_input, attention_mask), # Input tensors
|
289 |
+
onnx_path, # Output path
|
290 |
+
export_params=True, # Store weights
|
291 |
+
opset_version=14, # Required for transformer models
|
292 |
+
do_constant_folding=True, # Optimize constants
|
293 |
+
input_names=['input_ids', 'attention_mask'], # Input names
|
294 |
+
output_names=['logits'], # Output name
|
295 |
+
dynamic_axes={ # Dynamic dimensions
|
296 |
+
'input_ids': {0: 'batch_size', 1: 'sequence'},
|
297 |
+
'attention_mask': {0: 'batch_size', 1: 'sequence'},
|
298 |
+
'logits': {0: 'batch_size', 1: 'sequence'}
|
299 |
+
}
|
300 |
+
)
|
301 |
+
|
302 |
+
# Clean up to save memory
|
303 |
+
del model
|
304 |
+
del wrapped_model
|
305 |
+
gc.collect()
|
306 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
307 |
+
|
308 |
+
# Verify export success
|
309 |
+
if os.path.exists(onnx_path):
|
310 |
+
size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
|
311 |
+
logger.info(f"✓ ONNX model saved to {onnx_path}")
|
312 |
+
logger.info(f"✓ Original size: {size_mb:.2f} MB")
|
313 |
+
|
314 |
+
# Test ONNX model
|
315 |
+
test_onnx_model(onnx_path, tokenizer)
|
316 |
+
|
317 |
+
# Optimize the ONNX model
|
318 |
+
optimize_onnx_model(onnx_path)
|
319 |
+
|
320 |
+
# Step 6: Quantize the model (optional)
|
321 |
+
if quantize:
|
322 |
+
logger.info("Step 6/6: Applying INT8 quantization...")
|
323 |
+
quant_path = onnx_path.replace(".onnx", "_quantized.onnx")
|
324 |
+
|
325 |
+
try:
|
326 |
+
with tqdm(total=100, desc="Quantizing") as pbar:
|
327 |
+
# Update progress callback
|
328 |
+
def update_progress(x):
|
329 |
+
pbar.update(1)
|
330 |
+
|
331 |
+
quantize_dynamic(
|
332 |
+
model_input=onnx_path,
|
333 |
+
model_output=quant_path,
|
334 |
+
per_channel=False,
|
335 |
+
reduce_range=False,
|
336 |
+
weight_type=QuantType.QInt8,
|
337 |
+
optimize_model=True,
|
338 |
+
use_external_data_format=False
|
339 |
+
)
|
340 |
+
|
341 |
+
pbar.update(100) # Ensure progress reaches 100%
|
342 |
+
|
343 |
+
if os.path.exists(quant_path):
|
344 |
+
quant_size = os.path.getsize(quant_path) / (1024 * 1024)
|
345 |
+
logger.info(f"✓ Quantized size: {quant_size:.2f} MB")
|
346 |
+
logger.info(f"✓ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%")
|
347 |
+
|
348 |
+
# Test the quantized model
|
349 |
+
test_onnx_model(quant_path, tokenizer)
|
350 |
+
|
351 |
+
# Rename original as backup
|
352 |
+
backup_path = onnx_path.replace(".onnx", "_fp32.onnx")
|
353 |
+
os.rename(onnx_path, backup_path)
|
354 |
+
|
355 |
+
# Replace original with quantized
|
356 |
+
os.rename(quant_path, onnx_path)
|
357 |
+
logger.info("✓ Original model preserved as *_fp32.onnx")
|
358 |
+
logger.info("✓ Replaced original with quantized version")
|
359 |
+
else:
|
360 |
+
logger.warning("⚠ Quantized file not created, using original")
|
361 |
+
except Exception as e:
|
362 |
+
logger.error(f"⚠ Quantization error: {str(e)}")
|
363 |
+
logger.info("⚠ Using original model without quantization")
|
364 |
+
else:
|
365 |
+
logger.info("Step 6/6: Skipping quantization as requested")
|
366 |
+
|
367 |
+
# Calculate elapsed time
|
368 |
+
end_time = time.time()
|
369 |
+
duration = end_time - start_time
|
370 |
+
logger.info(f"✓ Conversion completed in {duration:.2f} seconds")
|
371 |
+
logger.info(f"✓ Final model size: {os.path.getsize(onnx_path) / (1024 * 1024):.2f} MB")
|
372 |
+
|
373 |
+
# Create a simple example usage file
|
374 |
+
example_path = os.path.join(model_dir, "example_usage.py")
|
375 |
+
with open(example_path, 'w') as f:
|
376 |
+
f.write("""
|
377 |
+
import onnxruntime as ort
|
378 |
+
from transformers import AutoTokenizer
|
379 |
+
import numpy as np
|
380 |
+
|
381 |
+
# Load tokenizer and model
|
382 |
+
tokenizer = AutoTokenizer.from_pretrained("./") # Path to model directory
|
383 |
+
session = ort.InferenceSession("./model.onnx")
|
384 |
+
|
385 |
+
# Prepare input
|
386 |
+
prompt = "Hello, how are you?"
|
387 |
+
inputs = tokenizer(prompt, return_tensors="np")
|
388 |
+
|
389 |
+
# Run inference for a single step
|
390 |
+
outputs = session.run(
|
391 |
+
["logits"],
|
392 |
+
{
|
393 |
+
"input_ids": inputs["input_ids"],
|
394 |
+
"attention_mask": inputs["attention_mask"]
|
395 |
+
}
|
396 |
+
)
|
397 |
+
|
398 |
+
# Get next token prediction
|
399 |
+
logits = outputs[0]
|
400 |
+
next_token_id = np.argmax(logits[0, -1, :])
|
401 |
+
next_token = tokenizer.decode([next_token_id])
|
402 |
+
print(f"Next predicted token: {next_token}")
|
403 |
+
|
404 |
+
# For full generation, you'd typically run in a loop, adding tokens one by one
|
405 |
+
""")
|
406 |
+
logger.info(f"✓ Example usage saved to {example_path}")
|
407 |
+
|
408 |
+
return True
|
409 |
+
else:
|
410 |
+
logger.error(f"× ONNX file not created at {onnx_path}")
|
411 |
+
return False
|
412 |
+
|
413 |
+
except Exception as e:
|
414 |
+
logger.error(f"× Error converting model: {str(e)}")
|
415 |
+
logger.error(traceback.format_exc())
|
416 |
+
return False
|
417 |
+
|
418 |
+
|
419 |
+
if __name__ == "__main__":
|
420 |
+
# Parse command line arguments
|
421 |
+
parser_available = False
|
422 |
+
try:
|
423 |
+
import argparse
|
424 |
+
parser = argparse.ArgumentParser(description="Convert HuggingFace models to ONNX for generation")
|
425 |
+
parser.add_argument("model_id", type=str, help="HuggingFace model ID or path")
|
426 |
+
parser.add_argument("--output_dir", "-o", type=str, default="./onnx_models",
|
427 |
+
help="Output directory for the converted model")
|
428 |
+
parser.add_argument("--seq_length", "-s", type=int, default=32,
|
429 |
+
help="Sequence length for model export")
|
430 |
+
parser.add_argument("--no_quantize", action="store_true",
|
431 |
+
help="Skip INT8 quantization step")
|
432 |
+
|
433 |
+
args = parser.parse_args()
|
434 |
+
parser_available = True
|
435 |
+
|
436 |
+
model_id = args.model_id
|
437 |
+
output_dir = args.output_dir
|
438 |
+
seq_length = args.seq_length
|
439 |
+
quantize = not args.no_quantize
|
440 |
+
|
441 |
+
except (ImportError, NameError):
|
442 |
+
# Fallback if argparse is not available
|
443 |
+
parser_available = False
|
444 |
+
|
445 |
+
if not parser_available:
|
446 |
+
if len(sys.argv) < 2:
|
447 |
+
print("Usage: python convert_model.py MODEL_ID [OUTPUT_DIR] [SEQ_LENGTH] [--no-quantize]")
|
448 |
+
print("Example: python convert_model.py facebook/opt-125m ./onnx_models 32")
|
449 |
+
print("\nRecommended models for small hardware:")
|
450 |
+
print(" - facebook/opt-125m")
|
451 |
+
print(" - distilgpt2")
|
452 |
+
print(" - TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
453 |
+
print(" - EleutherAI/pythia-70m")
|
454 |
+
sys.exit(1)
|
455 |
+
|
456 |
+
model_id = sys.argv[1]
|
457 |
+
output_dir = sys.argv[2] if len(sys.argv) > 2 else "./onnx_models"
|
458 |
+
seq_length = int(sys.argv[3]) if len(sys.argv) > 3 else 32
|
459 |
+
quantize = "--no-quantize" not in sys.argv and "--no_quantize" not in sys.argv
|
460 |
+
|
461 |
+
# Print header
|
462 |
+
logger.info("\nENHANCED ONNX CONVERTER FOR LANGUAGE MODELS")
|
463 |
+
logger.info("============================================")
|
464 |
+
logger.info(f"Model: {model_id}")
|
465 |
+
logger.info(f"Output directory: {output_dir}")
|
466 |
+
logger.info(f"Sequence length: {seq_length}")
|
467 |
+
logger.info(f"Quantization: {'Enabled' if quantize else 'Disabled'}")
|
468 |
+
|
469 |
+
# Create output directory
|
470 |
+
os.makedirs(output_dir, exist_ok=True)
|
471 |
+
|
472 |
+
# Convert the model
|
473 |
+
success = convert_model(model_id, output_dir, seq_length, quantize)
|
474 |
+
|
475 |
+
if success:
|
476 |
+
logger.info("\n" + "=" * 60)
|
477 |
+
logger.info("CONVERSION SUCCESSFUL")
|
478 |
+
logger.info("=" * 60)
|
479 |
+
logger.info(f"Model: {model_id}")
|
480 |
+
logger.info(f"Output directory: {os.path.abspath(output_dir)}")
|
481 |
+
logger.info("The model is ready for generation!")
|
482 |
+
logger.info("\nTo use the model:")
|
483 |
+
logger.info("1. See the example_usage.py file in the model directory")
|
484 |
+
logger.info("2. For chatbot applications, implement token-by-token generation")
|
485 |
+
else:
|
486 |
+
logger.error("\n" + "=" * 60)
|
487 |
+
logger.error("CONVERSION FAILED")
|
488 |
+
logger.error("=" * 60)
|
489 |
+
logger.error("Please try one of the recommended models:")
|
490 |
+
logger.error(" - facebook/opt-125m")
|
491 |
+
logger.error(" - distilgpt2")
|
492 |
+
logger.error(" - EleutherAI/pythia-70m")
|
old_scripts/convert_to_onnx.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import logging
|
6 |
+
import traceback
|
7 |
+
import torch
|
8 |
+
import warnings
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
+
from onnxruntime.quantization import quantize_dynamic, QuantType
|
11 |
+
|
12 |
+
# Configure logging
|
13 |
+
logging.basicConfig(
|
14 |
+
level=logging.INFO,
|
15 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
16 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
17 |
+
handlers=[logging.StreamHandler(sys.stdout)]
|
18 |
+
)
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
# Suppress specific warnings
|
22 |
+
warnings.filterwarnings("ignore", category=UserWarning, message=".*The shape of the input dimension.*")
|
23 |
+
warnings.filterwarnings("ignore", category=UserWarning, message=".*Converting a tensor to a Python.*")
|
24 |
+
|
25 |
+
# Models that are known to work well with ONNX conversion
|
26 |
+
RELIABLE_MODELS = [
|
27 |
+
{
|
28 |
+
"id": "facebook/opt-350m",
|
29 |
+
"description": "Well-balanced model (350M) for RAG and chatbots"
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"id": "gpt2",
|
33 |
+
"description": "Very reliable model (124M) with excellent ONNX compatibility"
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"id": "distilgpt2",
|
37 |
+
"description": "Lightweight (82M) model with good performance"
|
38 |
+
}
|
39 |
+
]
|
40 |
+
|
41 |
+
class ModelWrapper(torch.nn.Module):
|
42 |
+
"""
|
43 |
+
Wrapper to handle ONNX export compatibility issues.
|
44 |
+
This wrapper specifically:
|
45 |
+
1. Bypasses cache handling
|
46 |
+
2. Simplifies the forward pass to avoid dynamic operations
|
47 |
+
"""
|
48 |
+
def __init__(self, model):
|
49 |
+
super().__init__()
|
50 |
+
self.model = model
|
51 |
+
|
52 |
+
def forward(self, input_ids):
|
53 |
+
# Force no cache, no gradient, and no special features
|
54 |
+
with torch.no_grad():
|
55 |
+
return self.model(input_ids=input_ids, use_cache=False, return_dict=False)[0]
|
56 |
+
|
57 |
+
def convert_model(model_id, output_dir, quantize=True):
|
58 |
+
"""Convert a model to ONNX format with maximum compatibility."""
|
59 |
+
start_time = time.time()
|
60 |
+
|
61 |
+
logger.info(f"\n{'=' * 60}")
|
62 |
+
logger.info(f"Converting {model_id} to ONNX")
|
63 |
+
logger.info(f"{'=' * 60}")
|
64 |
+
|
65 |
+
# Create output directory
|
66 |
+
model_name = model_id.split("/")[-1]
|
67 |
+
model_dir = os.path.join(output_dir, model_name)
|
68 |
+
os.makedirs(model_dir, exist_ok=True)
|
69 |
+
|
70 |
+
try:
|
71 |
+
# Step 1: Load tokenizer
|
72 |
+
logger.info("Step 1/5: Loading tokenizer...")
|
73 |
+
|
74 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
75 |
+
|
76 |
+
# Handle missing pad token
|
77 |
+
if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'):
|
78 |
+
logger.info("Adding pad_token = eos_token")
|
79 |
+
tokenizer.pad_token = tokenizer.eos_token
|
80 |
+
|
81 |
+
# Save tokenizer
|
82 |
+
tokenizer.save_pretrained(model_dir)
|
83 |
+
logger.info(f"✓ Tokenizer saved to {model_dir}")
|
84 |
+
|
85 |
+
# Step 2: Load model with memory optimizations
|
86 |
+
logger.info("Step 2/5: Loading model with memory optimizations...")
|
87 |
+
|
88 |
+
# Clean memory before loading
|
89 |
+
gc.collect()
|
90 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
91 |
+
|
92 |
+
# Load model with optimizations
|
93 |
+
model = AutoModelForCausalLM.from_pretrained(
|
94 |
+
model_id,
|
95 |
+
torch_dtype=torch.float16, # Use half precision
|
96 |
+
low_cpu_mem_usage=True # Reduce memory usage
|
97 |
+
)
|
98 |
+
|
99 |
+
# Save config for reference
|
100 |
+
model.config.save_pretrained(model_dir)
|
101 |
+
logger.info(f"✓ Model config saved to {model_dir}")
|
102 |
+
|
103 |
+
# Step 3: Prepare for export
|
104 |
+
logger.info("Step 3/5: Preparing for export...")
|
105 |
+
|
106 |
+
# Wrap model to avoid tracing issues
|
107 |
+
wrapped_model = ModelWrapper(model)
|
108 |
+
wrapped_model.eval() # Set to evaluation mode
|
109 |
+
|
110 |
+
# Clean memory again
|
111 |
+
gc.collect()
|
112 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
113 |
+
|
114 |
+
# Step 4: Export to ONNX
|
115 |
+
logger.info("Step 4/5: Exporting to ONNX format...")
|
116 |
+
onnx_path = os.path.join(model_dir, "model.onnx")
|
117 |
+
|
118 |
+
# Create dummy input
|
119 |
+
batch_size = 1
|
120 |
+
seq_length = 8 # Small sequence length to reduce memory
|
121 |
+
dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long)
|
122 |
+
|
123 |
+
# Export to ONNX format with new opset version
|
124 |
+
torch.onnx.export(
|
125 |
+
wrapped_model, # Use wrapped model
|
126 |
+
dummy_input, # Model input
|
127 |
+
onnx_path, # Output path
|
128 |
+
export_params=True, # Store model weights
|
129 |
+
opset_version=14, # ONNX opset version (changed from 13 to 14)
|
130 |
+
do_constant_folding=True, # Optimize constants
|
131 |
+
input_names=['input_ids'], # Input names
|
132 |
+
output_names=['logits'], # Output names
|
133 |
+
dynamic_axes={
|
134 |
+
'input_ids': {0: 'batch_size', 1: 'sequence'},
|
135 |
+
'logits': {0: 'batch_size', 1: 'sequence'}
|
136 |
+
}
|
137 |
+
)
|
138 |
+
|
139 |
+
# Clean up to save memory
|
140 |
+
del model
|
141 |
+
del wrapped_model
|
142 |
+
gc.collect()
|
143 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
144 |
+
|
145 |
+
# Verify export was successful
|
146 |
+
if os.path.exists(onnx_path):
|
147 |
+
size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
|
148 |
+
logger.info(f"✓ ONNX model saved to {onnx_path}")
|
149 |
+
logger.info(f"✓ Original size: {size_mb:.2f} MB")
|
150 |
+
|
151 |
+
# Step 5: Quantize
|
152 |
+
if quantize:
|
153 |
+
logger.info("Step 5/5: Applying int8 quantization...")
|
154 |
+
quant_path = onnx_path.replace(".onnx", "_quantized.onnx")
|
155 |
+
|
156 |
+
try:
|
157 |
+
quantize_dynamic(
|
158 |
+
model_input=onnx_path,
|
159 |
+
model_output=quant_path,
|
160 |
+
per_channel=False,
|
161 |
+
reduce_range=False,
|
162 |
+
weight_type=QuantType.QInt8
|
163 |
+
)
|
164 |
+
|
165 |
+
if os.path.exists(quant_path):
|
166 |
+
quant_size = os.path.getsize(quant_path) / (1024 * 1024)
|
167 |
+
logger.info(f"✓ Quantized size: {quant_size:.2f} MB")
|
168 |
+
logger.info(f"✓ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%")
|
169 |
+
|
170 |
+
# Replace original with quantized to save space
|
171 |
+
os.replace(quant_path, onnx_path)
|
172 |
+
logger.info("✓ Replaced original with quantized version")
|
173 |
+
else:
|
174 |
+
logger.warning("⚠ Quantized file not created, using original")
|
175 |
+
except Exception as e:
|
176 |
+
logger.error(f"⚠ Quantization error: {str(e)}")
|
177 |
+
logger.info("⚠ Using original model without quantization")
|
178 |
+
else:
|
179 |
+
logger.info("Step 5/5: Skipping quantization (not requested)")
|
180 |
+
|
181 |
+
# Calculate elapsed time
|
182 |
+
end_time = time.time()
|
183 |
+
duration = end_time - start_time
|
184 |
+
logger.info(f"✓ Conversion completed in {duration:.2f} seconds")
|
185 |
+
|
186 |
+
return {
|
187 |
+
"success": True,
|
188 |
+
"model_id": model_id,
|
189 |
+
"size_mb": os.path.getsize(onnx_path) / (1024 * 1024),
|
190 |
+
"duration_seconds": duration,
|
191 |
+
"output_dir": model_dir
|
192 |
+
}
|
193 |
+
else:
|
194 |
+
logger.error(f"× ONNX file not created at {onnx_path}")
|
195 |
+
return {
|
196 |
+
"success": False,
|
197 |
+
"model_id": model_id,
|
198 |
+
"error": "ONNX file not created"
|
199 |
+
}
|
200 |
+
|
201 |
+
except Exception as e:
|
202 |
+
logger.error(f"× Error converting model: {str(e)}")
|
203 |
+
logger.error(traceback.format_exc())
|
204 |
+
|
205 |
+
return {
|
206 |
+
"success": False,
|
207 |
+
"model_id": model_id,
|
208 |
+
"error": str(e)
|
209 |
+
}
|
210 |
+
|
211 |
+
def main():
|
212 |
+
"""Convert all reliable models."""
|
213 |
+
# Print header
|
214 |
+
logger.info("\nGUARANTEED ONNX CONVERTER")
|
215 |
+
logger.info("======================")
|
216 |
+
logger.info("Using reliable models with proven ONNX compatibility")
|
217 |
+
|
218 |
+
# Create output directory
|
219 |
+
output_dir = "./onnx_models"
|
220 |
+
os.makedirs(output_dir, exist_ok=True)
|
221 |
+
|
222 |
+
# Check if specific model ID provided as argument
|
223 |
+
if len(sys.argv) > 1:
|
224 |
+
model_id = sys.argv[1]
|
225 |
+
logger.info(f"Converting single model: {model_id}")
|
226 |
+
convert_model(model_id, output_dir)
|
227 |
+
return
|
228 |
+
|
229 |
+
# Convert all reliable models
|
230 |
+
results = []
|
231 |
+
for model_info in RELIABLE_MODELS:
|
232 |
+
model_id = model_info["id"]
|
233 |
+
logger.info(f"Processing model: {model_id}")
|
234 |
+
logger.info(f"Description: {model_info['description']}")
|
235 |
+
|
236 |
+
result = convert_model(model_id, output_dir)
|
237 |
+
results.append(result)
|
238 |
+
|
239 |
+
# Print summary
|
240 |
+
logger.info("\n" + "=" * 60)
|
241 |
+
logger.info("CONVERSION SUMMARY")
|
242 |
+
logger.info("=" * 60)
|
243 |
+
|
244 |
+
success_count = 0
|
245 |
+
for result in results:
|
246 |
+
if result.get("success", False):
|
247 |
+
success_count += 1
|
248 |
+
size_info = f" - Size: {result.get('size_mb', 0):.2f} MB"
|
249 |
+
time_info = f" - Time: {result.get('duration_seconds', 0):.2f}s"
|
250 |
+
logger.info(f"✓ SUCCESS: {result['model_id']}{size_info}{time_info}")
|
251 |
+
else:
|
252 |
+
logger.info(f"× FAILED: {result['model_id']} - Error: {result.get('error', 'Unknown error')}")
|
253 |
+
|
254 |
+
logger.info(f"\nSuccessfully converted {success_count}/{len(RELIABLE_MODELS)} models")
|
255 |
+
logger.info(f"Models saved to: {os.path.abspath(output_dir)}")
|
256 |
+
|
257 |
+
if success_count > 0:
|
258 |
+
logger.info("\nThe models are ready for RAG and chatbot applications!")
|
259 |
+
|
260 |
+
if __name__ == "__main__":
|
261 |
+
main()
|
old_scripts/test_chat.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import argparse
|
5 |
+
import logging
|
6 |
+
import numpy as np
|
7 |
+
import onnxruntime as ort
|
8 |
+
from transformers import AutoTokenizer
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
# Configure logging
|
12 |
+
logging.basicConfig(
|
13 |
+
level=logging.INFO,
|
14 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
15 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
16 |
+
)
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
class ONNXGenerationChatbot:
|
20 |
+
def __init__(self, model_path, max_length=100):
|
21 |
+
"""
|
22 |
+
Initialize the ONNX chatbot for text generation.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
model_path: Path to the directory containing the ONNX model and tokenizer
|
26 |
+
max_length: Maximum sequence length for generation
|
27 |
+
"""
|
28 |
+
# Set up model paths
|
29 |
+
self.model_dir = model_path
|
30 |
+
self.onnx_path = os.path.join(self.model_dir, "model.onnx")
|
31 |
+
self.fp32_path = os.path.join(self.model_dir, "model_fp32.onnx")
|
32 |
+
|
33 |
+
# Check for model files
|
34 |
+
if not os.path.exists(self.onnx_path):
|
35 |
+
raise FileNotFoundError(f"ONNX model not found at {self.onnx_path}")
|
36 |
+
|
37 |
+
# Get model name for prompt formatting
|
38 |
+
self.model_name = os.path.basename(os.path.normpath(model_path))
|
39 |
+
logger.info(f"Using model: {self.model_name}")
|
40 |
+
|
41 |
+
# Load tokenizer
|
42 |
+
logger.info(f"Loading tokenizer from {self.model_dir}...")
|
43 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, local_files_only=True)
|
44 |
+
|
45 |
+
# Ensure tokenizer has necessary tokens
|
46 |
+
if self.tokenizer.pad_token is None and hasattr(self.tokenizer, 'eos_token'):
|
47 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
48 |
+
|
49 |
+
# Create optimized session
|
50 |
+
logger.info(f"Loading ONNX model from {self.onnx_path}...")
|
51 |
+
self.session_options = ort.SessionOptions()
|
52 |
+
self.session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
53 |
+
self.session_options.intra_op_num_threads = 4 # Adjust based on your CPU
|
54 |
+
|
55 |
+
# Create session with appropriate providers
|
56 |
+
providers = ['CPUExecutionProvider']
|
57 |
+
if 'CUDAExecutionProvider' in ort.get_available_providers():
|
58 |
+
logger.info("CUDA is available! Using GPU acceleration.")
|
59 |
+
providers.insert(0, 'CUDAExecutionProvider')
|
60 |
+
|
61 |
+
self.session = ort.InferenceSession(
|
62 |
+
self.onnx_path,
|
63 |
+
sess_options=self.session_options,
|
64 |
+
providers=providers
|
65 |
+
)
|
66 |
+
|
67 |
+
# Get input and output names from the model
|
68 |
+
self.input_names = [input.name for input in self.session.get_inputs()]
|
69 |
+
self.output_names = [output.name for output in self.session.get_outputs()]
|
70 |
+
|
71 |
+
logger.info(f"Model inputs: {self.input_names}")
|
72 |
+
logger.info(f"Model outputs: {self.output_names}")
|
73 |
+
|
74 |
+
# Settings
|
75 |
+
self.max_length = max_length
|
76 |
+
self.stop_tokens = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else []
|
77 |
+
|
78 |
+
# Try to add common stop tokens if they exist in the vocabulary
|
79 |
+
stop_words = ["<|endoftext|>", "</s>", "<|end|>"]
|
80 |
+
for word in stop_words:
|
81 |
+
try:
|
82 |
+
token_id = self.tokenizer.convert_tokens_to_ids(word)
|
83 |
+
if token_id not in self.stop_tokens and token_id != self.tokenizer.unk_token_id:
|
84 |
+
self.stop_tokens.append(token_id)
|
85 |
+
except:
|
86 |
+
pass
|
87 |
+
|
88 |
+
logger.info(f"Using stop tokens: {self.stop_tokens}")
|
89 |
+
|
90 |
+
# Conversation history for context
|
91 |
+
self.conversation_history = []
|
92 |
+
|
93 |
+
def get_prompt_template(self):
|
94 |
+
"""
|
95 |
+
Get the appropriate prompt template based on the model type.
|
96 |
+
"""
|
97 |
+
if "opt" in self.model_name.lower():
|
98 |
+
return "Human: {}\nAssistant:"
|
99 |
+
elif "pythia" in self.model_name.lower():
|
100 |
+
return "USER: {}\nASSISTANT:"
|
101 |
+
elif "llama" in self.model_name.lower() or "alpaca" in self.model_name.lower():
|
102 |
+
return "### Human: {}\n### Assistant:"
|
103 |
+
elif "gpt2" in self.model_name.lower() or "distilgpt2" in self.model_name.lower():
|
104 |
+
return "User: {}\nBot:"
|
105 |
+
else:
|
106 |
+
return "Question: {}\nAnswer:"
|
107 |
+
|
108 |
+
def format_prompt_with_history(self, user_message):
|
109 |
+
"""
|
110 |
+
Format the prompt with conversation history for better context.
|
111 |
+
"""
|
112 |
+
template = self.get_prompt_template()
|
113 |
+
parts = template.split("{}")
|
114 |
+
prefix = parts[0]
|
115 |
+
suffix = parts[1] if len(parts) > 1 else ""
|
116 |
+
|
117 |
+
# Include history if available (up to 3 turns)
|
118 |
+
formatted_prompt = ""
|
119 |
+
for i, (user, bot) in enumerate(self.conversation_history[-3:]):
|
120 |
+
formatted_prompt += f"{prefix}{user}{suffix} {bot}\n\n"
|
121 |
+
|
122 |
+
# Add current user message
|
123 |
+
formatted_prompt += f"{prefix}{user_message}{suffix}"
|
124 |
+
|
125 |
+
return formatted_prompt
|
126 |
+
|
127 |
+
def run_inference_step(self, input_ids, attention_mask=None):
|
128 |
+
"""
|
129 |
+
Run a single inference step with the ONNX model.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
input_ids: Token IDs of the input sequence
|
133 |
+
attention_mask: Attention mask for the input sequence
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
numpy array: Logits for the next token prediction
|
137 |
+
"""
|
138 |
+
# Prepare model inputs
|
139 |
+
model_inputs = {}
|
140 |
+
for name in self.input_names:
|
141 |
+
if name == "input_ids":
|
142 |
+
model_inputs[name] = input_ids
|
143 |
+
elif name == "attention_mask" and attention_mask is not None:
|
144 |
+
model_inputs[name] = attention_mask
|
145 |
+
|
146 |
+
# Run inference
|
147 |
+
outputs = self.session.run(self.output_names, model_inputs)
|
148 |
+
|
149 |
+
# Return logits (assumes first output is logits)
|
150 |
+
return outputs[0]
|
151 |
+
|
152 |
+
def generate_text(self, prompt, max_new_tokens=50, temperature=0.7, top_k=50, top_p=0.9,
|
153 |
+
repetition_penalty=1.1, do_sample=True, show_progress=True):
|
154 |
+
"""
|
155 |
+
Generate text using the ONNX model.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
prompt: Text prompt to generate from
|
159 |
+
max_new_tokens: Maximum number of tokens to generate
|
160 |
+
temperature: Temperature for sampling (higher = more random)
|
161 |
+
top_k: Number of highest probability tokens to keep for sampling
|
162 |
+
top_p: Cumulative probability threshold for nucleus sampling
|
163 |
+
repetition_penalty: Penalty for repeating tokens
|
164 |
+
do_sample: Whether to sample from the distribution or use greedy decoding
|
165 |
+
show_progress: Whether to show a progress bar during generation
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
str: Generated text
|
169 |
+
"""
|
170 |
+
# Encode the prompt
|
171 |
+
encoded = self.tokenizer(prompt, return_tensors="np")
|
172 |
+
input_ids = encoded["input_ids"]
|
173 |
+
attention_mask = encoded["attention_mask"]
|
174 |
+
|
175 |
+
# Track input tokens for repetition penalty
|
176 |
+
prev_tokens = input_ids[0].tolist()
|
177 |
+
|
178 |
+
# Setup progress bar if requested
|
179 |
+
progress = tqdm(total=max_new_tokens, desc="Generating") if show_progress else None
|
180 |
+
|
181 |
+
# Generate tokens auto-regressively
|
182 |
+
for _ in range(max_new_tokens):
|
183 |
+
# Run inference to get next token logits
|
184 |
+
logits = self.run_inference_step(input_ids, attention_mask)
|
185 |
+
|
186 |
+
# Get logits for the last token
|
187 |
+
next_token_logits = logits[0, -1, :]
|
188 |
+
|
189 |
+
# Apply temperature scaling
|
190 |
+
if temperature > 0:
|
191 |
+
next_token_logits = next_token_logits / max(temperature, 1e-8)
|
192 |
+
|
193 |
+
# Apply repetition penalty
|
194 |
+
if repetition_penalty > 1.0:
|
195 |
+
for prev_token in set(prev_tokens[-10:]): # Only consider recent tokens
|
196 |
+
if prev_token < len(next_token_logits):
|
197 |
+
next_token_logits[prev_token] /= repetition_penalty
|
198 |
+
|
199 |
+
# Apply top-k filtering
|
200 |
+
if top_k > 0:
|
201 |
+
indices_to_remove = np.argsort(next_token_logits)[:-top_k]
|
202 |
+
next_token_logits[indices_to_remove] = -float('inf')
|
203 |
+
|
204 |
+
# Apply top-p (nucleus) filtering
|
205 |
+
if 0 < top_p < 1.0:
|
206 |
+
sorted_logits = np.sort(next_token_logits)[::-1]
|
207 |
+
sorted_indices = np.argsort(next_token_logits)[::-1]
|
208 |
+
cumulative_probs = np.cumsum(np.exp(sorted_logits) / np.sum(np.exp(sorted_logits)))
|
209 |
+
|
210 |
+
# Remove tokens with cumulative probability above the threshold
|
211 |
+
sorted_indices_to_remove = sorted_indices[cumulative_probs > top_p]
|
212 |
+
next_token_logits[sorted_indices_to_remove] = -float('inf')
|
213 |
+
|
214 |
+
# Sample from the filtered distribution or use greedy decoding
|
215 |
+
if do_sample:
|
216 |
+
# Apply softmax to get probabilities
|
217 |
+
probs = np.exp(next_token_logits - np.max(next_token_logits))
|
218 |
+
probs = probs / np.sum(probs)
|
219 |
+
|
220 |
+
# Handle NaNs
|
221 |
+
if np.isnan(probs).any():
|
222 |
+
next_token_id = np.argmax(next_token_logits)
|
223 |
+
else:
|
224 |
+
try:
|
225 |
+
# Sample from the distribution
|
226 |
+
next_token_id = np.random.choice(len(probs), p=probs)
|
227 |
+
except:
|
228 |
+
# Fallback to greedy if sampling fails
|
229 |
+
next_token_id = np.argmax(next_token_logits)
|
230 |
+
else:
|
231 |
+
# Greedy decoding - take highest probability token
|
232 |
+
next_token_id = np.argmax(next_token_logits)
|
233 |
+
|
234 |
+
# Add the chosen token to the input
|
235 |
+
next_token = np.array([[next_token_id]])
|
236 |
+
input_ids = np.concatenate([input_ids, next_token], axis=1)
|
237 |
+
|
238 |
+
# Update attention mask
|
239 |
+
attention_mask = np.ones((1, input_ids.shape[1]), dtype=np.int64)
|
240 |
+
|
241 |
+
# Add token to history for repetition penalty
|
242 |
+
prev_tokens.append(int(next_token_id))
|
243 |
+
|
244 |
+
# Update progress bar if active
|
245 |
+
if progress is not None:
|
246 |
+
progress.update(1)
|
247 |
+
|
248 |
+
# Check for stop tokens or end of text
|
249 |
+
if next_token_id in self.stop_tokens:
|
250 |
+
break
|
251 |
+
|
252 |
+
# Also stop if we exceed max length
|
253 |
+
if input_ids.shape[1] >= self.max_length:
|
254 |
+
break
|
255 |
+
|
256 |
+
# Close progress bar if used
|
257 |
+
if progress is not None:
|
258 |
+
progress.close()
|
259 |
+
|
260 |
+
# Decode the full sequence
|
261 |
+
generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
262 |
+
return generated_text
|
263 |
+
|
264 |
+
def extract_assistant_response(self, full_text, prompt):
|
265 |
+
"""
|
266 |
+
Extract just the assistant's response from the full generated text.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
full_text: Full generated text including prompt
|
270 |
+
prompt: The original prompt
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
str: Just the assistant's response
|
274 |
+
"""
|
275 |
+
# Try to extract based on the prompt format
|
276 |
+
template = self.get_prompt_template()
|
277 |
+
response_start_marker = template.split("{}")[-1]
|
278 |
+
|
279 |
+
# If the prompt is in the text, extract everything after it
|
280 |
+
if prompt in full_text:
|
281 |
+
after_prompt = full_text[len(prompt):]
|
282 |
+
|
283 |
+
# Handle additional newlines or spaces at the beginning
|
284 |
+
return after_prompt.lstrip()
|
285 |
+
|
286 |
+
# If the response marker is in the text, extract everything after it
|
287 |
+
if response_start_marker.strip() in full_text:
|
288 |
+
parts = full_text.split(response_start_marker.strip(), 1)
|
289 |
+
if len(parts) > 1:
|
290 |
+
return parts[1].strip()
|
291 |
+
|
292 |
+
# Fallback: return everything after the last line of the prompt
|
293 |
+
prompt_last_line = prompt.strip().split('\n')[-1]
|
294 |
+
if prompt_last_line in full_text:
|
295 |
+
parts = full_text.split(prompt_last_line, 1)
|
296 |
+
if len(parts) > 1:
|
297 |
+
return parts[1].strip()
|
298 |
+
|
299 |
+
# Last resort: return the whole thing
|
300 |
+
return full_text
|
301 |
+
|
302 |
+
def chat(self, temperature=0.7, max_new_tokens=100):
|
303 |
+
"""
|
304 |
+
Run an interactive chat session with the model.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
temperature: Temperature for text generation
|
308 |
+
max_new_tokens: Maximum number of tokens to generate per response
|
309 |
+
"""
|
310 |
+
print("\n===== ONNX Generation Chatbot =====")
|
311 |
+
print(f"Model: {self.model_name}")
|
312 |
+
print(f"Type 'exit' to end the conversation")
|
313 |
+
print(f"Type 'reset' to clear conversation history")
|
314 |
+
|
315 |
+
while True:
|
316 |
+
# Get user input
|
317 |
+
user_input = input("\nYou: ")
|
318 |
+
|
319 |
+
# Check for exit command
|
320 |
+
if user_input.lower() in ["exit", "quit", "bye"]:
|
321 |
+
print("Goodbye!")
|
322 |
+
break
|
323 |
+
|
324 |
+
# Check for reset command
|
325 |
+
if user_input.lower() == "reset":
|
326 |
+
self.conversation_history = []
|
327 |
+
print("Conversation history cleared.")
|
328 |
+
continue
|
329 |
+
|
330 |
+
# Create prompt with history
|
331 |
+
prompt = self.format_prompt_with_history(user_input)
|
332 |
+
print("\nGenerating response...")
|
333 |
+
|
334 |
+
# Generate text
|
335 |
+
try:
|
336 |
+
start_time = time.time()
|
337 |
+
full_text = self.generate_text(
|
338 |
+
prompt,
|
339 |
+
max_new_tokens=max_new_tokens,
|
340 |
+
temperature=temperature,
|
341 |
+
show_progress=True
|
342 |
+
)
|
343 |
+
|
344 |
+
# Extract just the assistant's response
|
345 |
+
response = self.extract_assistant_response(full_text, prompt)
|
346 |
+
|
347 |
+
# Clean up any trailing incomplete sentences
|
348 |
+
if response and len(response) > 0:
|
349 |
+
# Try to end at a sentence boundary if possible
|
350 |
+
sentence_end = max(
|
351 |
+
response.rfind('.'),
|
352 |
+
response.rfind('!'),
|
353 |
+
response.rfind('?')
|
354 |
+
)
|
355 |
+
if sentence_end > len(response) * 0.5: # Only trim if we're not losing too much
|
356 |
+
response = response[:sentence_end+1]
|
357 |
+
|
358 |
+
# Calculate generation time
|
359 |
+
gen_time = time.time() - start_time
|
360 |
+
gen_speed = max_new_tokens / gen_time if gen_time > 0 else 0
|
361 |
+
|
362 |
+
# Print the response
|
363 |
+
print(f"\nBot: {response}")
|
364 |
+
print(f"\n[Generated {len(response)} chars in {gen_time:.2f}s ({gen_speed:.1f} tokens/sec)]")
|
365 |
+
|
366 |
+
# Add to conversation history
|
367 |
+
self.conversation_history.append((user_input, response))
|
368 |
+
|
369 |
+
# Keep history at a reasonable size
|
370 |
+
if len(self.conversation_history) > 10:
|
371 |
+
self.conversation_history = self.conversation_history[-10:]
|
372 |
+
|
373 |
+
except Exception as e:
|
374 |
+
logger.error(f"Error generating response: {str(e)}")
|
375 |
+
print("\nBot: I encountered an error while generating a response. Let's try again.")
|
376 |
+
|
377 |
+
|
378 |
+
def main():
|
379 |
+
"""Run the ONNX chatbot with command line arguments."""
|
380 |
+
parser = argparse.ArgumentParser(description="Interactive ONNX Chatbot")
|
381 |
+
parser.add_argument("--model", type=str, required=True,
|
382 |
+
help="Path to the ONNX model directory")
|
383 |
+
parser.add_argument("--temperature", type=float, default=0.7,
|
384 |
+
help="Temperature for text generation (default: 0.7)")
|
385 |
+
parser.add_argument("--max_tokens", type=int, default=100,
|
386 |
+
help="Maximum tokens to generate per response (default: 100)")
|
387 |
+
|
388 |
+
args = parser.parse_args()
|
389 |
+
|
390 |
+
try:
|
391 |
+
# Create and run the chatbot
|
392 |
+
chatbot = ONNXGenerationChatbot(args.model)
|
393 |
+
chatbot.chat(temperature=args.temperature, max_new_tokens=args.max_tokens)
|
394 |
+
except KeyboardInterrupt:
|
395 |
+
print("\nExiting chatbot. Goodbye!")
|
396 |
+
except Exception as e:
|
397 |
+
logger.error(f"Error: {str(e)}")
|
398 |
+
sys.exit(1)
|
399 |
+
|
400 |
+
|
401 |
+
if __name__ == "__main__":
|
402 |
+
main()
|
onnx_models/bloom_onnx/config.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_attn_implementation_autoset": true,
|
3 |
+
"_name_or_path": "bigscience/bloom-560m",
|
4 |
+
"apply_residual_connection_post_layernorm": false,
|
5 |
+
"architectures": [
|
6 |
+
"BloomForCausalLM"
|
7 |
+
],
|
8 |
+
"attention_dropout": 0.0,
|
9 |
+
"attention_softmax_in_fp32": true,
|
10 |
+
"bias_dropout_fusion": true,
|
11 |
+
"bos_token_id": 1,
|
12 |
+
"eos_token_id": 2,
|
13 |
+
"hidden_dropout": 0.0,
|
14 |
+
"hidden_size": 1024,
|
15 |
+
"initializer_range": 0.02,
|
16 |
+
"layer_norm_epsilon": 1e-05,
|
17 |
+
"masked_softmax_fusion": true,
|
18 |
+
"model_type": "bloom",
|
19 |
+
"n_head": 16,
|
20 |
+
"n_inner": null,
|
21 |
+
"n_layer": 24,
|
22 |
+
"offset_alibi": 100,
|
23 |
+
"pad_token_id": 3,
|
24 |
+
"pretraining_tp": 1,
|
25 |
+
"skip_bias_add": true,
|
26 |
+
"skip_bias_add_qkv": false,
|
27 |
+
"slow_but_exact": false,
|
28 |
+
"transformers_version": "4.48.3",
|
29 |
+
"unk_token_id": 0,
|
30 |
+
"use_cache": true,
|
31 |
+
"vocab_size": 250880
|
32 |
+
}
|
onnx_models/bloom_onnx/generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 1,
|
4 |
+
"eos_token_id": 2,
|
5 |
+
"pad_token_id": 3,
|
6 |
+
"transformers_version": "4.48.3"
|
7 |
+
}
|
onnx_models/bloom_onnx/model.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:268cdaf473da19cc5cb7f1c0eef597e3719dc88524ebc4e78b51268cfcdb8d28
|
3 |
+
size 798372
|
onnx_models/bloom_onnx/special_tokens_map.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "</s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": {
|
17 |
+
"content": "<pad>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"unk_token": {
|
24 |
+
"content": "<unk>",
|
25 |
+
"lstrip": false,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
}
|
30 |
+
}
|
onnx_models/bloom_onnx/tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d963066d6adae5034a1dc114c3ac444512de09928cf14ed4562ba94d9a440e66
|
3 |
+
size 21763085
|
onnx_models/bloom_onnx/tokenizer_config.json
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"added_tokens_decoder": {
|
4 |
+
"0": {
|
5 |
+
"content": "<unk>",
|
6 |
+
"lstrip": false,
|
7 |
+
"normalized": false,
|
8 |
+
"rstrip": false,
|
9 |
+
"single_word": false,
|
10 |
+
"special": true
|
11 |
+
},
|
12 |
+
"1": {
|
13 |
+
"content": "<s>",
|
14 |
+
"lstrip": false,
|
15 |
+
"normalized": false,
|
16 |
+
"rstrip": false,
|
17 |
+
"single_word": false,
|
18 |
+
"special": true
|
19 |
+
},
|
20 |
+
"2": {
|
21 |
+
"content": "</s>",
|
22 |
+
"lstrip": false,
|
23 |
+
"normalized": false,
|
24 |
+
"rstrip": false,
|
25 |
+
"single_word": false,
|
26 |
+
"special": true
|
27 |
+
},
|
28 |
+
"3": {
|
29 |
+
"content": "<pad>",
|
30 |
+
"lstrip": false,
|
31 |
+
"normalized": false,
|
32 |
+
"rstrip": false,
|
33 |
+
"single_word": false,
|
34 |
+
"special": true
|
35 |
+
}
|
36 |
+
},
|
37 |
+
"bos_token": "<s>",
|
38 |
+
"clean_up_tokenization_spaces": false,
|
39 |
+
"eos_token": "</s>",
|
40 |
+
"extra_special_tokens": {},
|
41 |
+
"merges_file": null,
|
42 |
+
"model_max_length": 1000000000000000019884624838656,
|
43 |
+
"pad_token": "<pad>",
|
44 |
+
"padding_side": "left",
|
45 |
+
"tokenizer_class": "BloomTokenizer",
|
46 |
+
"unk_token": "<unk>",
|
47 |
+
"vocab_file": null
|
48 |
+
}
|
onnx_models/bloom_onnx_quantized/config.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_attn_implementation_autoset": true,
|
3 |
+
"_name_or_path": "onnx_models/bloom_onnx",
|
4 |
+
"apply_residual_connection_post_layernorm": false,
|
5 |
+
"architectures": [
|
6 |
+
"BloomForCausalLM"
|
7 |
+
],
|
8 |
+
"attention_dropout": 0.0,
|
9 |
+
"attention_softmax_in_fp32": true,
|
10 |
+
"bias_dropout_fusion": true,
|
11 |
+
"bos_token_id": 1,
|
12 |
+
"eos_token_id": 2,
|
13 |
+
"hidden_dropout": 0.0,
|
14 |
+
"hidden_size": 1024,
|
15 |
+
"initializer_range": 0.02,
|
16 |
+
"layer_norm_epsilon": 1e-05,
|
17 |
+
"masked_softmax_fusion": true,
|
18 |
+
"model_type": "bloom",
|
19 |
+
"n_head": 16,
|
20 |
+
"n_inner": null,
|
21 |
+
"n_layer": 24,
|
22 |
+
"offset_alibi": 100,
|
23 |
+
"pad_token_id": 3,
|
24 |
+
"pretraining_tp": 1,
|
25 |
+
"skip_bias_add": true,
|
26 |
+
"skip_bias_add_qkv": false,
|
27 |
+
"slow_but_exact": false,
|
28 |
+
"transformers_version": "4.48.3",
|
29 |
+
"unk_token_id": 0,
|
30 |
+
"use_cache": true,
|
31 |
+
"vocab_size": 250880
|
32 |
+
}
|
onnx_models/bloom_onnx_quantized/model_quantized.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:179e57ab6bb5a39b3feef242d2d569aa321f8f1461ec5247c1bd980444b07419
|
3 |
+
size 561463713
|
onnx_models/bloom_onnx_quantized/ort_config.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"one_external_file": true,
|
3 |
+
"opset": null,
|
4 |
+
"optimization": {},
|
5 |
+
"quantization": {
|
6 |
+
"activations_dtype": "QUInt8",
|
7 |
+
"activations_symmetric": false,
|
8 |
+
"format": "QOperator",
|
9 |
+
"is_static": false,
|
10 |
+
"mode": "IntegerOps",
|
11 |
+
"nodes_to_exclude": [],
|
12 |
+
"nodes_to_quantize": [],
|
13 |
+
"operators_to_quantize": [
|
14 |
+
"Conv",
|
15 |
+
"MatMul",
|
16 |
+
"Attention",
|
17 |
+
"LSTM",
|
18 |
+
"Gather",
|
19 |
+
"Transpose",
|
20 |
+
"EmbedLayerNormalization"
|
21 |
+
],
|
22 |
+
"per_channel": false,
|
23 |
+
"qdq_add_pair_to_weight": false,
|
24 |
+
"qdq_dedicated_pair": false,
|
25 |
+
"qdq_op_type_per_channel_support_to_axis": {
|
26 |
+
"MatMul": 1
|
27 |
+
},
|
28 |
+
"reduce_range": false,
|
29 |
+
"weights_dtype": "QInt8",
|
30 |
+
"weights_symmetric": true
|
31 |
+
},
|
32 |
+
"use_external_data_format": false
|
33 |
+
}
|
onnx_models/bloom_onnx_quantized/special_tokens_map.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "</s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": {
|
17 |
+
"content": "<pad>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"unk_token": {
|
24 |
+
"content": "<unk>",
|
25 |
+
"lstrip": false,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
}
|
30 |
+
}
|
onnx_models/bloom_onnx_quantized/tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d963066d6adae5034a1dc114c3ac444512de09928cf14ed4562ba94d9a440e66
|
3 |
+
size 21763085
|
onnx_models/bloom_onnx_quantized/tokenizer_config.json
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"added_tokens_decoder": {
|
4 |
+
"0": {
|
5 |
+
"content": "<unk>",
|
6 |
+
"lstrip": false,
|
7 |
+
"normalized": false,
|
8 |
+
"rstrip": false,
|
9 |
+
"single_word": false,
|
10 |
+
"special": true
|
11 |
+
},
|
12 |
+
"1": {
|
13 |
+
"content": "<s>",
|
14 |
+
"lstrip": false,
|
15 |
+
"normalized": false,
|
16 |
+
"rstrip": false,
|
17 |
+
"single_word": false,
|
18 |
+
"special": true
|
19 |
+
},
|
20 |
+
"2": {
|
21 |
+
"content": "</s>",
|
22 |
+
"lstrip": false,
|
23 |
+
"normalized": false,
|
24 |
+
"rstrip": false,
|
25 |
+
"single_word": false,
|
26 |
+
"special": true
|
27 |
+
},
|
28 |
+
"3": {
|
29 |
+
"content": "<pad>",
|
30 |
+
"lstrip": false,
|
31 |
+
"normalized": false,
|
32 |
+
"rstrip": false,
|
33 |
+
"single_word": false,
|
34 |
+
"special": true
|
35 |
+
}
|
36 |
+
},
|
37 |
+
"bos_token": "<s>",
|
38 |
+
"clean_up_tokenization_spaces": false,
|
39 |
+
"eos_token": "</s>",
|
40 |
+
"extra_special_tokens": {},
|
41 |
+
"merges_file": null,
|
42 |
+
"model_max_length": 1000000000000000019884624838656,
|
43 |
+
"pad_token": "<pad>",
|
44 |
+
"padding_side": "left",
|
45 |
+
"tokenizer_class": "BloomTokenizer",
|
46 |
+
"unk_token": "<unk>",
|
47 |
+
"vocab_file": null
|
48 |
+
}
|
onnx_models/falcon_onnx/config.json
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_attn_implementation_autoset": true,
|
3 |
+
"_name_or_path": "tiiuae/falcon-rw-1b",
|
4 |
+
"activation": "gelu",
|
5 |
+
"alibi": true,
|
6 |
+
"apply_residual_connection_post_layernorm": false,
|
7 |
+
"architectures": [
|
8 |
+
"FalconForCausalLM"
|
9 |
+
],
|
10 |
+
"attention_dropout": 0.0,
|
11 |
+
"auto_map": {
|
12 |
+
"AutoConfig": "tiiuae/falcon-rw-1b--configuration_falcon.FalconConfig",
|
13 |
+
"AutoModel": "tiiuae/falcon-rw-1b--modeling_falcon.FalconModel",
|
14 |
+
"AutoModelForCausalLM": "tiiuae/falcon-rw-1b--modeling_falcon.FalconForCausalLM",
|
15 |
+
"AutoModelForQuestionAnswering": "tiiuae/falcon-rw-1b--modeling_falcon.FalconForQuestionAnswering",
|
16 |
+
"AutoModelForSequenceClassification": "tiiuae/falcon-rw-1b--modeling_falcon.FalconForSequenceClassification",
|
17 |
+
"AutoModelForTokenClassification": "tiiuae/falcon-rw-1b--modeling_falcon.FalconForTokenClassification"
|
18 |
+
},
|
19 |
+
"bias": true,
|
20 |
+
"bos_token_id": 1,
|
21 |
+
"eos_token_id": 2,
|
22 |
+
"ffn_hidden_size": 8192,
|
23 |
+
"hidden_dropout": 0.0,
|
24 |
+
"hidden_size": 2048,
|
25 |
+
"initializer_range": 0.02,
|
26 |
+
"layer_norm_epsilon": 1e-05,
|
27 |
+
"max_position_embeddings": 2048,
|
28 |
+
"model_type": "falcon",
|
29 |
+
"multi_query": false,
|
30 |
+
"new_decoder_architecture": false,
|
31 |
+
"num_attention_heads": 32,
|
32 |
+
"num_hidden_layers": 24,
|
33 |
+
"num_kv_heads": 32,
|
34 |
+
"num_ln_in_parallel_attn": null,
|
35 |
+
"parallel_attn": false,
|
36 |
+
"rope_scaling": null,
|
37 |
+
"rope_theta": 10000.0,
|
38 |
+
"transformers_version": "4.48.3",
|
39 |
+
"use_cache": true,
|
40 |
+
"vocab_size": 50304
|
41 |
+
}
|
onnx_models/falcon_onnx/generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 1,
|
4 |
+
"eos_token_id": 2,
|
5 |
+
"transformers_version": "4.48.3"
|
6 |
+
}
|
onnx_models/falcon_onnx/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
onnx_models/falcon_onnx/model.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c26d8b62a099f87043745987be680556c9a7c0324944af78d58b7f8559b73c17
|
3 |
+
size 655121
|
onnx_models/falcon_onnx/special_tokens_map.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|endoftext|>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|endoftext|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"unk_token": {
|
17 |
+
"content": "<|endoftext|>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
}
|
23 |
+
}
|
onnx_models/falcon_onnx/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
onnx_models/falcon_onnx/tokenizer_config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"added_tokens_decoder": {
|
4 |
+
"50256": {
|
5 |
+
"content": "<|endoftext|>",
|
6 |
+
"lstrip": false,
|
7 |
+
"normalized": false,
|
8 |
+
"rstrip": false,
|
9 |
+
"single_word": false,
|
10 |
+
"special": true
|
11 |
+
}
|
12 |
+
},
|
13 |
+
"bos_token": "<|endoftext|>",
|
14 |
+
"clean_up_tokenization_spaces": true,
|
15 |
+
"eos_token": "<|endoftext|>",
|
16 |
+
"extra_special_tokens": {},
|
17 |
+
"model_max_length": 1024,
|
18 |
+
"tokenizer_class": "GPT2Tokenizer",
|
19 |
+
"unk_token": "<|endoftext|>"
|
20 |
+
}
|
onnx_models/falcon_onnx/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
onnx_models/gpt2_onnx/config.json
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_attn_implementation_autoset": true,
|
3 |
+
"_name_or_path": "gpt2-medium",
|
4 |
+
"activation_function": "gelu_new",
|
5 |
+
"architectures": [
|
6 |
+
"GPT2LMHeadModel"
|
7 |
+
],
|
8 |
+
"attn_pdrop": 0.1,
|
9 |
+
"bos_token_id": 50256,
|
10 |
+
"embd_pdrop": 0.1,
|
11 |
+
"eos_token_id": 50256,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"layer_norm_epsilon": 1e-05,
|
14 |
+
"model_type": "gpt2",
|
15 |
+
"n_ctx": 1024,
|
16 |
+
"n_embd": 1024,
|
17 |
+
"n_head": 16,
|
18 |
+
"n_inner": null,
|
19 |
+
"n_layer": 24,
|
20 |
+
"n_positions": 1024,
|
21 |
+
"n_special": 0,
|
22 |
+
"predict_special_tokens": true,
|
23 |
+
"reorder_and_upcast_attn": false,
|
24 |
+
"resid_pdrop": 0.1,
|
25 |
+
"scale_attn_by_inverse_layer_idx": false,
|
26 |
+
"scale_attn_weights": true,
|
27 |
+
"summary_activation": null,
|
28 |
+
"summary_first_dropout": 0.1,
|
29 |
+
"summary_proj_to_labels": true,
|
30 |
+
"summary_type": "cls_index",
|
31 |
+
"summary_use_proj": true,
|
32 |
+
"task_specific_params": {
|
33 |
+
"text-generation": {
|
34 |
+
"do_sample": true,
|
35 |
+
"max_length": 50
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"transformers_version": "4.48.3",
|
39 |
+
"use_cache": true,
|
40 |
+
"vocab_size": 50257
|
41 |
+
}
|
onnx_models/gpt2_onnx/generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 50256,
|
4 |
+
"eos_token_id": 50256,
|
5 |
+
"transformers_version": "4.48.3"
|
6 |
+
}
|
onnx_models/gpt2_onnx/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
onnx_models/gpt2_onnx/model.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d7c266101cd0fb1a383a3006369c88384d5f76081ec1b9a4e76ff4b7bc15ffe6
|
3 |
+
size 1420150742
|
onnx_models/gpt2_onnx/special_tokens_map.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<|endoftext|>",
|
3 |
+
"eos_token": "<|endoftext|>",
|
4 |
+
"unk_token": "<|endoftext|>"
|
5 |
+
}
|
onnx_models/gpt2_onnx/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
onnx_models/gpt2_onnx/tokenizer_config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"added_tokens_decoder": {
|
4 |
+
"50256": {
|
5 |
+
"content": "<|endoftext|>",
|
6 |
+
"lstrip": false,
|
7 |
+
"normalized": true,
|
8 |
+
"rstrip": false,
|
9 |
+
"single_word": false,
|
10 |
+
"special": true
|
11 |
+
}
|
12 |
+
},
|
13 |
+
"bos_token": "<|endoftext|>",
|
14 |
+
"clean_up_tokenization_spaces": false,
|
15 |
+
"eos_token": "<|endoftext|>",
|
16 |
+
"extra_special_tokens": {},
|
17 |
+
"model_max_length": 1024,
|
18 |
+
"tokenizer_class": "GPT2Tokenizer",
|
19 |
+
"unk_token": "<|endoftext|>"
|
20 |
+
}
|
onnx_models/gpt2_onnx/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
onnx_models/gpt2_onnx_quantized/config.json
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_attn_implementation_autoset": true,
|
3 |
+
"_name_or_path": "onnx_models/gpt2_onnx",
|
4 |
+
"activation_function": "gelu_new",
|
5 |
+
"architectures": [
|
6 |
+
"GPT2LMHeadModel"
|
7 |
+
],
|
8 |
+
"attn_pdrop": 0.1,
|
9 |
+
"bos_token_id": 50256,
|
10 |
+
"embd_pdrop": 0.1,
|
11 |
+
"eos_token_id": 50256,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"layer_norm_epsilon": 1e-05,
|
14 |
+
"model_type": "gpt2",
|
15 |
+
"n_ctx": 1024,
|
16 |
+
"n_embd": 1024,
|
17 |
+
"n_head": 16,
|
18 |
+
"n_inner": null,
|
19 |
+
"n_layer": 24,
|
20 |
+
"n_positions": 1024,
|
21 |
+
"n_special": 0,
|
22 |
+
"predict_special_tokens": true,
|
23 |
+
"reorder_and_upcast_attn": false,
|
24 |
+
"resid_pdrop": 0.1,
|
25 |
+
"scale_attn_by_inverse_layer_idx": false,
|
26 |
+
"scale_attn_weights": true,
|
27 |
+
"summary_activation": null,
|
28 |
+
"summary_first_dropout": 0.1,
|
29 |
+
"summary_proj_to_labels": true,
|
30 |
+
"summary_type": "cls_index",
|
31 |
+
"summary_use_proj": true,
|
32 |
+
"task_specific_params": {
|
33 |
+
"text-generation": {
|
34 |
+
"do_sample": true,
|
35 |
+
"max_length": 50
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"transformers_version": "4.48.3",
|
39 |
+
"use_cache": true,
|
40 |
+
"vocab_size": 50257
|
41 |
+
}
|
onnx_models/gpt2_onnx_quantized/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
onnx_models/gpt2_onnx_quantized/model_quantized.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b051bd6632d5039e281be589c65e56abc86861340f41d044f49f496afe35aa07
|
3 |
+
size 357201134
|
onnx_models/gpt2_onnx_quantized/ort_config.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"one_external_file": true,
|
3 |
+
"opset": null,
|
4 |
+
"optimization": {},
|
5 |
+
"quantization": {
|
6 |
+
"activations_dtype": "QUInt8",
|
7 |
+
"activations_symmetric": false,
|
8 |
+
"format": "QOperator",
|
9 |
+
"is_static": false,
|
10 |
+
"mode": "IntegerOps",
|
11 |
+
"nodes_to_exclude": [],
|
12 |
+
"nodes_to_quantize": [],
|
13 |
+
"operators_to_quantize": [
|
14 |
+
"Conv",
|
15 |
+
"MatMul",
|
16 |
+
"Attention",
|
17 |
+
"LSTM",
|
18 |
+
"Gather",
|
19 |
+
"Transpose",
|
20 |
+
"EmbedLayerNormalization"
|
21 |
+
],
|
22 |
+
"per_channel": false,
|
23 |
+
"qdq_add_pair_to_weight": false,
|
24 |
+
"qdq_dedicated_pair": false,
|
25 |
+
"qdq_op_type_per_channel_support_to_axis": {
|
26 |
+
"MatMul": 1
|
27 |
+
},
|
28 |
+
"reduce_range": false,
|
29 |
+
"weights_dtype": "QInt8",
|
30 |
+
"weights_symmetric": true
|
31 |
+
},
|
32 |
+
"use_external_data_format": false
|
33 |
+
}
|
onnx_models/gpt2_onnx_quantized/special_tokens_map.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|endoftext|>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": true,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|endoftext|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": true,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"unk_token": {
|
17 |
+
"content": "<|endoftext|>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": true,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
}
|
23 |
+
}
|
onnx_models/gpt2_onnx_quantized/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
onnx_models/gpt2_onnx_quantized/tokenizer_config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"added_tokens_decoder": {
|
4 |
+
"50256": {
|
5 |
+
"content": "<|endoftext|>",
|
6 |
+
"lstrip": false,
|
7 |
+
"normalized": true,
|
8 |
+
"rstrip": false,
|
9 |
+
"single_word": false,
|
10 |
+
"special": true
|
11 |
+
}
|
12 |
+
},
|
13 |
+
"bos_token": "<|endoftext|>",
|
14 |
+
"clean_up_tokenization_spaces": false,
|
15 |
+
"eos_token": "<|endoftext|>",
|
16 |
+
"extra_special_tokens": {},
|
17 |
+
"model_max_length": 1024,
|
18 |
+
"tokenizer_class": "GPT2Tokenizer",
|
19 |
+
"unk_token": "<|endoftext|>"
|
20 |
+
}
|
onnx_models/gpt2_onnx_quantized/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
onnx_models/opt_onnx/config.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_attn_implementation_autoset": true,
|
3 |
+
"_name_or_path": "facebook/opt-350m",
|
4 |
+
"_remove_final_layer_norm": false,
|
5 |
+
"activation_dropout": 0.0,
|
6 |
+
"activation_function": "relu",
|
7 |
+
"architectures": [
|
8 |
+
"OPTForCausalLM"
|
9 |
+
],
|
10 |
+
"attention_dropout": 0.0,
|
11 |
+
"bos_token_id": 2,
|
12 |
+
"do_layer_norm_before": false,
|
13 |
+
"dropout": 0.1,
|
14 |
+
"enable_bias": true,
|
15 |
+
"eos_token_id": 2,
|
16 |
+
"ffn_dim": 4096,
|
17 |
+
"hidden_size": 1024,
|
18 |
+
"init_std": 0.02,
|
19 |
+
"layer_norm_elementwise_affine": true,
|
20 |
+
"layerdrop": 0.0,
|
21 |
+
"max_position_embeddings": 2048,
|
22 |
+
"model_type": "opt",
|
23 |
+
"num_attention_heads": 16,
|
24 |
+
"num_hidden_layers": 24,
|
25 |
+
"pad_token_id": 1,
|
26 |
+
"prefix": "</s>",
|
27 |
+
"transformers_version": "4.48.3",
|
28 |
+
"use_cache": true,
|
29 |
+
"vocab_size": 50272,
|
30 |
+
"word_embed_proj_dim": 512
|
31 |
+
}
|
onnx_models/opt_onnx/generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 2,
|
4 |
+
"eos_token_id": 2,
|
5 |
+
"pad_token_id": 1,
|
6 |
+
"transformers_version": "4.48.3"
|
7 |
+
}
|
onnx_models/opt_onnx/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
onnx_models/opt_onnx/model.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:33dbbbda8ead8a71ec8ad090902faadf3b292e64f4641087efe17996e9b85aa9
|
3 |
+
size 1325122848
|
onnx_models/opt_onnx/special_tokens_map.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "</s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": true,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "</s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": true,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": {
|
17 |
+
"content": "<pad>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": true,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"unk_token": {
|
24 |
+
"content": "</s>",
|
25 |
+
"lstrip": false,
|
26 |
+
"normalized": true,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
}
|
30 |
+
}
|