agoor97 commited on
Commit
16ffc97
·
verified ·
1 Parent(s): 4c59cd9

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -34
  2. .gitignore +175 -0
  3. CMD.md +91 -0
  4. README.md +60 -0
  5. hf_upload.ipynb +358 -0
  6. old_scripts/convert_for_unity.py +1024 -0
  7. old_scripts/convert_single_model.py +492 -0
  8. old_scripts/convert_to_onnx.py +261 -0
  9. old_scripts/test_chat.py +402 -0
  10. onnx_models/bloom_onnx/config.json +32 -0
  11. onnx_models/bloom_onnx/generation_config.json +7 -0
  12. onnx_models/bloom_onnx/model.onnx +3 -0
  13. onnx_models/bloom_onnx/special_tokens_map.json +30 -0
  14. onnx_models/bloom_onnx/tokenizer.json +3 -0
  15. onnx_models/bloom_onnx/tokenizer_config.json +48 -0
  16. onnx_models/bloom_onnx_quantized/config.json +32 -0
  17. onnx_models/bloom_onnx_quantized/model_quantized.onnx +3 -0
  18. onnx_models/bloom_onnx_quantized/ort_config.json +33 -0
  19. onnx_models/bloom_onnx_quantized/special_tokens_map.json +30 -0
  20. onnx_models/bloom_onnx_quantized/tokenizer.json +3 -0
  21. onnx_models/bloom_onnx_quantized/tokenizer_config.json +48 -0
  22. onnx_models/falcon_onnx/config.json +41 -0
  23. onnx_models/falcon_onnx/generation_config.json +6 -0
  24. onnx_models/falcon_onnx/merges.txt +0 -0
  25. onnx_models/falcon_onnx/model.onnx +3 -0
  26. onnx_models/falcon_onnx/special_tokens_map.json +23 -0
  27. onnx_models/falcon_onnx/tokenizer.json +0 -0
  28. onnx_models/falcon_onnx/tokenizer_config.json +20 -0
  29. onnx_models/falcon_onnx/vocab.json +0 -0
  30. onnx_models/gpt2_onnx/config.json +41 -0
  31. onnx_models/gpt2_onnx/generation_config.json +6 -0
  32. onnx_models/gpt2_onnx/merges.txt +0 -0
  33. onnx_models/gpt2_onnx/model.onnx +3 -0
  34. onnx_models/gpt2_onnx/special_tokens_map.json +5 -0
  35. onnx_models/gpt2_onnx/tokenizer.json +0 -0
  36. onnx_models/gpt2_onnx/tokenizer_config.json +20 -0
  37. onnx_models/gpt2_onnx/vocab.json +0 -0
  38. onnx_models/gpt2_onnx_quantized/config.json +41 -0
  39. onnx_models/gpt2_onnx_quantized/merges.txt +0 -0
  40. onnx_models/gpt2_onnx_quantized/model_quantized.onnx +3 -0
  41. onnx_models/gpt2_onnx_quantized/ort_config.json +33 -0
  42. onnx_models/gpt2_onnx_quantized/special_tokens_map.json +23 -0
  43. onnx_models/gpt2_onnx_quantized/tokenizer.json +0 -0
  44. onnx_models/gpt2_onnx_quantized/tokenizer_config.json +20 -0
  45. onnx_models/gpt2_onnx_quantized/vocab.json +0 -0
  46. onnx_models/opt_onnx/config.json +31 -0
  47. onnx_models/opt_onnx/generation_config.json +7 -0
  48. onnx_models/opt_onnx/merges.txt +0 -0
  49. onnx_models/opt_onnx/model.onnx +3 -0
  50. onnx_models/opt_onnx/special_tokens_map.json +30 -0
.gitattributes CHANGED
@@ -1,35 +1,9 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.onnx filter=lfs diff=lfs merge=lfs -text
2
+ *.onnx_data filter=lfs diff=lfs merge=lfs -text
3
+
4
+ onnx_models/bloom_onnx/tokenizer.json filter=lfs diff=lfs merge=lfs -text
5
+ onnx_models/bloom_onnx_quantized/tokenizer.json filter=lfs diff=lfs merge=lfs -text
6
+ onnx_models/qwen_onnx/tokenizer.json filter=lfs diff=lfs merge=lfs -text
7
+ onnx_models/qwen_onnx_quantized/tokenizer.json filter=lfs diff=lfs merge=lfs -text
8
+ onnx_models/tinyllama_onnx/tokenizer.model filter=lfs diff=lfs merge=lfs -text
9
+ onnx_models/tinyllama_onnx_quantized/tokenizer.model filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+ *.onnx_data
CMD.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Qwen-0.5B Model
2
+ ``` bash
3
+ # Step 1: Export for text generation with past KV cache (better for chat)
4
+ echo "Exporting Qwen-0.5B..."
5
+ optimum-cli export onnx --model Qwen/Qwen1.5-0.5B --task text-generation-with-past onnx_models/qwen_onnx/
6
+
7
+ # Step 2: Quantize for ARM64 (Mobile target) using static INT8 quantization
8
+ echo "Quantizing Qwen-0.5B for ARM64 (Static)..."
9
+ optimum-cli onnxruntime quantize --onnx_model onnx_models/qwen_onnx/ --arm64 -o onnx_models/qwen_onnx_quantized/
10
+ ```
11
+
12
+ -----------------------------------
13
+
14
+ ### TinyLlama-1.1B
15
+ ``` bash
16
+ # Step 1: Export for text generation with past KV cache (better for chat)
17
+ echo "Exporting TinyLlama-1.1B..."
18
+ optimum-cli export onnx --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --task text-generation-with-past onnx_models/tinyllama_onnx/
19
+
20
+ # Step 2: Attempt Quantization for ARM64 (Static INT8)
21
+ # This is the step you mentioned takes too long or fails. Try it, but have the alternative ready.
22
+ echo "Attempting TinyLlama-1.1B quantization for ARM64 (Static)..."
23
+ optimum-cli onnxruntime quantize --onnx_model onnx_models/tinyllama_onnx/ --arm64 -o onnx_models/tinyllama_onnx_quantized/
24
+ ```
25
+
26
+ -----------------------------------
27
+
28
+ ### Phi-1.5 Model
29
+ ``` bash
30
+ # Step 1: Export for text generation with past KV cache (better for chat)
31
+ echo "Exporting Phi-1.5..."
32
+ optimum-cli export onnx --model microsoft/phi-1_5 --task text-generation-with-past onnx_models/phi_onnx/
33
+
34
+ # Step 2: Attempt Quantization for ARM64 (Static INT8) -- Failed with me (need much memory)
35
+ echo "Quantizing Phi-1.5 for ARM64 (Static)..."
36
+ optimum-cli onnxruntime quantize --onnx_model onnx_models/phi_onnx/ --arm64 -o onnx_models/phi_onnx_quantized/
37
+ ```
38
+
39
+ -----------------------------------
40
+
41
+ ### Falcon-1B Model
42
+ ``` bash
43
+ # Export
44
+ echo "Exporting Falcon-1B..."
45
+ optimum-cli export onnx --model tiiuae/falcon-rw-1b --task text-generation-with-past onnx_models/falcon_onnx/
46
+
47
+ # Quantize for ARM64 -- Failed with me (need much memory)
48
+ echo "Quantizing Falcon-1B for ARM64..."
49
+ optimum-cli onnxruntime quantize --onnx_model onnx_models/falcon_onnx/ --arm64 -o onnx_models/falcon_onnx_quantized/
50
+ ```
51
+
52
+ -----------------------------------
53
+
54
+ ### GPT-2Medium Model
55
+ ``` bash
56
+ # Export GPT2-Medium
57
+ echo "Exporting GPT2-Medium..."
58
+ optimum-cli export onnx --model gpt2-medium --task text-generation-with-past onnx_models/gpt2_onnx/
59
+
60
+ # Quantize for ARM64
61
+ echo "Quantizing GPT2-Medium for ARM64..."
62
+ optimum-cli onnxruntime quantize --onnx_model onnx_models/gpt2_onnx/ --arm64 -o onnx_models/gpt2_onnx_quantized/
63
+ ```
64
+
65
+ -----------------------------------
66
+
67
+ ### OPT-350M Model
68
+ ``` bash
69
+ # Export OPT-350M
70
+ echo "Exporting OPT-350M..."
71
+ optimum-cli export onnx --model facebook/opt-350m --task text-generation-with-past onnx_models/opt_onnx/
72
+
73
+ # Quantize for ARM64
74
+ echo "Quantizing OPT-350M for ARM64..."
75
+ optimum-cli onnxruntime quantize --onnx_model onnx_models/opt_onnx/ --arm64 -o onnx_models/opt_onnx_quantized/
76
+ ```
77
+
78
+ -----------------------------------
79
+
80
+ ### Bloom-560M Model
81
+ ``` bash
82
+ # Export Bloom-560M
83
+ echo "Exporting Bloom-560M..."
84
+ optimum-cli export onnx --model bigscience/bloom-560m --task text-generation-with-past onnx_models/bloom_onnx/
85
+
86
+ # Quantize for ARM64
87
+ echo "Quantizing Bloom-560M for ARM64..."
88
+ optimum-cli onnxruntime quantize --onnx_model onnx_models/bloom_onnx/ --arm64 -o onnx_models/bloom_onnx_quantized/
89
+ ```
90
+
91
+ -----------------------------------
README.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 LLM to ONNX Converter
2
+ > Convert small language models to ONNX format with **guaranteed reliability** for RAG and chatbot applications on resource-constrained hardware.
3
+
4
+ ## 📋 Overview
5
+ This repository provides scripts to convert small language models to ONNX format and create INT8 quantized versions for efficient deployment on resource-constrained devices. Perfect for mobile applications, Unity game engines, and embedded systems.
6
+
7
+ ## ✅ Tested Models
8
+ We've successfully tested the following models with example outputs:
9
+
10
+ | Model | Size | Quantized | Response Quality | Speed (sec) |
11
+ |-------|------|-----------|-----------------|-------------|
12
+ | Qwen-0.5B | 500M | ✅ | ❌ Poor | 8.37 |
13
+ | Qwen-0.5B | 500M | ❌ | ✅ Good | 15.69 |
14
+ | TinyLlama-1.1B | 1.1B | ✅ | ❌ Poor | 10.15 |
15
+ | TinyLlama-1.1B | 1.1B | ❌ | ✅ Good | 19.23 |
16
+ | Phi-1.5 | 1.3B | ❌ | ✅ Good | 15.32 |
17
+ | Falcon-RW-1B | 1B | ❌ | ✅ Good | 21.56 |
18
+ | GPT2-Medium | 355M | ✅ | ✅ Good | 6.27 |
19
+ | GPT2-Medium | 355M | ❌ | ✅ Good | 12.77 |
20
+ | OPT-350M | 350M | ✅ | ✅ Good | 4.33 |
21
+ | OPT-350M | 350M | ❌ | ✅ Good | 10.42 |
22
+ | Bloom-560M | 560M | ✅ | ❌ Poor | 11.93 |
23
+ | Bloom-560M | 560M | ❌ | ✅ Good | 34.38 |
24
+
25
+ ## 🌟 Recommendations
26
+ Based on our testing:
27
+ 1. **For best speed + quality:** OPT-350M (quantized) - fastest with good quality
28
+ 2. **For best overall quality:** Phi-1.5 (non-quantized) - excellent responses
29
+ 3. **For smallest size:** GPT2-Medium or OPT-350M (quantized) - small with good performance
30
+
31
+ ## 🚩 Key Findings
32
+ - Quantization provides ~2x speed improvement
33
+ - Smaller models (350-500M) quantize better than larger models (1B+)
34
+ - Some architectures (OPT, GPT2) handle quantization better than others
35
+
36
+ ## 📁 Repository Structure
37
+ ```
38
+ onnx_models/
39
+ ├── bloom_onnx/
40
+ ├── bloom_onnx_quantized/
41
+ ├── falcon_onnx/
42
+ ├── gpt2_onnx/
43
+ ├── gpt2_onnx_quantized/
44
+ ├── opt_onnx/
45
+ ├── opt_onnx_quantized/
46
+ ├── phi_onnx/
47
+ ├── qwen_onnx/
48
+ ├── qwen_onnx_quantized/
49
+ ├── tinyllama_onnx/
50
+ └── tinyllama_onnx_quantized/
51
+ ```
52
+
53
+ ## 📚 Requirements
54
+ - Python 3.8+
55
+ - optimum
56
+ - onnxruntime
57
+ - transformers
58
+ - numpy
59
+
60
+ ---------------
hf_upload.ipynb ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "22fbff0c",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "data": {
11
+ "application/vnd.jupyter.widget-view+json": {
12
+ "model_id": "a10a920b0f9749058ee8dd5ce613705a",
13
+ "version_major": 2,
14
+ "version_minor": 0
15
+ },
16
+ "text/plain": [
17
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
18
+ ]
19
+ },
20
+ "metadata": {},
21
+ "output_type": "display_data"
22
+ }
23
+ ],
24
+ "source": [
25
+ "from huggingface_hub import login\n",
26
+ "login()"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "id": "25711ffa",
33
+ "metadata": {},
34
+ "outputs": [
35
+ {
36
+ "name": "stderr",
37
+ "output_type": "stream",
38
+ "text": [
39
+ "/home/administrator/miniconda/lib/python3.12/site-packages/huggingface_hub/hf_api.py:9561: UserWarning: Warnings while validating metadata in README.md:\n",
40
+ "- empty or missing yaml metadata in repo card\n",
41
+ " warnings.warn(f\"Warnings while validating metadata in README.md:\\n{message}\")\n"
42
+ ]
43
+ },
44
+ {
45
+ "data": {
46
+ "application/vnd.jupyter.widget-view+json": {
47
+ "model_id": "09c1ecf794c84518803cca555425306a",
48
+ "version_major": 2,
49
+ "version_minor": 0
50
+ },
51
+ "text/plain": [
52
+ "model.onnx: 0%| | 0.00/798k [00:00<?, ?B/s]"
53
+ ]
54
+ },
55
+ "metadata": {},
56
+ "output_type": "display_data"
57
+ },
58
+ {
59
+ "data": {
60
+ "application/vnd.jupyter.widget-view+json": {
61
+ "model_id": "32f7fccea8414cac920db0d57af0e7d5",
62
+ "version_major": 2,
63
+ "version_minor": 0
64
+ },
65
+ "text/plain": [
66
+ "tokenizer.json: 0%| | 0.00/21.8M [00:00<?, ?B/s]"
67
+ ]
68
+ },
69
+ "metadata": {},
70
+ "output_type": "display_data"
71
+ },
72
+ {
73
+ "data": {
74
+ "application/vnd.jupyter.widget-view+json": {
75
+ "model_id": "ca28ba13ab1040d7bdf3c44430b5a266",
76
+ "version_major": 2,
77
+ "version_minor": 0
78
+ },
79
+ "text/plain": [
80
+ "model.onnx: 0%| | 0.00/655k [00:00<?, ?B/s]"
81
+ ]
82
+ },
83
+ "metadata": {},
84
+ "output_type": "display_data"
85
+ },
86
+ {
87
+ "data": {
88
+ "application/vnd.jupyter.widget-view+json": {
89
+ "model_id": "dc762c21c9a64897b160c8bc0a745942",
90
+ "version_major": 2,
91
+ "version_minor": 0
92
+ },
93
+ "text/plain": [
94
+ "Upload 18 LFS files: 0%| | 0/18 [00:00<?, ?it/s]"
95
+ ]
96
+ },
97
+ "metadata": {},
98
+ "output_type": "display_data"
99
+ },
100
+ {
101
+ "data": {
102
+ "application/vnd.jupyter.widget-view+json": {
103
+ "model_id": "2c80e14a6ec54618bda83086bd6ab6d3",
104
+ "version_major": 2,
105
+ "version_minor": 0
106
+ },
107
+ "text/plain": [
108
+ "model_quantized.onnx: 0%| | 0.00/561M [00:00<?, ?B/s]"
109
+ ]
110
+ },
111
+ "metadata": {},
112
+ "output_type": "display_data"
113
+ },
114
+ {
115
+ "data": {
116
+ "application/vnd.jupyter.widget-view+json": {
117
+ "model_id": "94a1fd0e86f74aa99cdce9b12d4b9bc3",
118
+ "version_major": 2,
119
+ "version_minor": 0
120
+ },
121
+ "text/plain": [
122
+ "tokenizer.json: 0%| | 0.00/21.8M [00:00<?, ?B/s]"
123
+ ]
124
+ },
125
+ "metadata": {},
126
+ "output_type": "display_data"
127
+ },
128
+ {
129
+ "data": {
130
+ "application/vnd.jupyter.widget-view+json": {
131
+ "model_id": "f4b966da10894d32815492fae6e891d1",
132
+ "version_major": 2,
133
+ "version_minor": 0
134
+ },
135
+ "text/plain": [
136
+ "model.onnx: 0%| | 0.00/1.42G [00:00<?, ?B/s]"
137
+ ]
138
+ },
139
+ "metadata": {},
140
+ "output_type": "display_data"
141
+ },
142
+ {
143
+ "data": {
144
+ "application/vnd.jupyter.widget-view+json": {
145
+ "model_id": "28bf8c5b822c40639b24d588e968dadd",
146
+ "version_major": 2,
147
+ "version_minor": 0
148
+ },
149
+ "text/plain": [
150
+ "model_quantized.onnx: 0%| | 0.00/357M [00:00<?, ?B/s]"
151
+ ]
152
+ },
153
+ "metadata": {},
154
+ "output_type": "display_data"
155
+ },
156
+ {
157
+ "data": {
158
+ "application/vnd.jupyter.widget-view+json": {
159
+ "model_id": "483803bf88fc4e41ae696925d39ae6e2",
160
+ "version_major": 2,
161
+ "version_minor": 0
162
+ },
163
+ "text/plain": [
164
+ "model.onnx: 0%| | 0.00/1.33G [00:00<?, ?B/s]"
165
+ ]
166
+ },
167
+ "metadata": {},
168
+ "output_type": "display_data"
169
+ },
170
+ {
171
+ "data": {
172
+ "application/vnd.jupyter.widget-view+json": {
173
+ "model_id": "55056aef07564537b9d8f6b6e2d9d87c",
174
+ "version_major": 2,
175
+ "version_minor": 0
176
+ },
177
+ "text/plain": [
178
+ "model_quantized.onnx: 0%| | 0.00/333M [00:00<?, ?B/s]"
179
+ ]
180
+ },
181
+ "metadata": {},
182
+ "output_type": "display_data"
183
+ },
184
+ {
185
+ "data": {
186
+ "application/vnd.jupyter.widget-view+json": {
187
+ "model_id": "0442e7fbdea74b89950176cef30930dc",
188
+ "version_major": 2,
189
+ "version_minor": 0
190
+ },
191
+ "text/plain": [
192
+ "model.onnx: 0%| | 0.00/814k [00:00<?, ?B/s]"
193
+ ]
194
+ },
195
+ "metadata": {},
196
+ "output_type": "display_data"
197
+ },
198
+ {
199
+ "data": {
200
+ "application/vnd.jupyter.widget-view+json": {
201
+ "model_id": "f46c2c40bc2c4686920bd0deb3259df6",
202
+ "version_major": 2,
203
+ "version_minor": 0
204
+ },
205
+ "text/plain": [
206
+ "model.onnx: 0%| | 0.00/1.86G [00:00<?, ?B/s]"
207
+ ]
208
+ },
209
+ "metadata": {},
210
+ "output_type": "display_data"
211
+ },
212
+ {
213
+ "data": {
214
+ "application/vnd.jupyter.widget-view+json": {
215
+ "model_id": "bb2d49b750ef4399ae2defea1fb1593d",
216
+ "version_major": 2,
217
+ "version_minor": 0
218
+ },
219
+ "text/plain": [
220
+ "tokenizer.json: 0%| | 0.00/11.4M [00:00<?, ?B/s]"
221
+ ]
222
+ },
223
+ "metadata": {},
224
+ "output_type": "display_data"
225
+ },
226
+ {
227
+ "data": {
228
+ "application/vnd.jupyter.widget-view+json": {
229
+ "model_id": "418c0351956b4a75854237f5ec3077c1",
230
+ "version_major": 2,
231
+ "version_minor": 0
232
+ },
233
+ "text/plain": [
234
+ "model_quantized.onnx: 0%| | 0.00/466M [00:00<?, ?B/s]"
235
+ ]
236
+ },
237
+ "metadata": {},
238
+ "output_type": "display_data"
239
+ },
240
+ {
241
+ "data": {
242
+ "application/vnd.jupyter.widget-view+json": {
243
+ "model_id": "bd17923e98b741e6bc159b25d2a25717",
244
+ "version_major": 2,
245
+ "version_minor": 0
246
+ },
247
+ "text/plain": [
248
+ "tokenizer.json: 0%| | 0.00/11.4M [00:00<?, ?B/s]"
249
+ ]
250
+ },
251
+ "metadata": {},
252
+ "output_type": "display_data"
253
+ },
254
+ {
255
+ "data": {
256
+ "application/vnd.jupyter.widget-view+json": {
257
+ "model_id": "93080e4239d84969a69f2c2f86b658b3",
258
+ "version_major": 2,
259
+ "version_minor": 0
260
+ },
261
+ "text/plain": [
262
+ "model.onnx: 0%| | 0.00/987k [00:00<?, ?B/s]"
263
+ ]
264
+ },
265
+ "metadata": {},
266
+ "output_type": "display_data"
267
+ },
268
+ {
269
+ "data": {
270
+ "application/vnd.jupyter.widget-view+json": {
271
+ "model_id": "ea2fda33732745e48f7e1a6e1dae5ecc",
272
+ "version_major": 2,
273
+ "version_minor": 0
274
+ },
275
+ "text/plain": [
276
+ "tokenizer.model: 0%| | 0.00/500k [00:00<?, ?B/s]"
277
+ ]
278
+ },
279
+ "metadata": {},
280
+ "output_type": "display_data"
281
+ },
282
+ {
283
+ "data": {
284
+ "application/vnd.jupyter.widget-view+json": {
285
+ "model_id": "851fad297317474495592a8db14aabf5",
286
+ "version_major": 2,
287
+ "version_minor": 0
288
+ },
289
+ "text/plain": [
290
+ "model_quantized.onnx: 0%| | 0.00/1.10G [00:00<?, ?B/s]"
291
+ ]
292
+ },
293
+ "metadata": {},
294
+ "output_type": "display_data"
295
+ },
296
+ {
297
+ "data": {
298
+ "application/vnd.jupyter.widget-view+json": {
299
+ "model_id": "18d5836cab4f440799db945c1af7cfeb",
300
+ "version_major": 2,
301
+ "version_minor": 0
302
+ },
303
+ "text/plain": [
304
+ "tokenizer.model: 0%| | 0.00/500k [00:00<?, ?B/s]"
305
+ ]
306
+ },
307
+ "metadata": {},
308
+ "output_type": "display_data"
309
+ }
310
+ ],
311
+ "source": [
312
+ "from huggingface_hub import upload_folder\n",
313
+ "\n",
314
+ "# Path to your offline-models directory\n",
315
+ "folder_path = \"/home/administrator/offline-rag-model/offline-models\"\n",
316
+ "\n",
317
+ "# Your Hugging Face repository name\n",
318
+ "repo_name = \"onnx-models\"\n",
319
+ "\n",
320
+ "# Upload all files to Hugging Face\n",
321
+ "upload_folder(\n",
322
+ " folder_path=folder_path,\n",
323
+ " repo_id=f\"agoor97/{repo_name}\",\n",
324
+ " repo_type=\"model\",\n",
325
+ ")"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "id": "67055a1c",
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": []
335
+ }
336
+ ],
337
+ "metadata": {
338
+ "kernelspec": {
339
+ "display_name": "base",
340
+ "language": "python",
341
+ "name": "python3"
342
+ },
343
+ "language_info": {
344
+ "codemirror_mode": {
345
+ "name": "ipython",
346
+ "version": 3
347
+ },
348
+ "file_extension": ".py",
349
+ "mimetype": "text/x-python",
350
+ "name": "python",
351
+ "nbconvert_exporter": "python",
352
+ "pygments_lexer": "ipython3",
353
+ "version": "3.12.9"
354
+ }
355
+ },
356
+ "nbformat": 4,
357
+ "nbformat_minor": 5
358
+ }
old_scripts/convert_for_unity.py ADDED
@@ -0,0 +1,1024 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import sys
4
+ import time
5
+ import logging
6
+ import traceback
7
+ import torch
8
+ import warnings
9
+ import numpy as np
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ from transformers.generation import GenerationConfig
12
+ from tqdm import tqdm
13
+ from onnxruntime.quantization import quantize_dynamic, QuantType
14
+
15
+ # Configure logging
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format='%(asctime)s - %(levelname)s - %(message)s',
19
+ datefmt='%Y-%m-%d %H:%M:%S'
20
+ )
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Suppress unhelpful warnings
24
+ warnings.filterwarnings("ignore", category=UserWarning)
25
+
26
+
27
+ class GenerationWrapper(torch.nn.Module):
28
+ """
29
+ Wrapper for model export that handles generation properly.
30
+ This ensures the model can be correctly used for text generation.
31
+ """
32
+ def __init__(self, model):
33
+ super().__init__()
34
+ self.model = model
35
+ self.config = model.config
36
+
37
+ def forward(self, input_ids, attention_mask=None):
38
+ # Return only the logits to avoid complex structures
39
+ with torch.no_grad():
40
+ try:
41
+ # Standard approach for most models
42
+ outputs = self.model(
43
+ input_ids=input_ids,
44
+ attention_mask=attention_mask,
45
+ use_cache=False,
46
+ return_dict=True
47
+ )
48
+ return outputs.logits
49
+ except Exception as e:
50
+ logger.warning(f"Standard forward pass failed, trying fallback: {str(e)}")
51
+ # Fallback for models with different API
52
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
53
+ if hasattr(outputs, 'logits'):
54
+ return outputs.logits
55
+ elif isinstance(outputs, tuple) and len(outputs) > 0:
56
+ return outputs[0] # First element is typically logits
57
+ else:
58
+ raise ValueError("Could not extract logits from model outputs")
59
+
60
+ def verify_model_generation(model, tokenizer, device="cpu"):
61
+ """Test model generation capabilities before export"""
62
+ model.eval()
63
+
64
+ # Use a chat-like prompt for better testing
65
+ prompt = "User: Hello, how are you today?\nAssistant:"
66
+
67
+ logger.info("Testing model generation...")
68
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
69
+
70
+ # Configure generation parameters
71
+ gen_config = GenerationConfig(
72
+ max_length=100,
73
+ do_sample=True,
74
+ temperature=0.7,
75
+ num_return_sequences=1,
76
+ )
77
+
78
+ try:
79
+ # Try generation
80
+ with torch.no_grad():
81
+ outputs = model.generate(
82
+ **inputs,
83
+ generation_config=gen_config
84
+ )
85
+
86
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
87
+ logger.info(f"Test generation result: {generated_text}")
88
+
89
+ if len(generated_text) <= len(prompt):
90
+ logger.warning("Generation output is not longer than input prompt!")
91
+
92
+ return True
93
+ except Exception as e:
94
+ logger.error(f"Generation test failed: {str(e)}")
95
+ return False
96
+
97
+ def test_onnx_model(onnx_path, tokenizer):
98
+ """Verify the ONNX model can be loaded and run"""
99
+ try:
100
+ import onnxruntime as ort
101
+
102
+ logger.info("Testing ONNX model inference...")
103
+ session = ort.InferenceSession(onnx_path)
104
+
105
+ # Get input and output names
106
+ input_names = [input.name for input in session.get_inputs()]
107
+ output_names = [output.name for output in session.get_outputs()]
108
+
109
+ # Create test input
110
+ prompt = "User: Hello, how are you?\nAssistant:"
111
+ inputs = tokenizer(prompt, return_tensors="np")
112
+
113
+ # Prepare input dict
114
+ onnx_inputs = {}
115
+ for name in input_names:
116
+ if name == "input_ids" and "input_ids" in inputs:
117
+ onnx_inputs[name] = inputs["input_ids"]
118
+ elif name == "attention_mask" and "attention_mask" in inputs:
119
+ onnx_inputs[name] = inputs["attention_mask"]
120
+
121
+ # Run inference
122
+ outputs = session.run(output_names, onnx_inputs)
123
+
124
+ # Check output shape
125
+ logits = outputs[0]
126
+ logger.info(f"ONNX model output shape: {logits.shape}")
127
+
128
+ if logits.shape[0] != 1 or logits.shape[1] != inputs["input_ids"].shape[1]:
129
+ logger.warning("Output shape doesn't match expected dimensions!")
130
+
131
+ # Test next token prediction
132
+ next_token_logits = logits[0, -1, :]
133
+ next_token_id = np.argmax(next_token_logits)
134
+ next_token = tokenizer.decode([next_token_id])
135
+ logger.info(f"Next predicted token: '{next_token}'")
136
+
137
+ return True
138
+ except Exception as e:
139
+ logger.error(f"ONNX model test failed: {str(e)}")
140
+ return False
141
+
142
+ def post_process_onnx_for_unity(onnx_path):
143
+ """
144
+ Post-process ONNX model to be compatible with Unity Sentis
145
+ using only core onnx functionality (no onnxsim)
146
+ """
147
+ try:
148
+ import onnx
149
+
150
+ logger.info("Post-processing ONNX model for Unity compatibility...")
151
+
152
+ # First, create a backup of the original model
153
+ backup_path = onnx_path.replace(".onnx", "_original.onnx")
154
+ import shutil
155
+ shutil.copy(onnx_path, backup_path)
156
+ logger.info(f"Original model backed up to {backup_path}")
157
+
158
+ # Load the model
159
+ model = onnx.load(onnx_path)
160
+
161
+ # Basic model checks and optimizations
162
+ try:
163
+ # Check model validity
164
+ onnx.checker.check_model(model)
165
+ logger.info("✓ Model structure validated successfully")
166
+
167
+ # Apply shape inference
168
+ inferred_model = onnx.shape_inference.infer_shapes(model)
169
+ onnx.save(inferred_model, onnx_path)
170
+ logger.info("✓ Applied shape inference")
171
+
172
+ except Exception as e:
173
+ logger.warning(f"Model validation/optimization error (continuing): {str(e)}")
174
+
175
+ return True
176
+
177
+ except Exception as e:
178
+ logger.warning(f"ONNX post-processing error (skipping): {str(e)}")
179
+ return False
180
+
181
+ def is_architecture_compatible(model_id):
182
+ """
183
+ Check if the model architecture is expected to be compatible with ONNX opset 11
184
+ """
185
+ model_id_lower = model_id.lower()
186
+
187
+ # Models known to work with opset 11
188
+ compatible_architectures = [
189
+ "gpt2", "distilgpt2", "opt-125m", "opt-350m",
190
+ "pythia-70m", "pythia-160m", "rwkv", "gpt-neo"
191
+ ]
192
+
193
+ # Models likely requiring higher opsets (usually 14+)
194
+ incompatible_architectures = [
195
+ "llama", "mistral", "mixtral", "tinyllama", "phi-2",
196
+ "gemma", "falcon", "bloom"
197
+ ]
198
+
199
+ # Check for compatibility
200
+ for arch in compatible_architectures:
201
+ if arch in model_id_lower:
202
+ return True, 11
203
+
204
+ # Check for known incompatible architectures
205
+ for arch in incompatible_architectures:
206
+ if arch in model_id_lower:
207
+ return False, 14
208
+
209
+ # For phi-1 models, use opset 14 but mark as potentially compatible
210
+ if "phi-1" in model_id_lower:
211
+ return True, 14
212
+
213
+ # Default to opset 14 for unknown architectures
214
+ return False, 14
215
+
216
+ def setup_chat_template(model_id, tokenizer):
217
+ """
218
+ Setup appropriate chat template based on model architecture
219
+ """
220
+ model_id_lower = model_id.lower()
221
+
222
+ # Try to setup chat template if it doesn't have one
223
+ try:
224
+ if not hasattr(tokenizer, "chat_template") or tokenizer.chat_template is None:
225
+ logger.info("Setting up chat template for improved conversations...")
226
+
227
+ # Determine chat template based on model
228
+ if "gpt2" in model_id_lower or "pythia" in model_id_lower or "opt" in model_id_lower:
229
+ # Simple template for base models
230
+ chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\nHuman: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAI: {{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\nAI: {% endif %}"
231
+ tokenizer.chat_template = chat_template
232
+ logger.info("✓ Added simple Human/AI chat template")
233
+
234
+ elif "phi" in model_id_lower:
235
+ # Microsoft Phi models template
236
+ chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\nHuman: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\nAssistant: {% endif %}"
237
+ tokenizer.chat_template = chat_template
238
+ logger.info("✓ Added Phi-style Human/Assistant chat template")
239
+
240
+ elif "rwkv" in model_id_lower:
241
+ # RWKV template
242
+ chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nBot: {{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\nBot: {% endif %}"
243
+ tokenizer.chat_template = chat_template
244
+ logger.info("✓ Added RWKV-style User/Bot chat template")
245
+
246
+ except Exception as e:
247
+ logger.warning(f"Couldn't setup chat template: {str(e)}")
248
+ logger.info("Chat template setup will need to be handled in Unity")
249
+
250
+ def convert_model(model_id, output_dir="./onnx_models", seq_length=32, quantize=True, force_opset=None):
251
+ """
252
+ Convert a model to ONNX format with focus on Unity compatibility.
253
+
254
+ Args:
255
+ model_id: HuggingFace model ID or path
256
+ output_dir: Directory to save the model
257
+ seq_length: Input sequence length for export
258
+ quantize: Whether to quantize the model to INT8
259
+ force_opset: Force a specific ONNX opset version
260
+
261
+ Returns:
262
+ bool: Success status
263
+ """
264
+ start_time = time.time()
265
+
266
+ # Check model architecture for compatibility
267
+ is_compatible, recommended_opset = is_architecture_compatible(model_id)
268
+
269
+ # Use forced opset if provided, otherwise use recommended
270
+ opset_version = force_opset if force_opset is not None else recommended_opset
271
+
272
+ # Warn if using a model that might not be compatible with Unity
273
+ if not is_compatible and opset_version < 14:
274
+ logger.warning(f"⚠ Model {model_id} may not be compatible with opset {opset_version}")
275
+ logger.warning(f"⚠ Recommended opset for this model: {recommended_opset}")
276
+ logger.warning(f"⚠ You can force a higher opset with --opset {recommended_opset}")
277
+
278
+ logger.info(f"\n{'=' * 60}")
279
+ logger.info(f"Converting {model_id} to ONNX for Unity (opset {opset_version})")
280
+ logger.info(f"{'=' * 60}")
281
+
282
+ # Create output directory
283
+ model_name = model_id.split("/")[-1]
284
+ model_dir = os.path.join(output_dir, model_name)
285
+ os.makedirs(model_dir, exist_ok=True)
286
+
287
+ try:
288
+ # Step 1: Load tokenizer
289
+ logger.info("Step 1/7: Loading tokenizer...")
290
+
291
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
292
+ if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'):
293
+ logger.info("Adding pad_token = eos_token")
294
+ tokenizer.pad_token = tokenizer.eos_token
295
+
296
+ # Setup chat template for better conversation formatting
297
+ setup_chat_template(model_id, tokenizer)
298
+
299
+ # Save tokenizer
300
+ tokenizer.save_pretrained(model_dir)
301
+ logger.info(f"✓ Tokenizer saved to {model_dir}")
302
+
303
+ # Step 2: Load model with reliability optimizations
304
+ logger.info("Step 2/7: Loading model...")
305
+
306
+ # Clean memory
307
+ gc.collect()
308
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
309
+
310
+ # Determine device
311
+ device = "cuda" if torch.cuda.is_available() else "cpu"
312
+
313
+ # Load model with full precision
314
+ try:
315
+ model = AutoModelForCausalLM.from_pretrained(
316
+ model_id,
317
+ torch_dtype=torch.float32, # Use full precision for reliability
318
+ low_cpu_mem_usage=True, # Reduce memory usage
319
+ device_map=device # Use CUDA if available
320
+ )
321
+ except Exception as e:
322
+ logger.warning(f"Standard loading failed, trying with 'trust_remote_code=True': {str(e)}")
323
+ # Some models (like RWKV) need trust_remote_code
324
+ model = AutoModelForCausalLM.from_pretrained(
325
+ model_id,
326
+ torch_dtype=torch.float32,
327
+ low_cpu_mem_usage=True,
328
+ device_map=device,
329
+ trust_remote_code=True
330
+ )
331
+
332
+ # Save config
333
+ model.config.save_pretrained(model_dir)
334
+ logger.info(f"✓ Model config saved to {model_dir}")
335
+
336
+ # Step 3: Verify model can generate chat responses
337
+ logger.info("Step 3/7: Validating chat capabilities...")
338
+
339
+ if not verify_model_generation(model, tokenizer, device):
340
+ logger.warning("⚠ Model chat test didn't complete successfully")
341
+ logger.info("Continuing with export anyway...")
342
+
343
+ # Step 4: Export to ONNX
344
+ logger.info(f"Step 4/7: Exporting to ONNX format with opset {opset_version}...")
345
+
346
+ # Wrap model with generation-optimized interface
347
+ wrapped_model = GenerationWrapper(model)
348
+ wrapped_model.eval()
349
+
350
+ # Clean memory again
351
+ gc.collect()
352
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
353
+
354
+ # Export to ONNX with appropriate opset version
355
+ onnx_path = os.path.join(model_dir, "model.onnx")
356
+
357
+ # Create minimal input
358
+ batch_size = 1
359
+ dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long)
360
+ attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long)
361
+
362
+ # Move tensors to correct device
363
+ dummy_input = dummy_input.to(device)
364
+ attention_mask = attention_mask.to(device)
365
+
366
+ # Export to ONNX with required opset
367
+ with torch.no_grad():
368
+ torch.onnx.export(
369
+ wrapped_model, # Wrapped model
370
+ (dummy_input, attention_mask), # Input tensors
371
+ onnx_path, # Output path
372
+ export_params=True, # Store weights
373
+ opset_version=opset_version, # Required opset version
374
+ do_constant_folding=True, # Optimize constants
375
+ input_names=['input_ids', 'attention_mask'], # Input names
376
+ output_names=['logits'], # Output name
377
+ dynamic_axes={ # Dynamic dimensions
378
+ 'input_ids': {0: 'batch_size', 1: 'sequence'},
379
+ 'attention_mask': {0: 'batch_size', 1: 'sequence'},
380
+ 'logits': {0: 'batch_size', 1: 'sequence'}
381
+ }
382
+ )
383
+
384
+ # Clean up to save memory
385
+ del model
386
+ del wrapped_model
387
+ gc.collect()
388
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
389
+
390
+ # Verify export success
391
+ if os.path.exists(onnx_path):
392
+ size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
393
+ logger.info(f"✓ ONNX model saved to {onnx_path}")
394
+ logger.info(f"✓ Original size: {size_mb:.2f} MB")
395
+
396
+ # Step 5: Post-process the ONNX model for better Unity compatibility
397
+ logger.info("Step 5/7: Post-processing ONNX model for Unity compatibility...")
398
+
399
+ # Try to post-process model for Unity
400
+ try:
401
+ post_process_onnx_for_unity(onnx_path)
402
+ except Exception as e:
403
+ logger.warning(f"Post-processing failed (non-critical): {str(e)}")
404
+
405
+ # Test ONNX model
406
+ test_onnx_model(onnx_path, tokenizer)
407
+
408
+ # Step 6: Quantize the model (optional)
409
+ if quantize:
410
+ logger.info("Step 6/7: Applying INT8 quantization...")
411
+ quant_path = onnx_path.replace(".onnx", "_quantized.onnx")
412
+
413
+ try:
414
+ with tqdm(total=100, desc="Quantizing") as pbar:
415
+ # Update progress callback
416
+ def update_progress(x):
417
+ pbar.update(1)
418
+
419
+ # Apply quantization
420
+ quantize_dynamic(
421
+ model_input=onnx_path,
422
+ model_output=quant_path,
423
+ per_channel=False,
424
+ reduce_range=False,
425
+ weight_type=QuantType.QInt8,
426
+ optimize_model=True,
427
+ use_external_data_format=False
428
+ )
429
+
430
+ pbar.update(100) # Ensure progress reaches 100%
431
+
432
+ if os.path.exists(quant_path):
433
+ quant_size = os.path.getsize(quant_path) / (1024 * 1024)
434
+ logger.info(f"✓ Quantized size: {quant_size:.2f} MB")
435
+ logger.info(f"✓ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%")
436
+
437
+ # Test the quantized model
438
+ test_onnx_model(quant_path, tokenizer)
439
+
440
+ # Rename original as backup
441
+ backup_path = onnx_path.replace(".onnx", "_fp32.onnx")
442
+ os.rename(onnx_path, backup_path)
443
+
444
+ # Replace original with quantized
445
+ os.rename(quant_path, onnx_path)
446
+ logger.info("✓ Original model preserved as *_fp32.onnx")
447
+ logger.info("✓ Replaced original with quantized version")
448
+ else:
449
+ logger.warning("⚠ Quantized file not created, using original")
450
+ except Exception as e:
451
+ logger.error(f"⚠ Quantization error: {str(e)}")
452
+ logger.info("⚠ Using original model without quantization")
453
+ else:
454
+ logger.info("Step 6/7: Skipping quantization as requested")
455
+
456
+ # Step 7: Generate Unity integration examples
457
+ logger.info("Step 7/7: Generating Unity integration examples...")
458
+
459
+ # Create a Unity integration example
460
+ unity_example_path = os.path.join(model_dir, "unity_integration.cs")
461
+ with open(unity_example_path, 'w') as f:
462
+ f.write("""
463
+ using UnityEngine;
464
+ using Unity.Sentis;
465
+ using System.Collections.Generic;
466
+ using System.Linq;
467
+ using System.Text;
468
+ using System.Threading.Tasks;
469
+
470
+ public class ONNXChatbot : MonoBehaviour
471
+ {
472
+ [SerializeField] private ModelAsset modelAsset;
473
+ [SerializeField] private TextAsset tokenizerVocabJson;
474
+ [SerializeField] private int maxTokens = 50;
475
+ [SerializeField] private float temperature = 0.7f;
476
+
477
+ private IWorker worker;
478
+ private Dictionary<string, Tensor> inputs;
479
+ private SimpleTokenizer tokenizer;
480
+ private bool isGenerating = false;
481
+
482
+ void Start()
483
+ {
484
+ // Initialize the model
485
+ var model = ModelLoader.Load(modelAsset);
486
+ worker = WorkerFactory.CreateWorker(WorkerFactory.Type.ComputePrecompiled, model);
487
+
488
+ // Initialize tokenizer
489
+ tokenizer = new SimpleTokenizer(tokenizerVocabJson.text);
490
+
491
+ // Prepare for inference
492
+ inputs = new Dictionary<string, Tensor>();
493
+
494
+ Debug.Log("Model and tokenizer initialized successfully.");
495
+ }
496
+
497
+ public async Task<string> GenerateResponseAsync(string userMessage)
498
+ {
499
+ if (isGenerating)
500
+ {
501
+ Debug.LogWarning("Already generating a response. Please wait.");
502
+ return "Already generating a response. Please wait.";
503
+ }
504
+
505
+ isGenerating = true;
506
+
507
+ try
508
+ {
509
+ // Format prompt with chat template
510
+ string prompt = FormatChatPrompt(userMessage);
511
+ Debug.Log($"Formatted prompt: {prompt}");
512
+
513
+ // Tokenize input
514
+ var tokenIds = tokenizer.Encode(prompt);
515
+ Debug.Log($"Encoded to {tokenIds.Length} tokens");
516
+
517
+ if (tokenIds.Length > 0)
518
+ {
519
+ // Generate response token by token
520
+ StringBuilder responseBuilder = new StringBuilder();
521
+ List<int> currentIds = tokenIds.ToList();
522
+
523
+ for (int i = 0; i < maxTokens; i++)
524
+ {
525
+ // Make sure we don't exceed the model's context window
526
+ if (currentIds.Count > 1024)
527
+ {
528
+ // If too long, keep only the last 1024 tokens
529
+ currentIds = currentIds.Skip(currentIds.Count - 1024).Take(1024).ToList();
530
+ }
531
+
532
+ // Create tensors for current sequence
533
+ using (var inputIdsTensor = new TensorInt(new TensorShape(1, currentIds.Count), currentIds.ToArray()))
534
+ using (var attentionMaskTensor = new TensorInt(new TensorShape(1, currentIds.Count), Enumerable.Repeat(1, currentIds.Count).ToArray()))
535
+ {
536
+ // Run inference
537
+ inputs.Clear();
538
+ inputs["input_ids"] = inputIdsTensor;
539
+ inputs["attention_mask"] = attentionMaskTensor;
540
+
541
+ worker.Execute(inputs);
542
+ var logits = worker.PeekOutput() as TensorFloat;
543
+
544
+ // Get next token prediction
545
+ int nextToken = SampleNextToken(logits, currentIds, temperature);
546
+
547
+ // If we hit the end token or a newline after content, stop
548
+ if (nextToken == tokenizer.EosToken ||
549
+ (i > 0 && nextToken == tokenizer.NewlineToken))
550
+ {
551
+ break;
552
+ }
553
+
554
+ // Add token to current sequence for next iteration
555
+ currentIds.Add(nextToken);
556
+
557
+ // Decode the latest token
558
+ string newToken = tokenizer.Decode(new[] { nextToken });
559
+ responseBuilder.Append(newToken);
560
+
561
+ // For smoother output, yield every few tokens
562
+ if (i % 5 == 0)
563
+ {
564
+ await Task.Delay(1);
565
+ }
566
+ }
567
+ }
568
+
569
+ // Return the full response, without the prompt
570
+ string fullResponse = responseBuilder.ToString();
571
+ return CleanResponse(fullResponse);
572
+ }
573
+ else
574
+ {
575
+ Debug.LogError("Tokenization failed: empty token list");
576
+ return "Sorry, I couldn't process that input.";
577
+ }
578
+ }
579
+ catch (System.Exception ex)
580
+ {
581
+ Debug.LogError($"Generation error: {ex.Message}\\n{ex.StackTrace}");
582
+ return "Sorry, an error occurred while generating a response.";
583
+ }
584
+ finally
585
+ {
586
+ isGenerating = false;
587
+ }
588
+ }
589
+
590
+ private string FormatChatPrompt(string userMessage)
591
+ {
592
+ // You may need to adjust this template based on your specific model
593
+ return $"User: {userMessage}\\nAssistant:";
594
+ }
595
+
596
+ private string CleanResponse(string response)
597
+ {
598
+ // Extract only the Assistant's response
599
+ int assistantPrefix = response.IndexOf("Assistant:");
600
+ if (assistantPrefix >= 0)
601
+ {
602
+ response = response.Substring(assistantPrefix + "Assistant:".Length).Trim();
603
+ }
604
+
605
+ // Stop at any "User:" marker if present
606
+ int nextUser = response.IndexOf("User:");
607
+ if (nextUser >= 0)
608
+ {
609
+ response = response.Substring(0, nextUser).Trim();
610
+ }
611
+
612
+ return response;
613
+ }
614
+
615
+ private int SampleNextToken(TensorFloat logits, List<int> currentInputs, float temp)
616
+ {
617
+ // Get logits for the last position
618
+ int lastPos = currentInputs.Count - 1;
619
+ int vocabSize = logits.shape.channels;
620
+
621
+ // Prepare array for logits
622
+ float[] lastLogits = new float[vocabSize];
623
+
624
+ // Extract logits for the last token position
625
+ for (int i = 0; i < vocabSize; i++)
626
+ {
627
+ lastLogits[i] = logits[0, lastPos, i];
628
+ }
629
+
630
+ // Simple temperature-based sampling
631
+ if (temp <= 0.0f)
632
+ {
633
+ // Greedy sampling (argmax)
634
+ int maxIndex = 0;
635
+ float maxValue = lastLogits[0];
636
+
637
+ for (int i = 1; i < vocabSize; i++)
638
+ {
639
+ if (lastLogits[i] > maxValue)
640
+ {
641
+ maxValue = lastLogits[i];
642
+ maxIndex = i;
643
+ }
644
+ }
645
+
646
+ return maxIndex;
647
+ }
648
+ else
649
+ {
650
+ // Temperature sampling
651
+ // Apply temperature
652
+ for (int i = 0; i < vocabSize; i++)
653
+ {
654
+ lastLogits[i] /= temp;
655
+ }
656
+
657
+ // Softmax
658
+ float maxLogit = lastLogits.Max();
659
+ float sum = 0.0f;
660
+
661
+ for (int i = 0; i < vocabSize; i++)
662
+ {
663
+ lastLogits[i] = Mathf.Exp(lastLogits[i] - maxLogit);
664
+ sum += lastLogits[i];
665
+ }
666
+
667
+ for (int i = 0; i < vocabSize; i++)
668
+ {
669
+ lastLogits[i] /= sum;
670
+ }
671
+
672
+ // Sample from distribution
673
+ float random = Random.value;
674
+ float cumulativeProb = 0.0f;
675
+
676
+ for (int i = 0; i < vocabSize; i++)
677
+ {
678
+ cumulativeProb += lastLogits[i];
679
+ if (random < cumulativeProb)
680
+ {
681
+ return i;
682
+ }
683
+ }
684
+
685
+ // Fallback to last token if sampling fails
686
+ return vocabSize - 1;
687
+ }
688
+ }
689
+
690
+ void OnDestroy()
691
+ {
692
+ worker?.Dispose();
693
+ }
694
+ }
695
+
696
+ // Simple tokenizer implementation for Unity
697
+ public class SimpleTokenizer
698
+ {
699
+ private Dictionary<string, int> vocab;
700
+ private Dictionary<int, string> reversedVocab;
701
+
702
+ public int PadToken { get; private set; }
703
+ public int EosToken { get; private set; }
704
+ public int BosToken { get; private set; }
705
+ public int NewlineToken { get; private set; }
706
+
707
+ public SimpleTokenizer(string vocabJson)
708
+ {
709
+ // Parse the vocabulary from JSON
710
+ vocab = new Dictionary<string, int>();
711
+
712
+ // Simple JSON parsing (you'll need a proper JSON parser in production)
713
+ string[] entries = vocabJson.Split(new[] { '\\n', '{', '}', '\"', ':', ',' },
714
+ System.StringSplitOptions.RemoveEmptyEntries);
715
+
716
+ for (int i = 0; i < entries.Length - 1; i += 2)
717
+ {
718
+ string token = entries[i].Trim();
719
+ if (int.TryParse(entries[i + 1].Trim(), out int id))
720
+ {
721
+ vocab[token] = id;
722
+ }
723
+ }
724
+
725
+ // Create reversed vocabulary for decoding
726
+ reversedVocab = vocab.ToDictionary(kv => kv.Value, kv => kv.Key);
727
+
728
+ // Find special tokens
729
+ SetSpecialTokens();
730
+
731
+ Debug.Log($"Tokenizer initialized with {vocab.Count} tokens");
732
+ }
733
+
734
+ private void SetSpecialTokens()
735
+ {
736
+ // Try to find standard special tokens
737
+ PadToken = FindToken(new[] { "<pad>", "[PAD]", "<|endoftext|>" });
738
+ EosToken = FindToken(new[] { "</s>", "<|endoftext|>", "[EOS]", "<eos>" });
739
+ BosToken = FindToken(new[] { "<s>", "<|startoftext|>", "[BOS]", "<bos>" });
740
+
741
+ // Find newline token
742
+ foreach (var entry in vocab)
743
+ {
744
+ if (entry.Key == "\\n" || entry.Key == "<\\n>" || entry.Key == "\\n")
745
+ {
746
+ NewlineToken = entry.Value;
747
+ break;
748
+ }
749
+ }
750
+
751
+ Debug.Log($"Special tokens - PAD: {PadToken}, EOS: {EosToken}, BOS: {BosToken}, NEWLINE: {NewlineToken}");
752
+ }
753
+
754
+ private int FindToken(string[] candidates)
755
+ {
756
+ foreach (var candidate in candidates)
757
+ {
758
+ if (vocab.TryGetValue(candidate, out int id))
759
+ {
760
+ return id;
761
+ }
762
+ }
763
+
764
+ // Return -1 if not found
765
+ return -1;
766
+ }
767
+
768
+ public int[] Encode(string text)
769
+ {
770
+ // Simple character-level tokenization
771
+ // In production, use a proper BPE/WordPiece tokenizer implementation
772
+ List<int> tokens = new List<int>();
773
+ StringBuilder currentToken = new StringBuilder();
774
+
775
+ // Add BOS token if available
776
+ if (BosToken != -1)
777
+ {
778
+ tokens.Add(BosToken);
779
+ }
780
+
781
+ // Very simple tokenization - in production, this would implement
782
+ // the specific tokenization algorithm for your model
783
+ foreach (char c in text)
784
+ {
785
+ currentToken.Append(c);
786
+ string current = currentToken.ToString();
787
+
788
+ if (vocab.TryGetValue(current, out int id))
789
+ {
790
+ tokens.Add(id);
791
+ currentToken.Clear();
792
+ }
793
+ else if (currentToken.Length > 10)
794
+ {
795
+ // If token is too long, add unknown token and reset
796
+ tokens.Add(vocab.ContainsKey("<unk>") ? vocab["<unk>"] : 0);
797
+ currentToken.Clear();
798
+ currentToken.Append(c);
799
+ }
800
+ }
801
+
802
+ // Handle any remaining text
803
+ if (currentToken.Length > 0)
804
+ {
805
+ tokens.Add(vocab.ContainsKey("<unk>") ? vocab["<unk>"] : 0);
806
+ }
807
+
808
+ return tokens.ToArray();
809
+ }
810
+
811
+ public string Decode(int[] ids)
812
+ {
813
+ StringBuilder result = new StringBuilder();
814
+
815
+ foreach (int id in ids)
816
+ {
817
+ if (reversedVocab.TryGetValue(id, out string token))
818
+ {
819
+ // Some tokenizers use special prefixes like "Ġ" for spaces
820
+ string processedToken = token
821
+ .Replace("Ġ", " ")
822
+ .Replace("Ċ", "\n")
823
+ .Replace("▁", " ");
824
+
825
+ result.Append(processedToken);
826
+ }
827
+ }
828
+
829
+ return result.ToString();
830
+ }
831
+ }
832
+ """)
833
+
834
+ # Calculate elapsed time
835
+ end_time = time.time()
836
+ duration = end_time - start_time
837
+ logger.info(f"✓ Conversion completed in {duration:.2f} seconds")
838
+ logger.info(f"✓ Final model size: {os.path.getsize(onnx_path) / (1024 * 1024):.2f} MB")
839
+
840
+ # Create a Python example usage file
841
+ example_path = os.path.join(model_dir, "example_usage.py")
842
+ with open(example_path, 'w') as f:
843
+ f.write("""
844
+ import onnxruntime as ort
845
+ from transformers import AutoTokenizer
846
+ import numpy as np
847
+
848
+ # Load tokenizer and model
849
+ tokenizer = AutoTokenizer.from_pretrained("./") # Path to model directory
850
+ session = ort.InferenceSession("./model.onnx")
851
+
852
+ def generate_response(user_message, max_length=50):
853
+ # Format as a chat message
854
+ prompt = f"User: {user_message}\\nAssistant:"
855
+ inputs = tokenizer(prompt, return_tensors="np")
856
+
857
+ input_ids = inputs["input_ids"]
858
+ attention_mask = inputs["attention_mask"]
859
+
860
+ # Simple auto-regressive generation loop
861
+ for _ in range(max_length):
862
+ # Run inference for a single step
863
+ outputs = session.run(
864
+ ["logits"],
865
+ {
866
+ "input_ids": input_ids,
867
+ "attention_mask": attention_mask
868
+ }
869
+ )
870
+
871
+ # Get next token prediction from logits
872
+ logits = outputs[0]
873
+ next_token_logits = logits[0, -1, :]
874
+
875
+ # Apply temperature sampling
876
+ temperature = 0.7
877
+ next_token_logits = next_token_logits / temperature
878
+
879
+ # Apply softmax to get probabilities
880
+ exp_logits = np.exp(next_token_logits - np.max(next_token_logits))
881
+ probs = exp_logits / np.sum(exp_logits)
882
+
883
+ # Sample from the distribution
884
+ next_token_id = np.random.choice(probs.shape[0], p=probs)
885
+
886
+ # Stop if we hit the end of sequence token
887
+ if next_token_id == tokenizer.eos_token_id:
888
+ break
889
+
890
+ # Append new token to the input_ids
891
+ input_ids = np.concatenate([input_ids, [[next_token_id]]], axis=1)
892
+ attention_mask = np.concatenate([attention_mask, [[1]]], axis=1)
893
+
894
+ # Decode the entire response
895
+ response = tokenizer.decode(input_ids[0], skip_special_tokens=True)
896
+
897
+ # Extract only the assistant's response
898
+ if "Assistant:" in response:
899
+ response = response.split("Assistant:")[-1].strip()
900
+
901
+ return response
902
+
903
+ # Example usage
904
+ while True:
905
+ user_input = input("You: ")
906
+ if user_input.lower() in ['exit', 'quit']:
907
+ break
908
+ response = generate_response(user_input)
909
+ print(f"Assistant: {response}")
910
+ """)
911
+
912
+ logger.info(f"✓ Example usage saved to {example_path}")
913
+ logger.info(f"✓ Unity integration example saved to {unity_example_path}")
914
+ return True
915
+
916
+ else:
917
+ logger.error(f"× ONNX file not created at {onnx_path}")
918
+ return False
919
+
920
+ except Exception as e:
921
+ logger.error(f"�� Error converting model: {str(e)}")
922
+ logger.error(traceback.format_exc())
923
+ return False
924
+
925
+ if __name__ == "__main__":
926
+ # Parse command line arguments
927
+ parser_available = False
928
+ try:
929
+ import argparse
930
+ parser = argparse.ArgumentParser(description="Convert HuggingFace models to ONNX for Unity")
931
+ parser.add_argument("model_id", type=str, help="HuggingFace model ID or path")
932
+ parser.add_argument("--output_dir", "-o", type=str, default="./onnx_models",
933
+ help="Output directory for the converted model")
934
+ parser.add_argument("--seq_length", "-s", type=int, default=32,
935
+ help="Sequence length for model export")
936
+ parser.add_argument("--no_quantize", action="store_true",
937
+ help="Skip INT8 quantization step")
938
+ parser.add_argument("--opset", "-op", type=int, default=None,
939
+ help="Force a specific ONNX opset version")
940
+
941
+ args = parser.parse_args()
942
+ parser_available = True
943
+
944
+ model_id = args.model_id
945
+ output_dir = args.output_dir
946
+ seq_length = args.seq_length
947
+ quantize = not args.no_quantize
948
+ force_opset = args.opset
949
+
950
+ except (ImportError, NameError):
951
+ # Fallback if argparse is not available
952
+ parser_available = False
953
+
954
+ if not parser_available:
955
+ if len(sys.argv) < 2:
956
+ print("Usage: python unity_compatible_converter.py MODEL_ID [OUTPUT_DIR] [SEQ_LENGTH] [--no-quantize] [--opset]")
957
+ print("Example: python unity_compatible_converter.py distilgpt2 ./onnx_models 32")
958
+ print("\nRecommended chat models for Unity:")
959
+ print(" - distilgpt2 (smallest, opset 11)")
960
+ print(" - EleutherAI/pythia-70m (better quality, opset 11)")
961
+ print(" - microsoft/phi-1 (high quality, opset 14)")
962
+ print(" - TinyLlama/TinyLlama-1.1B-Chat-v1.0 (chat-tuned, opset 14)")
963
+ sys.exit(1)
964
+
965
+ model_id = sys.argv[1]
966
+ output_dir = sys.argv[2] if len(sys.argv) > 2 else "./onnx_models"
967
+ seq_length = int(sys.argv[3]) if len(sys.argv) > 3 else 32
968
+ quantize = "--no-quantize" not in sys.argv and "--no_quantize" not in sys.argv
969
+ force_opset = None
970
+
971
+ # Check for opset flag
972
+ for i, arg in enumerate(sys.argv):
973
+ if arg == "--opset" and i + 1 < len(sys.argv):
974
+ force_opset = int(sys.argv[i + 1])
975
+
976
+ # Check model architecture for automatic opset recommendation
977
+ is_compatible, recommended_opset = is_architecture_compatible(model_id)
978
+
979
+ # Print header
980
+ logger.info("\nUNITY-COMPATIBLE ONNX CONVERTER")
981
+ logger.info("===============================")
982
+ logger.info(f"Model: {model_id}")
983
+ logger.info(f"Output directory: {output_dir}")
984
+ logger.info(f"Sequence length: {seq_length}")
985
+
986
+ if force_opset is not None:
987
+ logger.info(f"ONNX opset version: {force_opset} (forced)")
988
+ else:
989
+ logger.info(f"Recommended ONNX opset: {recommended_opset}")
990
+ logger.info(f"Architecture compatible with opset 11: {'Yes' if is_compatible else 'No'}")
991
+
992
+ logger.info(f"Quantization: {'Enabled' if quantize else 'Disabled'}")
993
+
994
+ # Create output directory
995
+ os.makedirs(output_dir, exist_ok=True)
996
+
997
+ # Convert the model
998
+ success = convert_model(model_id, output_dir, seq_length, quantize, force_opset)
999
+
1000
+ if success:
1001
+ logger.info("\n" + "=" * 60)
1002
+ logger.info("CONVERSION SUCCESSFUL")
1003
+ logger.info("=" * 60)
1004
+ logger.info(f"Model: {model_id}")
1005
+ logger.info(f"Output directory: {os.path.abspath(output_dir)}")
1006
+ logger.info("The model is ready for Unity integration!")
1007
+ logger.info("\nNext steps:")
1008
+ logger.info("1. Import the ONNX model into Unity using the Sentis package")
1009
+ logger.info("2. Use the unity_integration.cs file as a starting point")
1010
+ logger.info("3. For tokenization in Unity, implement the SimpleTokenizer class")
1011
+ else:
1012
+ logger.info("\n" + "=" * 60)
1013
+ logger.info("CONVERSION FAILED")
1014
+ logger.info("=" * 60)
1015
+ logger.info("Please try one of the recommended models that work well with Unity:")
1016
+
1017
+ if is_compatible:
1018
+ logger.info("Compatible with Unity (opset 11):")
1019
+ logger.info(" - distilgpt2")
1020
+ logger.info(" - EleutherAI/pythia-70m")
1021
+
1022
+ logger.info("Advanced models (require opset 14):")
1023
+ logger.info(" - microsoft/phi-1 --opset 14")
1024
+ logger.info(" - TinyLlama/TinyLlama-1.1B-Chat-v1.0 --opset 14")
old_scripts/convert_single_model.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import sys
4
+ import time
5
+ import logging
6
+ import traceback
7
+ import torch
8
+ import warnings
9
+ import numpy as np
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ from transformers.generation import GenerationConfig
12
+ from tqdm import tqdm
13
+ import onnx
14
+ from onnxruntime.quantization import quantize_dynamic, QuantType
15
+
16
+ # Configure logging
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format='%(asctime)s - %(levelname)s - %(message)s',
20
+ datefmt='%Y-%m-%d %H:%M:%S'
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Suppress unhelpful warnings
25
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*The shape of the input dimension.*")
26
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*Converting a tensor to a Python.*")
27
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*The model does not use GenerationMixin.*")
28
+
29
+
30
+ class GenerationWrapper(torch.nn.Module):
31
+ """
32
+ Wrapper for model export that handles generation properly.
33
+ This ensures the model can be correctly used for text generation.
34
+ """
35
+ def __init__(self, model):
36
+ super().__init__()
37
+ self.model = model
38
+ self.config = model.config
39
+
40
+ def forward(self, input_ids, attention_mask=None):
41
+ # Return only the logits to avoid complex structures
42
+ with torch.no_grad():
43
+ try:
44
+ # Standard approach for most models
45
+ outputs = self.model(
46
+ input_ids=input_ids,
47
+ attention_mask=attention_mask,
48
+ use_cache=False,
49
+ return_dict=True
50
+ )
51
+ return outputs.logits
52
+ except Exception as e:
53
+ logger.warning(f"Standard forward pass failed, trying fallback: {str(e)}")
54
+ # Fallback for models with different API
55
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
56
+ if hasattr(outputs, 'logits'):
57
+ return outputs.logits
58
+ elif isinstance(outputs, tuple) and len(outputs) > 0:
59
+ return outputs[0] # First element is typically logits
60
+ else:
61
+ raise ValueError("Could not extract logits from model outputs")
62
+
63
+
64
+ def verify_model_generation(model, tokenizer, device="cpu"):
65
+ """Test model generation capabilities before export"""
66
+ model.eval()
67
+ prompt = "Hello, how are you today? I am"
68
+
69
+ logger.info("Testing model generation...")
70
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
71
+
72
+ # Configure generation parameters
73
+ gen_config = GenerationConfig(
74
+ max_length=30,
75
+ do_sample=True,
76
+ temperature=0.7,
77
+ num_return_sequences=1,
78
+ )
79
+
80
+ try:
81
+ # Try generation
82
+ with torch.no_grad():
83
+ outputs = model.generate(
84
+ **inputs,
85
+ generation_config=gen_config
86
+ )
87
+
88
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
89
+ logger.info(f"Test generation result: {generated_text}")
90
+
91
+ if len(generated_text) <= len(prompt):
92
+ logger.warning("Generation output is not longer than input prompt!")
93
+
94
+ return True
95
+ except Exception as e:
96
+ logger.error(f"Generation test failed: {str(e)}")
97
+ return False
98
+
99
+
100
+ def test_onnx_model(onnx_path, tokenizer):
101
+ """Verify the ONNX model can be loaded and run"""
102
+ try:
103
+ import onnxruntime as ort
104
+
105
+ logger.info("Testing ONNX model inference...")
106
+ session = ort.InferenceSession(onnx_path)
107
+
108
+ # Get input and output names
109
+ input_names = [input.name for input in session.get_inputs()]
110
+ output_names = [output.name for output in session.get_outputs()]
111
+
112
+ # Create test input
113
+ prompt = "Hello, how are you?"
114
+ inputs = tokenizer(prompt, return_tensors="np")
115
+
116
+ # Prepare input dict
117
+ onnx_inputs = {}
118
+ for name in input_names:
119
+ if name == "input_ids" and "input_ids" in inputs:
120
+ onnx_inputs[name] = inputs["input_ids"]
121
+ elif name == "attention_mask" and "attention_mask" in inputs:
122
+ onnx_inputs[name] = inputs["attention_mask"]
123
+
124
+ # Run inference
125
+ outputs = session.run(output_names, onnx_inputs)
126
+
127
+ # Check output shape
128
+ logits = outputs[0]
129
+ logger.info(f"ONNX model output shape: {logits.shape}")
130
+
131
+ if logits.shape[0] != 1 or logits.shape[1] != inputs["input_ids"].shape[1]:
132
+ logger.warning("Output shape doesn't match expected dimensions!")
133
+
134
+ # Test next token prediction
135
+ next_token_logits = logits[0, -1, :]
136
+ next_token_id = np.argmax(next_token_logits)
137
+ next_token = tokenizer.decode([next_token_id])
138
+ logger.info(f"Next predicted token: '{next_token}'")
139
+
140
+ return True
141
+ except Exception as e:
142
+ logger.error(f"ONNX model test failed: {str(e)}")
143
+ return False
144
+
145
+
146
+ def optimize_onnx_model(onnx_path):
147
+ """Apply ONNX optimizations to improve performance"""
148
+ try:
149
+ logger.info("Optimizing ONNX model...")
150
+
151
+ # Load the model
152
+ model = onnx.load(onnx_path)
153
+
154
+ # Apply optimizations
155
+ from onnxruntime.transformers import optimizer
156
+
157
+ # Get model type from path
158
+ model_path = os.path.dirname(onnx_path)
159
+ model_name = os.path.basename(model_path).lower()
160
+
161
+ # Determine model type for optimization
162
+ if "gpt" in model_name:
163
+ model_type = "gpt2"
164
+ elif "opt" in model_name:
165
+ model_type = "opt"
166
+ elif "pythia" in model_name:
167
+ model_type = "gpt_neox"
168
+ else:
169
+ model_type = "gpt2" # Default fallback
170
+
171
+ logger.info(f"Using optimization profile for model type: {model_type}")
172
+
173
+ # Try to optimize the model
174
+ try:
175
+ optimized_model = optimizer.optimize_model(
176
+ onnx_path,
177
+ model_type=model_type,
178
+ num_heads=8, # Will be overridden by model's real config
179
+ hidden_size=768, # Will be overridden by model's real config
180
+ optimization_options=None
181
+ )
182
+ optimized_model.save_model_to_file(onnx_path)
183
+ logger.info("✓ ONNX model optimized")
184
+ return True
185
+ except Exception as e:
186
+ logger.warning(f"Optimization failed (non-critical): {str(e)}")
187
+ return False
188
+
189
+ except Exception as e:
190
+ logger.warning(f"ONNX optimization error (skipping): {str(e)}")
191
+ return False
192
+
193
+
194
+ def convert_model(model_id, output_dir="./onnx_models", seq_length=32, quantize=True):
195
+ """
196
+ Convert a model to ONNX format with focus on reliability for generation.
197
+
198
+ Args:
199
+ model_id: HuggingFace model ID or path
200
+ output_dir: Directory to save the model
201
+ seq_length: Input sequence length for export
202
+ quantize: Whether to quantize the model to INT8
203
+
204
+ Returns:
205
+ bool: Success status
206
+ """
207
+ start_time = time.time()
208
+
209
+ logger.info(f"\n{'=' * 60}")
210
+ logger.info(f"Converting {model_id} to ONNX (optimized for generation)")
211
+ logger.info(f"{'=' * 60}")
212
+
213
+ # Create output directory
214
+ model_name = model_id.split("/")[-1]
215
+ model_dir = os.path.join(output_dir, model_name)
216
+ os.makedirs(model_dir, exist_ok=True)
217
+
218
+ try:
219
+ # Step 1: Load tokenizer
220
+ logger.info("Step 1/6: Loading tokenizer...")
221
+
222
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
223
+ if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'):
224
+ logger.info("Adding pad_token = eos_token")
225
+ tokenizer.pad_token = tokenizer.eos_token
226
+
227
+ # Save tokenizer
228
+ tokenizer.save_pretrained(model_dir)
229
+ logger.info(f"✓ Tokenizer saved to {model_dir}")
230
+
231
+ # Step 2: Load model with reliability optimizations
232
+ logger.info("Step 2/6: Loading model...")
233
+
234
+ # Clean memory
235
+ gc.collect()
236
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
237
+
238
+ # Determine device
239
+ device = "cuda" if torch.cuda.is_available() else "cpu"
240
+
241
+ # Load model with full precision
242
+ model = AutoModelForCausalLM.from_pretrained(
243
+ model_id,
244
+ torch_dtype=torch.float32, # Use full precision for reliability
245
+ low_cpu_mem_usage=True, # Reduce memory usage
246
+ device_map=device # Use CUDA if available
247
+ )
248
+
249
+ # Save config
250
+ model.config.save_pretrained(model_dir)
251
+ logger.info(f"✓ Model config saved to {model_dir}")
252
+
253
+ # Step 3: Verify model can generate text
254
+ logger.info("Step 3/6: Validating generation capabilities...")
255
+
256
+ if not verify_model_generation(model, tokenizer, device):
257
+ logger.warning("⚠ Model generation test didn't complete successfully")
258
+ logger.info("Continuing with export anyway...")
259
+
260
+ # Step 4: Wrap and prepare for export
261
+ logger.info("Step 4/6: Preparing for export...")
262
+
263
+ # Wrap model with generation-optimized interface
264
+ wrapped_model = GenerationWrapper(model)
265
+ wrapped_model.eval()
266
+
267
+ # Clean memory again
268
+ gc.collect()
269
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
270
+
271
+ # Step 5: Export to ONNX
272
+ logger.info("Step 5/6: Exporting to ONNX format...")
273
+ onnx_path = os.path.join(model_dir, "model.onnx")
274
+
275
+ # Create minimal input
276
+ batch_size = 1
277
+ dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long)
278
+ attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long)
279
+
280
+ # Move tensors to correct device
281
+ dummy_input = dummy_input.to(device)
282
+ attention_mask = attention_mask.to(device)
283
+
284
+ # Export to ONNX with required opset for transformer models
285
+ with torch.no_grad():
286
+ torch.onnx.export(
287
+ wrapped_model, # Wrapped model
288
+ (dummy_input, attention_mask), # Input tensors
289
+ onnx_path, # Output path
290
+ export_params=True, # Store weights
291
+ opset_version=14, # Required for transformer models
292
+ do_constant_folding=True, # Optimize constants
293
+ input_names=['input_ids', 'attention_mask'], # Input names
294
+ output_names=['logits'], # Output name
295
+ dynamic_axes={ # Dynamic dimensions
296
+ 'input_ids': {0: 'batch_size', 1: 'sequence'},
297
+ 'attention_mask': {0: 'batch_size', 1: 'sequence'},
298
+ 'logits': {0: 'batch_size', 1: 'sequence'}
299
+ }
300
+ )
301
+
302
+ # Clean up to save memory
303
+ del model
304
+ del wrapped_model
305
+ gc.collect()
306
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
307
+
308
+ # Verify export success
309
+ if os.path.exists(onnx_path):
310
+ size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
311
+ logger.info(f"✓ ONNX model saved to {onnx_path}")
312
+ logger.info(f"✓ Original size: {size_mb:.2f} MB")
313
+
314
+ # Test ONNX model
315
+ test_onnx_model(onnx_path, tokenizer)
316
+
317
+ # Optimize the ONNX model
318
+ optimize_onnx_model(onnx_path)
319
+
320
+ # Step 6: Quantize the model (optional)
321
+ if quantize:
322
+ logger.info("Step 6/6: Applying INT8 quantization...")
323
+ quant_path = onnx_path.replace(".onnx", "_quantized.onnx")
324
+
325
+ try:
326
+ with tqdm(total=100, desc="Quantizing") as pbar:
327
+ # Update progress callback
328
+ def update_progress(x):
329
+ pbar.update(1)
330
+
331
+ quantize_dynamic(
332
+ model_input=onnx_path,
333
+ model_output=quant_path,
334
+ per_channel=False,
335
+ reduce_range=False,
336
+ weight_type=QuantType.QInt8,
337
+ optimize_model=True,
338
+ use_external_data_format=False
339
+ )
340
+
341
+ pbar.update(100) # Ensure progress reaches 100%
342
+
343
+ if os.path.exists(quant_path):
344
+ quant_size = os.path.getsize(quant_path) / (1024 * 1024)
345
+ logger.info(f"✓ Quantized size: {quant_size:.2f} MB")
346
+ logger.info(f"✓ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%")
347
+
348
+ # Test the quantized model
349
+ test_onnx_model(quant_path, tokenizer)
350
+
351
+ # Rename original as backup
352
+ backup_path = onnx_path.replace(".onnx", "_fp32.onnx")
353
+ os.rename(onnx_path, backup_path)
354
+
355
+ # Replace original with quantized
356
+ os.rename(quant_path, onnx_path)
357
+ logger.info("✓ Original model preserved as *_fp32.onnx")
358
+ logger.info("✓ Replaced original with quantized version")
359
+ else:
360
+ logger.warning("⚠ Quantized file not created, using original")
361
+ except Exception as e:
362
+ logger.error(f"⚠ Quantization error: {str(e)}")
363
+ logger.info("⚠ Using original model without quantization")
364
+ else:
365
+ logger.info("Step 6/6: Skipping quantization as requested")
366
+
367
+ # Calculate elapsed time
368
+ end_time = time.time()
369
+ duration = end_time - start_time
370
+ logger.info(f"✓ Conversion completed in {duration:.2f} seconds")
371
+ logger.info(f"✓ Final model size: {os.path.getsize(onnx_path) / (1024 * 1024):.2f} MB")
372
+
373
+ # Create a simple example usage file
374
+ example_path = os.path.join(model_dir, "example_usage.py")
375
+ with open(example_path, 'w') as f:
376
+ f.write("""
377
+ import onnxruntime as ort
378
+ from transformers import AutoTokenizer
379
+ import numpy as np
380
+
381
+ # Load tokenizer and model
382
+ tokenizer = AutoTokenizer.from_pretrained("./") # Path to model directory
383
+ session = ort.InferenceSession("./model.onnx")
384
+
385
+ # Prepare input
386
+ prompt = "Hello, how are you?"
387
+ inputs = tokenizer(prompt, return_tensors="np")
388
+
389
+ # Run inference for a single step
390
+ outputs = session.run(
391
+ ["logits"],
392
+ {
393
+ "input_ids": inputs["input_ids"],
394
+ "attention_mask": inputs["attention_mask"]
395
+ }
396
+ )
397
+
398
+ # Get next token prediction
399
+ logits = outputs[0]
400
+ next_token_id = np.argmax(logits[0, -1, :])
401
+ next_token = tokenizer.decode([next_token_id])
402
+ print(f"Next predicted token: {next_token}")
403
+
404
+ # For full generation, you'd typically run in a loop, adding tokens one by one
405
+ """)
406
+ logger.info(f"✓ Example usage saved to {example_path}")
407
+
408
+ return True
409
+ else:
410
+ logger.error(f"× ONNX file not created at {onnx_path}")
411
+ return False
412
+
413
+ except Exception as e:
414
+ logger.error(f"× Error converting model: {str(e)}")
415
+ logger.error(traceback.format_exc())
416
+ return False
417
+
418
+
419
+ if __name__ == "__main__":
420
+ # Parse command line arguments
421
+ parser_available = False
422
+ try:
423
+ import argparse
424
+ parser = argparse.ArgumentParser(description="Convert HuggingFace models to ONNX for generation")
425
+ parser.add_argument("model_id", type=str, help="HuggingFace model ID or path")
426
+ parser.add_argument("--output_dir", "-o", type=str, default="./onnx_models",
427
+ help="Output directory for the converted model")
428
+ parser.add_argument("--seq_length", "-s", type=int, default=32,
429
+ help="Sequence length for model export")
430
+ parser.add_argument("--no_quantize", action="store_true",
431
+ help="Skip INT8 quantization step")
432
+
433
+ args = parser.parse_args()
434
+ parser_available = True
435
+
436
+ model_id = args.model_id
437
+ output_dir = args.output_dir
438
+ seq_length = args.seq_length
439
+ quantize = not args.no_quantize
440
+
441
+ except (ImportError, NameError):
442
+ # Fallback if argparse is not available
443
+ parser_available = False
444
+
445
+ if not parser_available:
446
+ if len(sys.argv) < 2:
447
+ print("Usage: python convert_model.py MODEL_ID [OUTPUT_DIR] [SEQ_LENGTH] [--no-quantize]")
448
+ print("Example: python convert_model.py facebook/opt-125m ./onnx_models 32")
449
+ print("\nRecommended models for small hardware:")
450
+ print(" - facebook/opt-125m")
451
+ print(" - distilgpt2")
452
+ print(" - TinyLlama/TinyLlama-1.1B-Chat-v1.0")
453
+ print(" - EleutherAI/pythia-70m")
454
+ sys.exit(1)
455
+
456
+ model_id = sys.argv[1]
457
+ output_dir = sys.argv[2] if len(sys.argv) > 2 else "./onnx_models"
458
+ seq_length = int(sys.argv[3]) if len(sys.argv) > 3 else 32
459
+ quantize = "--no-quantize" not in sys.argv and "--no_quantize" not in sys.argv
460
+
461
+ # Print header
462
+ logger.info("\nENHANCED ONNX CONVERTER FOR LANGUAGE MODELS")
463
+ logger.info("============================================")
464
+ logger.info(f"Model: {model_id}")
465
+ logger.info(f"Output directory: {output_dir}")
466
+ logger.info(f"Sequence length: {seq_length}")
467
+ logger.info(f"Quantization: {'Enabled' if quantize else 'Disabled'}")
468
+
469
+ # Create output directory
470
+ os.makedirs(output_dir, exist_ok=True)
471
+
472
+ # Convert the model
473
+ success = convert_model(model_id, output_dir, seq_length, quantize)
474
+
475
+ if success:
476
+ logger.info("\n" + "=" * 60)
477
+ logger.info("CONVERSION SUCCESSFUL")
478
+ logger.info("=" * 60)
479
+ logger.info(f"Model: {model_id}")
480
+ logger.info(f"Output directory: {os.path.abspath(output_dir)}")
481
+ logger.info("The model is ready for generation!")
482
+ logger.info("\nTo use the model:")
483
+ logger.info("1. See the example_usage.py file in the model directory")
484
+ logger.info("2. For chatbot applications, implement token-by-token generation")
485
+ else:
486
+ logger.error("\n" + "=" * 60)
487
+ logger.error("CONVERSION FAILED")
488
+ logger.error("=" * 60)
489
+ logger.error("Please try one of the recommended models:")
490
+ logger.error(" - facebook/opt-125m")
491
+ logger.error(" - distilgpt2")
492
+ logger.error(" - EleutherAI/pythia-70m")
old_scripts/convert_to_onnx.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import sys
4
+ import time
5
+ import logging
6
+ import traceback
7
+ import torch
8
+ import warnings
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from onnxruntime.quantization import quantize_dynamic, QuantType
11
+
12
+ # Configure logging
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(levelname)s - %(message)s',
16
+ datefmt='%Y-%m-%d %H:%M:%S',
17
+ handlers=[logging.StreamHandler(sys.stdout)]
18
+ )
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Suppress specific warnings
22
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*The shape of the input dimension.*")
23
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*Converting a tensor to a Python.*")
24
+
25
+ # Models that are known to work well with ONNX conversion
26
+ RELIABLE_MODELS = [
27
+ {
28
+ "id": "facebook/opt-350m",
29
+ "description": "Well-balanced model (350M) for RAG and chatbots"
30
+ },
31
+ {
32
+ "id": "gpt2",
33
+ "description": "Very reliable model (124M) with excellent ONNX compatibility"
34
+ },
35
+ {
36
+ "id": "distilgpt2",
37
+ "description": "Lightweight (82M) model with good performance"
38
+ }
39
+ ]
40
+
41
+ class ModelWrapper(torch.nn.Module):
42
+ """
43
+ Wrapper to handle ONNX export compatibility issues.
44
+ This wrapper specifically:
45
+ 1. Bypasses cache handling
46
+ 2. Simplifies the forward pass to avoid dynamic operations
47
+ """
48
+ def __init__(self, model):
49
+ super().__init__()
50
+ self.model = model
51
+
52
+ def forward(self, input_ids):
53
+ # Force no cache, no gradient, and no special features
54
+ with torch.no_grad():
55
+ return self.model(input_ids=input_ids, use_cache=False, return_dict=False)[0]
56
+
57
+ def convert_model(model_id, output_dir, quantize=True):
58
+ """Convert a model to ONNX format with maximum compatibility."""
59
+ start_time = time.time()
60
+
61
+ logger.info(f"\n{'=' * 60}")
62
+ logger.info(f"Converting {model_id} to ONNX")
63
+ logger.info(f"{'=' * 60}")
64
+
65
+ # Create output directory
66
+ model_name = model_id.split("/")[-1]
67
+ model_dir = os.path.join(output_dir, model_name)
68
+ os.makedirs(model_dir, exist_ok=True)
69
+
70
+ try:
71
+ # Step 1: Load tokenizer
72
+ logger.info("Step 1/5: Loading tokenizer...")
73
+
74
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
75
+
76
+ # Handle missing pad token
77
+ if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'):
78
+ logger.info("Adding pad_token = eos_token")
79
+ tokenizer.pad_token = tokenizer.eos_token
80
+
81
+ # Save tokenizer
82
+ tokenizer.save_pretrained(model_dir)
83
+ logger.info(f"✓ Tokenizer saved to {model_dir}")
84
+
85
+ # Step 2: Load model with memory optimizations
86
+ logger.info("Step 2/5: Loading model with memory optimizations...")
87
+
88
+ # Clean memory before loading
89
+ gc.collect()
90
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
91
+
92
+ # Load model with optimizations
93
+ model = AutoModelForCausalLM.from_pretrained(
94
+ model_id,
95
+ torch_dtype=torch.float16, # Use half precision
96
+ low_cpu_mem_usage=True # Reduce memory usage
97
+ )
98
+
99
+ # Save config for reference
100
+ model.config.save_pretrained(model_dir)
101
+ logger.info(f"✓ Model config saved to {model_dir}")
102
+
103
+ # Step 3: Prepare for export
104
+ logger.info("Step 3/5: Preparing for export...")
105
+
106
+ # Wrap model to avoid tracing issues
107
+ wrapped_model = ModelWrapper(model)
108
+ wrapped_model.eval() # Set to evaluation mode
109
+
110
+ # Clean memory again
111
+ gc.collect()
112
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
113
+
114
+ # Step 4: Export to ONNX
115
+ logger.info("Step 4/5: Exporting to ONNX format...")
116
+ onnx_path = os.path.join(model_dir, "model.onnx")
117
+
118
+ # Create dummy input
119
+ batch_size = 1
120
+ seq_length = 8 # Small sequence length to reduce memory
121
+ dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long)
122
+
123
+ # Export to ONNX format with new opset version
124
+ torch.onnx.export(
125
+ wrapped_model, # Use wrapped model
126
+ dummy_input, # Model input
127
+ onnx_path, # Output path
128
+ export_params=True, # Store model weights
129
+ opset_version=14, # ONNX opset version (changed from 13 to 14)
130
+ do_constant_folding=True, # Optimize constants
131
+ input_names=['input_ids'], # Input names
132
+ output_names=['logits'], # Output names
133
+ dynamic_axes={
134
+ 'input_ids': {0: 'batch_size', 1: 'sequence'},
135
+ 'logits': {0: 'batch_size', 1: 'sequence'}
136
+ }
137
+ )
138
+
139
+ # Clean up to save memory
140
+ del model
141
+ del wrapped_model
142
+ gc.collect()
143
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
144
+
145
+ # Verify export was successful
146
+ if os.path.exists(onnx_path):
147
+ size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
148
+ logger.info(f"✓ ONNX model saved to {onnx_path}")
149
+ logger.info(f"✓ Original size: {size_mb:.2f} MB")
150
+
151
+ # Step 5: Quantize
152
+ if quantize:
153
+ logger.info("Step 5/5: Applying int8 quantization...")
154
+ quant_path = onnx_path.replace(".onnx", "_quantized.onnx")
155
+
156
+ try:
157
+ quantize_dynamic(
158
+ model_input=onnx_path,
159
+ model_output=quant_path,
160
+ per_channel=False,
161
+ reduce_range=False,
162
+ weight_type=QuantType.QInt8
163
+ )
164
+
165
+ if os.path.exists(quant_path):
166
+ quant_size = os.path.getsize(quant_path) / (1024 * 1024)
167
+ logger.info(f"✓ Quantized size: {quant_size:.2f} MB")
168
+ logger.info(f"✓ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%")
169
+
170
+ # Replace original with quantized to save space
171
+ os.replace(quant_path, onnx_path)
172
+ logger.info("✓ Replaced original with quantized version")
173
+ else:
174
+ logger.warning("⚠ Quantized file not created, using original")
175
+ except Exception as e:
176
+ logger.error(f"⚠ Quantization error: {str(e)}")
177
+ logger.info("⚠ Using original model without quantization")
178
+ else:
179
+ logger.info("Step 5/5: Skipping quantization (not requested)")
180
+
181
+ # Calculate elapsed time
182
+ end_time = time.time()
183
+ duration = end_time - start_time
184
+ logger.info(f"✓ Conversion completed in {duration:.2f} seconds")
185
+
186
+ return {
187
+ "success": True,
188
+ "model_id": model_id,
189
+ "size_mb": os.path.getsize(onnx_path) / (1024 * 1024),
190
+ "duration_seconds": duration,
191
+ "output_dir": model_dir
192
+ }
193
+ else:
194
+ logger.error(f"× ONNX file not created at {onnx_path}")
195
+ return {
196
+ "success": False,
197
+ "model_id": model_id,
198
+ "error": "ONNX file not created"
199
+ }
200
+
201
+ except Exception as e:
202
+ logger.error(f"× Error converting model: {str(e)}")
203
+ logger.error(traceback.format_exc())
204
+
205
+ return {
206
+ "success": False,
207
+ "model_id": model_id,
208
+ "error": str(e)
209
+ }
210
+
211
+ def main():
212
+ """Convert all reliable models."""
213
+ # Print header
214
+ logger.info("\nGUARANTEED ONNX CONVERTER")
215
+ logger.info("======================")
216
+ logger.info("Using reliable models with proven ONNX compatibility")
217
+
218
+ # Create output directory
219
+ output_dir = "./onnx_models"
220
+ os.makedirs(output_dir, exist_ok=True)
221
+
222
+ # Check if specific model ID provided as argument
223
+ if len(sys.argv) > 1:
224
+ model_id = sys.argv[1]
225
+ logger.info(f"Converting single model: {model_id}")
226
+ convert_model(model_id, output_dir)
227
+ return
228
+
229
+ # Convert all reliable models
230
+ results = []
231
+ for model_info in RELIABLE_MODELS:
232
+ model_id = model_info["id"]
233
+ logger.info(f"Processing model: {model_id}")
234
+ logger.info(f"Description: {model_info['description']}")
235
+
236
+ result = convert_model(model_id, output_dir)
237
+ results.append(result)
238
+
239
+ # Print summary
240
+ logger.info("\n" + "=" * 60)
241
+ logger.info("CONVERSION SUMMARY")
242
+ logger.info("=" * 60)
243
+
244
+ success_count = 0
245
+ for result in results:
246
+ if result.get("success", False):
247
+ success_count += 1
248
+ size_info = f" - Size: {result.get('size_mb', 0):.2f} MB"
249
+ time_info = f" - Time: {result.get('duration_seconds', 0):.2f}s"
250
+ logger.info(f"✓ SUCCESS: {result['model_id']}{size_info}{time_info}")
251
+ else:
252
+ logger.info(f"× FAILED: {result['model_id']} - Error: {result.get('error', 'Unknown error')}")
253
+
254
+ logger.info(f"\nSuccessfully converted {success_count}/{len(RELIABLE_MODELS)} models")
255
+ logger.info(f"Models saved to: {os.path.abspath(output_dir)}")
256
+
257
+ if success_count > 0:
258
+ logger.info("\nThe models are ready for RAG and chatbot applications!")
259
+
260
+ if __name__ == "__main__":
261
+ main()
old_scripts/test_chat.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import argparse
5
+ import logging
6
+ import numpy as np
7
+ import onnxruntime as ort
8
+ from transformers import AutoTokenizer
9
+ from tqdm import tqdm
10
+
11
+ # Configure logging
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format='%(asctime)s - %(levelname)s - %(message)s',
15
+ datefmt='%Y-%m-%d %H:%M:%S'
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class ONNXGenerationChatbot:
20
+ def __init__(self, model_path, max_length=100):
21
+ """
22
+ Initialize the ONNX chatbot for text generation.
23
+
24
+ Args:
25
+ model_path: Path to the directory containing the ONNX model and tokenizer
26
+ max_length: Maximum sequence length for generation
27
+ """
28
+ # Set up model paths
29
+ self.model_dir = model_path
30
+ self.onnx_path = os.path.join(self.model_dir, "model.onnx")
31
+ self.fp32_path = os.path.join(self.model_dir, "model_fp32.onnx")
32
+
33
+ # Check for model files
34
+ if not os.path.exists(self.onnx_path):
35
+ raise FileNotFoundError(f"ONNX model not found at {self.onnx_path}")
36
+
37
+ # Get model name for prompt formatting
38
+ self.model_name = os.path.basename(os.path.normpath(model_path))
39
+ logger.info(f"Using model: {self.model_name}")
40
+
41
+ # Load tokenizer
42
+ logger.info(f"Loading tokenizer from {self.model_dir}...")
43
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, local_files_only=True)
44
+
45
+ # Ensure tokenizer has necessary tokens
46
+ if self.tokenizer.pad_token is None and hasattr(self.tokenizer, 'eos_token'):
47
+ self.tokenizer.pad_token = self.tokenizer.eos_token
48
+
49
+ # Create optimized session
50
+ logger.info(f"Loading ONNX model from {self.onnx_path}...")
51
+ self.session_options = ort.SessionOptions()
52
+ self.session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
53
+ self.session_options.intra_op_num_threads = 4 # Adjust based on your CPU
54
+
55
+ # Create session with appropriate providers
56
+ providers = ['CPUExecutionProvider']
57
+ if 'CUDAExecutionProvider' in ort.get_available_providers():
58
+ logger.info("CUDA is available! Using GPU acceleration.")
59
+ providers.insert(0, 'CUDAExecutionProvider')
60
+
61
+ self.session = ort.InferenceSession(
62
+ self.onnx_path,
63
+ sess_options=self.session_options,
64
+ providers=providers
65
+ )
66
+
67
+ # Get input and output names from the model
68
+ self.input_names = [input.name for input in self.session.get_inputs()]
69
+ self.output_names = [output.name for output in self.session.get_outputs()]
70
+
71
+ logger.info(f"Model inputs: {self.input_names}")
72
+ logger.info(f"Model outputs: {self.output_names}")
73
+
74
+ # Settings
75
+ self.max_length = max_length
76
+ self.stop_tokens = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else []
77
+
78
+ # Try to add common stop tokens if they exist in the vocabulary
79
+ stop_words = ["<|endoftext|>", "</s>", "<|end|>"]
80
+ for word in stop_words:
81
+ try:
82
+ token_id = self.tokenizer.convert_tokens_to_ids(word)
83
+ if token_id not in self.stop_tokens and token_id != self.tokenizer.unk_token_id:
84
+ self.stop_tokens.append(token_id)
85
+ except:
86
+ pass
87
+
88
+ logger.info(f"Using stop tokens: {self.stop_tokens}")
89
+
90
+ # Conversation history for context
91
+ self.conversation_history = []
92
+
93
+ def get_prompt_template(self):
94
+ """
95
+ Get the appropriate prompt template based on the model type.
96
+ """
97
+ if "opt" in self.model_name.lower():
98
+ return "Human: {}\nAssistant:"
99
+ elif "pythia" in self.model_name.lower():
100
+ return "USER: {}\nASSISTANT:"
101
+ elif "llama" in self.model_name.lower() or "alpaca" in self.model_name.lower():
102
+ return "### Human: {}\n### Assistant:"
103
+ elif "gpt2" in self.model_name.lower() or "distilgpt2" in self.model_name.lower():
104
+ return "User: {}\nBot:"
105
+ else:
106
+ return "Question: {}\nAnswer:"
107
+
108
+ def format_prompt_with_history(self, user_message):
109
+ """
110
+ Format the prompt with conversation history for better context.
111
+ """
112
+ template = self.get_prompt_template()
113
+ parts = template.split("{}")
114
+ prefix = parts[0]
115
+ suffix = parts[1] if len(parts) > 1 else ""
116
+
117
+ # Include history if available (up to 3 turns)
118
+ formatted_prompt = ""
119
+ for i, (user, bot) in enumerate(self.conversation_history[-3:]):
120
+ formatted_prompt += f"{prefix}{user}{suffix} {bot}\n\n"
121
+
122
+ # Add current user message
123
+ formatted_prompt += f"{prefix}{user_message}{suffix}"
124
+
125
+ return formatted_prompt
126
+
127
+ def run_inference_step(self, input_ids, attention_mask=None):
128
+ """
129
+ Run a single inference step with the ONNX model.
130
+
131
+ Args:
132
+ input_ids: Token IDs of the input sequence
133
+ attention_mask: Attention mask for the input sequence
134
+
135
+ Returns:
136
+ numpy array: Logits for the next token prediction
137
+ """
138
+ # Prepare model inputs
139
+ model_inputs = {}
140
+ for name in self.input_names:
141
+ if name == "input_ids":
142
+ model_inputs[name] = input_ids
143
+ elif name == "attention_mask" and attention_mask is not None:
144
+ model_inputs[name] = attention_mask
145
+
146
+ # Run inference
147
+ outputs = self.session.run(self.output_names, model_inputs)
148
+
149
+ # Return logits (assumes first output is logits)
150
+ return outputs[0]
151
+
152
+ def generate_text(self, prompt, max_new_tokens=50, temperature=0.7, top_k=50, top_p=0.9,
153
+ repetition_penalty=1.1, do_sample=True, show_progress=True):
154
+ """
155
+ Generate text using the ONNX model.
156
+
157
+ Args:
158
+ prompt: Text prompt to generate from
159
+ max_new_tokens: Maximum number of tokens to generate
160
+ temperature: Temperature for sampling (higher = more random)
161
+ top_k: Number of highest probability tokens to keep for sampling
162
+ top_p: Cumulative probability threshold for nucleus sampling
163
+ repetition_penalty: Penalty for repeating tokens
164
+ do_sample: Whether to sample from the distribution or use greedy decoding
165
+ show_progress: Whether to show a progress bar during generation
166
+
167
+ Returns:
168
+ str: Generated text
169
+ """
170
+ # Encode the prompt
171
+ encoded = self.tokenizer(prompt, return_tensors="np")
172
+ input_ids = encoded["input_ids"]
173
+ attention_mask = encoded["attention_mask"]
174
+
175
+ # Track input tokens for repetition penalty
176
+ prev_tokens = input_ids[0].tolist()
177
+
178
+ # Setup progress bar if requested
179
+ progress = tqdm(total=max_new_tokens, desc="Generating") if show_progress else None
180
+
181
+ # Generate tokens auto-regressively
182
+ for _ in range(max_new_tokens):
183
+ # Run inference to get next token logits
184
+ logits = self.run_inference_step(input_ids, attention_mask)
185
+
186
+ # Get logits for the last token
187
+ next_token_logits = logits[0, -1, :]
188
+
189
+ # Apply temperature scaling
190
+ if temperature > 0:
191
+ next_token_logits = next_token_logits / max(temperature, 1e-8)
192
+
193
+ # Apply repetition penalty
194
+ if repetition_penalty > 1.0:
195
+ for prev_token in set(prev_tokens[-10:]): # Only consider recent tokens
196
+ if prev_token < len(next_token_logits):
197
+ next_token_logits[prev_token] /= repetition_penalty
198
+
199
+ # Apply top-k filtering
200
+ if top_k > 0:
201
+ indices_to_remove = np.argsort(next_token_logits)[:-top_k]
202
+ next_token_logits[indices_to_remove] = -float('inf')
203
+
204
+ # Apply top-p (nucleus) filtering
205
+ if 0 < top_p < 1.0:
206
+ sorted_logits = np.sort(next_token_logits)[::-1]
207
+ sorted_indices = np.argsort(next_token_logits)[::-1]
208
+ cumulative_probs = np.cumsum(np.exp(sorted_logits) / np.sum(np.exp(sorted_logits)))
209
+
210
+ # Remove tokens with cumulative probability above the threshold
211
+ sorted_indices_to_remove = sorted_indices[cumulative_probs > top_p]
212
+ next_token_logits[sorted_indices_to_remove] = -float('inf')
213
+
214
+ # Sample from the filtered distribution or use greedy decoding
215
+ if do_sample:
216
+ # Apply softmax to get probabilities
217
+ probs = np.exp(next_token_logits - np.max(next_token_logits))
218
+ probs = probs / np.sum(probs)
219
+
220
+ # Handle NaNs
221
+ if np.isnan(probs).any():
222
+ next_token_id = np.argmax(next_token_logits)
223
+ else:
224
+ try:
225
+ # Sample from the distribution
226
+ next_token_id = np.random.choice(len(probs), p=probs)
227
+ except:
228
+ # Fallback to greedy if sampling fails
229
+ next_token_id = np.argmax(next_token_logits)
230
+ else:
231
+ # Greedy decoding - take highest probability token
232
+ next_token_id = np.argmax(next_token_logits)
233
+
234
+ # Add the chosen token to the input
235
+ next_token = np.array([[next_token_id]])
236
+ input_ids = np.concatenate([input_ids, next_token], axis=1)
237
+
238
+ # Update attention mask
239
+ attention_mask = np.ones((1, input_ids.shape[1]), dtype=np.int64)
240
+
241
+ # Add token to history for repetition penalty
242
+ prev_tokens.append(int(next_token_id))
243
+
244
+ # Update progress bar if active
245
+ if progress is not None:
246
+ progress.update(1)
247
+
248
+ # Check for stop tokens or end of text
249
+ if next_token_id in self.stop_tokens:
250
+ break
251
+
252
+ # Also stop if we exceed max length
253
+ if input_ids.shape[1] >= self.max_length:
254
+ break
255
+
256
+ # Close progress bar if used
257
+ if progress is not None:
258
+ progress.close()
259
+
260
+ # Decode the full sequence
261
+ generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
262
+ return generated_text
263
+
264
+ def extract_assistant_response(self, full_text, prompt):
265
+ """
266
+ Extract just the assistant's response from the full generated text.
267
+
268
+ Args:
269
+ full_text: Full generated text including prompt
270
+ prompt: The original prompt
271
+
272
+ Returns:
273
+ str: Just the assistant's response
274
+ """
275
+ # Try to extract based on the prompt format
276
+ template = self.get_prompt_template()
277
+ response_start_marker = template.split("{}")[-1]
278
+
279
+ # If the prompt is in the text, extract everything after it
280
+ if prompt in full_text:
281
+ after_prompt = full_text[len(prompt):]
282
+
283
+ # Handle additional newlines or spaces at the beginning
284
+ return after_prompt.lstrip()
285
+
286
+ # If the response marker is in the text, extract everything after it
287
+ if response_start_marker.strip() in full_text:
288
+ parts = full_text.split(response_start_marker.strip(), 1)
289
+ if len(parts) > 1:
290
+ return parts[1].strip()
291
+
292
+ # Fallback: return everything after the last line of the prompt
293
+ prompt_last_line = prompt.strip().split('\n')[-1]
294
+ if prompt_last_line in full_text:
295
+ parts = full_text.split(prompt_last_line, 1)
296
+ if len(parts) > 1:
297
+ return parts[1].strip()
298
+
299
+ # Last resort: return the whole thing
300
+ return full_text
301
+
302
+ def chat(self, temperature=0.7, max_new_tokens=100):
303
+ """
304
+ Run an interactive chat session with the model.
305
+
306
+ Args:
307
+ temperature: Temperature for text generation
308
+ max_new_tokens: Maximum number of tokens to generate per response
309
+ """
310
+ print("\n===== ONNX Generation Chatbot =====")
311
+ print(f"Model: {self.model_name}")
312
+ print(f"Type 'exit' to end the conversation")
313
+ print(f"Type 'reset' to clear conversation history")
314
+
315
+ while True:
316
+ # Get user input
317
+ user_input = input("\nYou: ")
318
+
319
+ # Check for exit command
320
+ if user_input.lower() in ["exit", "quit", "bye"]:
321
+ print("Goodbye!")
322
+ break
323
+
324
+ # Check for reset command
325
+ if user_input.lower() == "reset":
326
+ self.conversation_history = []
327
+ print("Conversation history cleared.")
328
+ continue
329
+
330
+ # Create prompt with history
331
+ prompt = self.format_prompt_with_history(user_input)
332
+ print("\nGenerating response...")
333
+
334
+ # Generate text
335
+ try:
336
+ start_time = time.time()
337
+ full_text = self.generate_text(
338
+ prompt,
339
+ max_new_tokens=max_new_tokens,
340
+ temperature=temperature,
341
+ show_progress=True
342
+ )
343
+
344
+ # Extract just the assistant's response
345
+ response = self.extract_assistant_response(full_text, prompt)
346
+
347
+ # Clean up any trailing incomplete sentences
348
+ if response and len(response) > 0:
349
+ # Try to end at a sentence boundary if possible
350
+ sentence_end = max(
351
+ response.rfind('.'),
352
+ response.rfind('!'),
353
+ response.rfind('?')
354
+ )
355
+ if sentence_end > len(response) * 0.5: # Only trim if we're not losing too much
356
+ response = response[:sentence_end+1]
357
+
358
+ # Calculate generation time
359
+ gen_time = time.time() - start_time
360
+ gen_speed = max_new_tokens / gen_time if gen_time > 0 else 0
361
+
362
+ # Print the response
363
+ print(f"\nBot: {response}")
364
+ print(f"\n[Generated {len(response)} chars in {gen_time:.2f}s ({gen_speed:.1f} tokens/sec)]")
365
+
366
+ # Add to conversation history
367
+ self.conversation_history.append((user_input, response))
368
+
369
+ # Keep history at a reasonable size
370
+ if len(self.conversation_history) > 10:
371
+ self.conversation_history = self.conversation_history[-10:]
372
+
373
+ except Exception as e:
374
+ logger.error(f"Error generating response: {str(e)}")
375
+ print("\nBot: I encountered an error while generating a response. Let's try again.")
376
+
377
+
378
+ def main():
379
+ """Run the ONNX chatbot with command line arguments."""
380
+ parser = argparse.ArgumentParser(description="Interactive ONNX Chatbot")
381
+ parser.add_argument("--model", type=str, required=True,
382
+ help="Path to the ONNX model directory")
383
+ parser.add_argument("--temperature", type=float, default=0.7,
384
+ help="Temperature for text generation (default: 0.7)")
385
+ parser.add_argument("--max_tokens", type=int, default=100,
386
+ help="Maximum tokens to generate per response (default: 100)")
387
+
388
+ args = parser.parse_args()
389
+
390
+ try:
391
+ # Create and run the chatbot
392
+ chatbot = ONNXGenerationChatbot(args.model)
393
+ chatbot.chat(temperature=args.temperature, max_new_tokens=args.max_tokens)
394
+ except KeyboardInterrupt:
395
+ print("\nExiting chatbot. Goodbye!")
396
+ except Exception as e:
397
+ logger.error(f"Error: {str(e)}")
398
+ sys.exit(1)
399
+
400
+
401
+ if __name__ == "__main__":
402
+ main()
onnx_models/bloom_onnx/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "_name_or_path": "bigscience/bloom-560m",
4
+ "apply_residual_connection_post_layernorm": false,
5
+ "architectures": [
6
+ "BloomForCausalLM"
7
+ ],
8
+ "attention_dropout": 0.0,
9
+ "attention_softmax_in_fp32": true,
10
+ "bias_dropout_fusion": true,
11
+ "bos_token_id": 1,
12
+ "eos_token_id": 2,
13
+ "hidden_dropout": 0.0,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.02,
16
+ "layer_norm_epsilon": 1e-05,
17
+ "masked_softmax_fusion": true,
18
+ "model_type": "bloom",
19
+ "n_head": 16,
20
+ "n_inner": null,
21
+ "n_layer": 24,
22
+ "offset_alibi": 100,
23
+ "pad_token_id": 3,
24
+ "pretraining_tp": 1,
25
+ "skip_bias_add": true,
26
+ "skip_bias_add_qkv": false,
27
+ "slow_but_exact": false,
28
+ "transformers_version": "4.48.3",
29
+ "unk_token_id": 0,
30
+ "use_cache": true,
31
+ "vocab_size": 250880
32
+ }
onnx_models/bloom_onnx/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 3,
6
+ "transformers_version": "4.48.3"
7
+ }
onnx_models/bloom_onnx/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:268cdaf473da19cc5cb7f1c0eef597e3719dc88524ebc4e78b51268cfcdb8d28
3
+ size 798372
onnx_models/bloom_onnx/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<pad>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
onnx_models/bloom_onnx/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d963066d6adae5034a1dc114c3ac444512de09928cf14ed4562ba94d9a440e66
3
+ size 21763085
onnx_models/bloom_onnx/tokenizer_config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<unk>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "<s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "<pad>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ }
36
+ },
37
+ "bos_token": "<s>",
38
+ "clean_up_tokenization_spaces": false,
39
+ "eos_token": "</s>",
40
+ "extra_special_tokens": {},
41
+ "merges_file": null,
42
+ "model_max_length": 1000000000000000019884624838656,
43
+ "pad_token": "<pad>",
44
+ "padding_side": "left",
45
+ "tokenizer_class": "BloomTokenizer",
46
+ "unk_token": "<unk>",
47
+ "vocab_file": null
48
+ }
onnx_models/bloom_onnx_quantized/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "_name_or_path": "onnx_models/bloom_onnx",
4
+ "apply_residual_connection_post_layernorm": false,
5
+ "architectures": [
6
+ "BloomForCausalLM"
7
+ ],
8
+ "attention_dropout": 0.0,
9
+ "attention_softmax_in_fp32": true,
10
+ "bias_dropout_fusion": true,
11
+ "bos_token_id": 1,
12
+ "eos_token_id": 2,
13
+ "hidden_dropout": 0.0,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.02,
16
+ "layer_norm_epsilon": 1e-05,
17
+ "masked_softmax_fusion": true,
18
+ "model_type": "bloom",
19
+ "n_head": 16,
20
+ "n_inner": null,
21
+ "n_layer": 24,
22
+ "offset_alibi": 100,
23
+ "pad_token_id": 3,
24
+ "pretraining_tp": 1,
25
+ "skip_bias_add": true,
26
+ "skip_bias_add_qkv": false,
27
+ "slow_but_exact": false,
28
+ "transformers_version": "4.48.3",
29
+ "unk_token_id": 0,
30
+ "use_cache": true,
31
+ "vocab_size": 250880
32
+ }
onnx_models/bloom_onnx_quantized/model_quantized.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:179e57ab6bb5a39b3feef242d2d569aa321f8f1461ec5247c1bd980444b07419
3
+ size 561463713
onnx_models/bloom_onnx_quantized/ort_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "one_external_file": true,
3
+ "opset": null,
4
+ "optimization": {},
5
+ "quantization": {
6
+ "activations_dtype": "QUInt8",
7
+ "activations_symmetric": false,
8
+ "format": "QOperator",
9
+ "is_static": false,
10
+ "mode": "IntegerOps",
11
+ "nodes_to_exclude": [],
12
+ "nodes_to_quantize": [],
13
+ "operators_to_quantize": [
14
+ "Conv",
15
+ "MatMul",
16
+ "Attention",
17
+ "LSTM",
18
+ "Gather",
19
+ "Transpose",
20
+ "EmbedLayerNormalization"
21
+ ],
22
+ "per_channel": false,
23
+ "qdq_add_pair_to_weight": false,
24
+ "qdq_dedicated_pair": false,
25
+ "qdq_op_type_per_channel_support_to_axis": {
26
+ "MatMul": 1
27
+ },
28
+ "reduce_range": false,
29
+ "weights_dtype": "QInt8",
30
+ "weights_symmetric": true
31
+ },
32
+ "use_external_data_format": false
33
+ }
onnx_models/bloom_onnx_quantized/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<pad>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
onnx_models/bloom_onnx_quantized/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d963066d6adae5034a1dc114c3ac444512de09928cf14ed4562ba94d9a440e66
3
+ size 21763085
onnx_models/bloom_onnx_quantized/tokenizer_config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<unk>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "<s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "<pad>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ }
36
+ },
37
+ "bos_token": "<s>",
38
+ "clean_up_tokenization_spaces": false,
39
+ "eos_token": "</s>",
40
+ "extra_special_tokens": {},
41
+ "merges_file": null,
42
+ "model_max_length": 1000000000000000019884624838656,
43
+ "pad_token": "<pad>",
44
+ "padding_side": "left",
45
+ "tokenizer_class": "BloomTokenizer",
46
+ "unk_token": "<unk>",
47
+ "vocab_file": null
48
+ }
onnx_models/falcon_onnx/config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "_name_or_path": "tiiuae/falcon-rw-1b",
4
+ "activation": "gelu",
5
+ "alibi": true,
6
+ "apply_residual_connection_post_layernorm": false,
7
+ "architectures": [
8
+ "FalconForCausalLM"
9
+ ],
10
+ "attention_dropout": 0.0,
11
+ "auto_map": {
12
+ "AutoConfig": "tiiuae/falcon-rw-1b--configuration_falcon.FalconConfig",
13
+ "AutoModel": "tiiuae/falcon-rw-1b--modeling_falcon.FalconModel",
14
+ "AutoModelForCausalLM": "tiiuae/falcon-rw-1b--modeling_falcon.FalconForCausalLM",
15
+ "AutoModelForQuestionAnswering": "tiiuae/falcon-rw-1b--modeling_falcon.FalconForQuestionAnswering",
16
+ "AutoModelForSequenceClassification": "tiiuae/falcon-rw-1b--modeling_falcon.FalconForSequenceClassification",
17
+ "AutoModelForTokenClassification": "tiiuae/falcon-rw-1b--modeling_falcon.FalconForTokenClassification"
18
+ },
19
+ "bias": true,
20
+ "bos_token_id": 1,
21
+ "eos_token_id": 2,
22
+ "ffn_hidden_size": 8192,
23
+ "hidden_dropout": 0.0,
24
+ "hidden_size": 2048,
25
+ "initializer_range": 0.02,
26
+ "layer_norm_epsilon": 1e-05,
27
+ "max_position_embeddings": 2048,
28
+ "model_type": "falcon",
29
+ "multi_query": false,
30
+ "new_decoder_architecture": false,
31
+ "num_attention_heads": 32,
32
+ "num_hidden_layers": 24,
33
+ "num_kv_heads": 32,
34
+ "num_ln_in_parallel_attn": null,
35
+ "parallel_attn": false,
36
+ "rope_scaling": null,
37
+ "rope_theta": 10000.0,
38
+ "transformers_version": "4.48.3",
39
+ "use_cache": true,
40
+ "vocab_size": 50304
41
+ }
onnx_models/falcon_onnx/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.48.3"
6
+ }
onnx_models/falcon_onnx/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
onnx_models/falcon_onnx/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c26d8b62a099f87043745987be680556c9a7c0324944af78d58b7f8559b73c17
3
+ size 655121
onnx_models/falcon_onnx/special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
onnx_models/falcon_onnx/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
onnx_models/falcon_onnx/tokenizer_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": true,
15
+ "eos_token": "<|endoftext|>",
16
+ "extra_special_tokens": {},
17
+ "model_max_length": 1024,
18
+ "tokenizer_class": "GPT2Tokenizer",
19
+ "unk_token": "<|endoftext|>"
20
+ }
onnx_models/falcon_onnx/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
onnx_models/gpt2_onnx/config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "_name_or_path": "gpt2-medium",
4
+ "activation_function": "gelu_new",
5
+ "architectures": [
6
+ "GPT2LMHeadModel"
7
+ ],
8
+ "attn_pdrop": 0.1,
9
+ "bos_token_id": 50256,
10
+ "embd_pdrop": 0.1,
11
+ "eos_token_id": 50256,
12
+ "initializer_range": 0.02,
13
+ "layer_norm_epsilon": 1e-05,
14
+ "model_type": "gpt2",
15
+ "n_ctx": 1024,
16
+ "n_embd": 1024,
17
+ "n_head": 16,
18
+ "n_inner": null,
19
+ "n_layer": 24,
20
+ "n_positions": 1024,
21
+ "n_special": 0,
22
+ "predict_special_tokens": true,
23
+ "reorder_and_upcast_attn": false,
24
+ "resid_pdrop": 0.1,
25
+ "scale_attn_by_inverse_layer_idx": false,
26
+ "scale_attn_weights": true,
27
+ "summary_activation": null,
28
+ "summary_first_dropout": 0.1,
29
+ "summary_proj_to_labels": true,
30
+ "summary_type": "cls_index",
31
+ "summary_use_proj": true,
32
+ "task_specific_params": {
33
+ "text-generation": {
34
+ "do_sample": true,
35
+ "max_length": 50
36
+ }
37
+ },
38
+ "transformers_version": "4.48.3",
39
+ "use_cache": true,
40
+ "vocab_size": 50257
41
+ }
onnx_models/gpt2_onnx/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "transformers_version": "4.48.3"
6
+ }
onnx_models/gpt2_onnx/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
onnx_models/gpt2_onnx/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7c266101cd0fb1a383a3006369c88384d5f76081ec1b9a4e76ff4b7bc15ffe6
3
+ size 1420150742
onnx_models/gpt2_onnx/special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "unk_token": "<|endoftext|>"
5
+ }
onnx_models/gpt2_onnx/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
onnx_models/gpt2_onnx/tokenizer_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": false,
15
+ "eos_token": "<|endoftext|>",
16
+ "extra_special_tokens": {},
17
+ "model_max_length": 1024,
18
+ "tokenizer_class": "GPT2Tokenizer",
19
+ "unk_token": "<|endoftext|>"
20
+ }
onnx_models/gpt2_onnx/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
onnx_models/gpt2_onnx_quantized/config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "_name_or_path": "onnx_models/gpt2_onnx",
4
+ "activation_function": "gelu_new",
5
+ "architectures": [
6
+ "GPT2LMHeadModel"
7
+ ],
8
+ "attn_pdrop": 0.1,
9
+ "bos_token_id": 50256,
10
+ "embd_pdrop": 0.1,
11
+ "eos_token_id": 50256,
12
+ "initializer_range": 0.02,
13
+ "layer_norm_epsilon": 1e-05,
14
+ "model_type": "gpt2",
15
+ "n_ctx": 1024,
16
+ "n_embd": 1024,
17
+ "n_head": 16,
18
+ "n_inner": null,
19
+ "n_layer": 24,
20
+ "n_positions": 1024,
21
+ "n_special": 0,
22
+ "predict_special_tokens": true,
23
+ "reorder_and_upcast_attn": false,
24
+ "resid_pdrop": 0.1,
25
+ "scale_attn_by_inverse_layer_idx": false,
26
+ "scale_attn_weights": true,
27
+ "summary_activation": null,
28
+ "summary_first_dropout": 0.1,
29
+ "summary_proj_to_labels": true,
30
+ "summary_type": "cls_index",
31
+ "summary_use_proj": true,
32
+ "task_specific_params": {
33
+ "text-generation": {
34
+ "do_sample": true,
35
+ "max_length": 50
36
+ }
37
+ },
38
+ "transformers_version": "4.48.3",
39
+ "use_cache": true,
40
+ "vocab_size": 50257
41
+ }
onnx_models/gpt2_onnx_quantized/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
onnx_models/gpt2_onnx_quantized/model_quantized.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b051bd6632d5039e281be589c65e56abc86861340f41d044f49f496afe35aa07
3
+ size 357201134
onnx_models/gpt2_onnx_quantized/ort_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "one_external_file": true,
3
+ "opset": null,
4
+ "optimization": {},
5
+ "quantization": {
6
+ "activations_dtype": "QUInt8",
7
+ "activations_symmetric": false,
8
+ "format": "QOperator",
9
+ "is_static": false,
10
+ "mode": "IntegerOps",
11
+ "nodes_to_exclude": [],
12
+ "nodes_to_quantize": [],
13
+ "operators_to_quantize": [
14
+ "Conv",
15
+ "MatMul",
16
+ "Attention",
17
+ "LSTM",
18
+ "Gather",
19
+ "Transpose",
20
+ "EmbedLayerNormalization"
21
+ ],
22
+ "per_channel": false,
23
+ "qdq_add_pair_to_weight": false,
24
+ "qdq_dedicated_pair": false,
25
+ "qdq_op_type_per_channel_support_to_axis": {
26
+ "MatMul": 1
27
+ },
28
+ "reduce_range": false,
29
+ "weights_dtype": "QInt8",
30
+ "weights_symmetric": true
31
+ },
32
+ "use_external_data_format": false
33
+ }
onnx_models/gpt2_onnx_quantized/special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
onnx_models/gpt2_onnx_quantized/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
onnx_models/gpt2_onnx_quantized/tokenizer_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": false,
15
+ "eos_token": "<|endoftext|>",
16
+ "extra_special_tokens": {},
17
+ "model_max_length": 1024,
18
+ "tokenizer_class": "GPT2Tokenizer",
19
+ "unk_token": "<|endoftext|>"
20
+ }
onnx_models/gpt2_onnx_quantized/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
onnx_models/opt_onnx/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "_name_or_path": "facebook/opt-350m",
4
+ "_remove_final_layer_norm": false,
5
+ "activation_dropout": 0.0,
6
+ "activation_function": "relu",
7
+ "architectures": [
8
+ "OPTForCausalLM"
9
+ ],
10
+ "attention_dropout": 0.0,
11
+ "bos_token_id": 2,
12
+ "do_layer_norm_before": false,
13
+ "dropout": 0.1,
14
+ "enable_bias": true,
15
+ "eos_token_id": 2,
16
+ "ffn_dim": 4096,
17
+ "hidden_size": 1024,
18
+ "init_std": 0.02,
19
+ "layer_norm_elementwise_affine": true,
20
+ "layerdrop": 0.0,
21
+ "max_position_embeddings": 2048,
22
+ "model_type": "opt",
23
+ "num_attention_heads": 16,
24
+ "num_hidden_layers": 24,
25
+ "pad_token_id": 1,
26
+ "prefix": "</s>",
27
+ "transformers_version": "4.48.3",
28
+ "use_cache": true,
29
+ "vocab_size": 50272,
30
+ "word_embed_proj_dim": 512
31
+ }
onnx_models/opt_onnx/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 2,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 1,
6
+ "transformers_version": "4.48.3"
7
+ }
onnx_models/opt_onnx/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
onnx_models/opt_onnx/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33dbbbda8ead8a71ec8ad090902faadf3b292e64f4641087efe17996e9b85aa9
3
+ size 1325122848
onnx_models/opt_onnx/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "</s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<pad>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "</s>",
25
+ "lstrip": false,
26
+ "normalized": true,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }