raksa-the-wildcats Claude commited on
Commit
21ba99f
·
1 Parent(s): ee78b3d

Update Scholar Express with comprehensive Gradio app

Browse files

- Add main app.py with DOLPHIN and Gemma 3n integration
- Include PDF processing, chat, and voice features
- Clean up demo files and deployment configs
- Update requirements for Hugging Face Spaces

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. gradio_add_voice.py → app.py +13 -23
  3. assets/demo.gif +0 -3
  4. assets/dolphin.png +0 -3
  5. assets/framework.png +0 -3
  6. chat.py +0 -5
  7. demo/.DS_Store +0 -0
  8. demo/element_imgs/.DS_Store +0 -0
  9. demo/element_imgs/block_formula.jpeg +0 -3
  10. demo/element_imgs/line_formula.jpeg +0 -3
  11. demo/element_imgs/markdown/.DS_Store +0 -0
  12. demo/element_imgs/markdown/table_1.md +0 -2
  13. demo/element_imgs/para_1.jpg +0 -3
  14. demo/element_imgs/para_2.jpg +0 -3
  15. demo/element_imgs/para_3.jpeg +0 -3
  16. demo/element_imgs/recognition_json/table_1.json +0 -6
  17. demo/element_imgs/table_1.jpeg +0 -3
  18. demo/element_imgs/table_2.jpeg +0 -3
  19. demo/page_imgs/.DS_Store +0 -0
  20. demo/page_imgs/markdown/.DS_Store +0 -0
  21. demo/page_imgs/markdown/figures/.DS_Store +0 -0
  22. demo/page_imgs/markdown/figures/test_page3_figure_000.png +0 -3
  23. demo/page_imgs/markdown/test_page3.md +0 -22
  24. demo/page_imgs/page_1.jpeg +0 -3
  25. demo/page_imgs/page_2.jpeg +0 -3
  26. demo/page_imgs/page_3.jpeg +0 -3
  27. demo/page_imgs/page_4.png +0 -3
  28. demo/page_imgs/page_5.jpg +0 -3
  29. demo/page_imgs/page_6.pdf +0 -0
  30. demo/page_imgs/page_7.jpeg +0 -3
  31. demo/page_imgs/recognition_json/page_1.json +0 -178
  32. demo/page_imgs/recognition_json/test_page.json +0 -47
  33. demo/page_imgs/recognition_json/test_page2.json +0 -102
  34. demo/page_imgs/recognition_json/test_page3.json +0 -124
  35. demo/page_imgs/test_page2.jpeg +0 -3
  36. demo/page_imgs/test_page3.jpeg +0 -3
  37. demo_element.py +0 -129
  38. demo_element_hf.py +0 -5
  39. demo_page.py +0 -247
  40. demo_page_hf.py +0 -5
  41. deployment/ReadMe.md +0 -12
  42. deployment/tensorrt_llm/ReadMe.md +0 -89
  43. deployment/tensorrt_llm/api_client.py +0 -100
  44. deployment/tensorrt_llm/api_server.py +0 -112
  45. deployment/tensorrt_llm/convert/__init__.py +0 -0
  46. deployment/tensorrt_llm/convert/build_visual_engine.py +0 -14
  47. deployment/tensorrt_llm/convert/convert_checkpoint.py +0 -1528
  48. deployment/tensorrt_llm/convert/helper.py +0 -95
  49. deployment/tensorrt_llm/convert_dolphin.sh +0 -47
  50. deployment/tensorrt_llm/dolphin_runner.py +0 -220
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
gradio_add_voice.py → app.py RENAMED
@@ -1,9 +1,3 @@
1
- """
2
- DOLPHIN PDF Document AI - Local Gemma 3n Version
3
- Optimized for powerful GPU deployment with local models
4
- Features: AI-generated alt text for accessibility using local Gemma 3n
5
- """
6
-
7
  import gradio as gr
8
  import json
9
  import markdown
@@ -51,14 +45,18 @@ try:
51
  except ImportError:
52
  pass
53
 
54
- # Warm up voice model if available
55
- if VOICE_DEPENDENCIES_AVAILABLE and voice_model:
 
56
  try:
 
 
57
  print("Warming up voice model...")
58
  voice_model.warm_up()
59
- print("✅ Voice model warmed up successfully")
60
  except Exception as e:
61
- print(f"⚠️ Voice model warm-up failed: {e}")
 
62
 
63
 
64
  class DOLPHIN:
@@ -520,16 +518,7 @@ OUT_RATE = 24000
520
  OUT_SAMPLE_WIDTH = 2
521
  OUT_CHUNK = 20 * 4096
522
 
523
- # Initialize voice inference model if available
524
- voice_model = None
525
- if VOICE_DEPENDENCIES_AVAILABLE:
526
- try:
527
- print("Loading voice model for Talk with Gemma...")
528
- voice_model = Gemma3nInference(device='cuda' if torch.cuda.is_available() else 'cpu')
529
- print("✅ Voice model loaded successfully")
530
- except Exception as e:
531
- print(f"❌ Error loading voice model: {e}")
532
- VOICE_DEPENDENCIES_AVAILABLE = False
533
 
534
  @dataclass
535
  class VoiceAppState:
@@ -643,8 +632,8 @@ def generate_voice_response(state: VoiceAppState):
643
  audio_array = audio_array.reshape((-1, 2))
644
 
645
  # Update conversation history
646
- state.conversation.append({"role": "user", "content": {"path": temp_audio_path, "mime_type": "audio/wav"}})
647
- state.conversation.append({"role": "assistant", "content": {"text": text_response}})
648
 
649
  return (audio_segment.frame_rate, audio_array), VoiceAppState(conversation=state.conversation)
650
 
@@ -697,7 +686,7 @@ def create_embeddings(chunks):
697
  def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
698
  """Retrieve most relevant chunks for a question"""
699
  if embedding_model is None or embeddings is None:
700
- return chunks[:3] # Fallback to first 3 chunks
701
 
702
  try:
703
  question_embedding = embedding_model.encode([question], show_progress_bar=False)
@@ -892,6 +881,7 @@ with gr.Blocks(
892
  chatbot = gr.Chatbot(
893
  value=[],
894
  height=500,
 
895
  elem_classes="chatbot-container",
896
  placeholder="Your conversation will appear here once you process a document..."
897
  )
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import json
3
  import markdown
 
45
  except ImportError:
46
  pass
47
 
48
+ # Initialize voice model early to avoid NameError
49
+ voice_model = None
50
+ if VOICE_DEPENDENCIES_AVAILABLE:
51
  try:
52
+ print("Loading voice model...")
53
+ voice_model = Gemma3nInference(device='cuda' if torch.cuda.is_available() else 'cpu')
54
  print("Warming up voice model...")
55
  voice_model.warm_up()
56
+ print("✅ Voice model loaded and warmed up successfully")
57
  except Exception as e:
58
+ print(f"⚠️ Voice model initialization failed: {e}")
59
+ voice_model = None
60
 
61
 
62
  class DOLPHIN:
 
518
  OUT_SAMPLE_WIDTH = 2
519
  OUT_CHUNK = 20 * 4096
520
 
521
+ # Voice model already initialized earlier in the file
 
 
 
 
 
 
 
 
 
522
 
523
  @dataclass
524
  class VoiceAppState:
 
632
  audio_array = audio_array.reshape((-1, 2))
633
 
634
  # Update conversation history
635
+ state.conversation.append({"role": "user", "content": f"[Audio message]"})
636
+ state.conversation.append({"role": "assistant", "content": text_response})
637
 
638
  return (audio_segment.frame_rate, audio_array), VoiceAppState(conversation=state.conversation)
639
 
 
686
  def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
687
  """Retrieve most relevant chunks for a question"""
688
  if embedding_model is None or embeddings is None:
689
+ return chunks[:3] =
690
 
691
  try:
692
  question_embedding = embedding_model.encode([question], show_progress_bar=False)
 
881
  chatbot = gr.Chatbot(
882
  value=[],
883
  height=500,
884
+ type='messages',
885
  elem_classes="chatbot-container",
886
  placeholder="Your conversation will appear here once you process a document..."
887
  )
assets/demo.gif DELETED

Git LFS Details

  • SHA256: 003bcda91af8e23c007d6d1c5e23bee177c5735c7ba914b9ee33670829d59a2c
  • Pointer size: 132 Bytes
  • Size of remote file: 3.23 MB
assets/dolphin.png DELETED

Git LFS Details

  • SHA256: 3f462bb6eaf6cf9ba02caa04966ec354e1352f2cb1ac3e03ead082a0ba725170
  • Pointer size: 130 Bytes
  • Size of remote file: 83.3 kB
assets/framework.png DELETED

Git LFS Details

  • SHA256: f23f47c5ec092369a0707fa6e82ec4dd03ed10044b00ef10aff5f7c89570187e
  • Pointer size: 132 Bytes
  • Size of remote file: 2 MB
chat.py CHANGED
@@ -1,8 +1,3 @@
1
- """
2
- Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
- SPDX-License-Identifier: MIT
4
- """
5
-
6
  import os
7
  import warnings
8
  from collections import OrderedDict
 
 
 
 
 
 
1
  import os
2
  import warnings
3
  from collections import OrderedDict
demo/.DS_Store DELETED
Binary file (6.15 kB)
 
demo/element_imgs/.DS_Store DELETED
Binary file (6.15 kB)
 
demo/element_imgs/block_formula.jpeg DELETED

Git LFS Details

  • SHA256: 5dc9c328d058816ef31d878a0d42f0751606afd3b77854057910a81451dae1b4
  • Pointer size: 130 Bytes
  • Size of remote file: 92.5 kB
demo/element_imgs/line_formula.jpeg DELETED

Git LFS Details

  • SHA256: 65e2be8cc82c609364e1f921cacb822213f0ca2eafd86f5721b6f0499ceb8712
  • Pointer size: 130 Bytes
  • Size of remote file: 55.3 kB
demo/element_imgs/markdown/.DS_Store DELETED
Binary file (6.15 kB)
 
demo/element_imgs/markdown/table_1.md DELETED
@@ -1,2 +0,0 @@
1
- <table><tr><td></td><td></td><td>100-class (top-1 acc.)</td><td>1000-class (top-1 acc.)</td></tr><tr><td colspan="2">4096-d (float)</td><td>77.1 ± 1.5</td><td>65.0</td></tr><tr><td rowspan="3">1024 bits</td><td>BP</td><td>72.9 ± 1.3</td><td>58.1</td></tr><tr><td>CBE</td><td>73.0 ± 1.3</td><td>59.2</td></tr><tr><td>SP</td><td>73.8 ± 1.3</td><td>60.1</td></tr><tr><td rowspan="4">4096 bits</td><td>threshold [1]</td><td>73.5 ± 1.4</td><td>59.1</td></tr><tr><td>BP</td><td>76.0 ± 1.5</td><td>63.2</td></tr><tr><td>CBE</td><td>75.9 ± 1.4</td><td>63.0</td></tr><tr><td>SP</td><td>76.3 ± 1.5</td><td>63.3</td></tr><tr><td>8192 bits</td><td>SP</td><td>76.8 ± 1.4</td><td>64.2</td></tr><tr><td>16384 bits</td><td>SP</td><td>77.1 ± 1.6</td><td>64.5</td></tr></table>
2
-
 
 
 
demo/element_imgs/para_1.jpg DELETED

Git LFS Details

  • SHA256: 68308a404e8e4c111f5cc1568e7f4b74f1f0c08ad4485e2ad9e78869f79a556b
  • Pointer size: 130 Bytes
  • Size of remote file: 18.7 kB
demo/element_imgs/para_2.jpg DELETED

Git LFS Details

  • SHA256: 8d9eda1c71490b76ac5d3ef33f436fb6e6db4ca3b625d5d74f35c3b248949c56
  • Pointer size: 130 Bytes
  • Size of remote file: 69.8 kB
demo/element_imgs/para_3.jpeg DELETED

Git LFS Details

  • SHA256: b372541d80263c5508b8b85ccf847123874efdb4c25473845fbf042f2d9cc5a9
  • Pointer size: 130 Bytes
  • Size of remote file: 84 kB
demo/element_imgs/recognition_json/table_1.json DELETED
@@ -1,6 +0,0 @@
1
- [
2
- {
3
- "label": "tab",
4
- "text": "<table><tr><td></td><td></td><td>100-class (top-1 acc.)</td><td>1000-class (top-1 acc.)</td></tr><tr><td colspan=\"2\">4096-d (float)</td><td>77.1 ± 1.5</td><td>65.0</td></tr><tr><td rowspan=\"3\">1024 bits</td><td>BP</td><td>72.9 ± 1.3</td><td>58.1</td></tr><tr><td>CBE</td><td>73.0 ± 1.3</td><td>59.2</td></tr><tr><td>SP</td><td>73.8 ± 1.3</td><td>60.1</td></tr><tr><td rowspan=\"4\">4096 bits</td><td>threshold [1]</td><td>73.5 ± 1.4</td><td>59.1</td></tr><tr><td>BP</td><td>76.0 ± 1.5</td><td>63.2</td></tr><tr><td>CBE</td><td>75.9 ± 1.4</td><td>63.0</td></tr><tr><td>SP</td><td>76.3 ± 1.5</td><td>63.3</td></tr><tr><td>8192 bits</td><td>SP</td><td>76.8 ± 1.4</td><td>64.2</td></tr><tr><td>16384 bits</td><td>SP</td><td>77.1 ± 1.6</td><td>64.5</td></tr></table>"
5
- }
6
- ]
 
 
 
 
 
 
 
demo/element_imgs/table_1.jpeg DELETED

Git LFS Details

  • SHA256: 1ccce9dab1a1b537ae502183f461ad3331a2b9eeb8574790e6ec43ca54f24e2c
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB
demo/element_imgs/table_2.jpeg DELETED

Git LFS Details

  • SHA256: 3fdc67f4bb8afee58ff4ee84412581deb771cf26f5fe9eead742108700e9650e
  • Pointer size: 131 Bytes
  • Size of remote file: 406 kB
demo/page_imgs/.DS_Store DELETED
Binary file (8.2 kB)
 
demo/page_imgs/markdown/.DS_Store DELETED
Binary file (6.15 kB)
 
demo/page_imgs/markdown/figures/.DS_Store DELETED
Binary file (6.15 kB)
 
demo/page_imgs/markdown/figures/test_page3_figure_000.png DELETED

Git LFS Details

  • SHA256: eba97bcb2eefbc653f4b5db7572799a9674b8fd39e5f14d261c33e1916a9f009
  • Pointer size: 130 Bytes
  • Size of remote file: 63.4 kB
demo/page_imgs/markdown/test_page3.md DELETED
@@ -1,22 +0,0 @@
1
- ![Figure](figures/test_page3_figure_000.png)
2
-
3
- Figure 2: (left) Scaled Dot-Product Attention. (right) Multi-Head Attention consists of several attention layers running in parallel.
4
-
5
- query with all keys, divide each by $\sqrt{d_k}$ , and apply a softmax function to obtain the weights on the values.
6
-
7
- In practice, we compute the attention function on a set of queries simultaneously, packed together into a matrix $Q$ . The keys and values are also packed together into matrices $K$ and $V$ . We compute the matrix of outputs as: $$ \\ \text{Attention}(Q, K, V) = \mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V \\ $$
8
-
9
- The two most commonly used attention functions are additive attention [2] , and dot-product (multiplicative) attention. Dot-product attention is identical to our algorithm, except for the scaling factor of $\frac{1}{\sqrt{d_k}}$ . Additive attention computes the compatibility function using a feed-forward network with a single hidden layer. While the two are similar in theoretical complexity, dot-product attention is much faster and more space-efficient in practice, since it can be implemented using highly optimized matrix multiplication code.
10
-
11
- While for small values of $d_k$ the two mechanisms perform similarly, additive attention outperforms dot product attention without scaling for larger values of $d_k$ [ 3 ] . We suspect that for large values of $d_k$ , the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients 4 To counteract this effect, we scale the dot products by $\frac{1}{\sqrt{d_k}}$ .
12
-
13
- 3.2.2 Multi-Head Attention
14
-
15
- Instead of performing a single attention function with $d_{\text{model}}$ -dimensional keys, values and queries, we found it beneficial to linearly project the queries, keys and values $h$ times with different, learned linear projections to $d_k$ , $d_k$ and $d_v$ dimensions, respectively. On each of these projected versions of queries, keys and values we then perform the attention function in parallel, yielding $d_v$ -dimensional output values. These are concatenated and once again projected, resulting in the final values, as depicted in Figure 2 .
16
-
17
- Multi­head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.
18
-
19
- ${ }^{4}$ To illustrate why the dot products get large, assume that the components of $q$ and $k$ are independent random variables with mean 0 and variance 1 . Then their dot product, $q \cdot k=\sum_{i=1}^{d_{k}} q_{i} k_{i}$, has mean 0 and variance $d_{k}$.
20
-
21
- 4
22
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/page_imgs/page_1.jpeg DELETED

Git LFS Details

  • SHA256: aba4e06f5debeb14a59a193818f6787aa06f17f4cb21c0e483d8267f5397b627
  • Pointer size: 132 Bytes
  • Size of remote file: 1.52 MB
demo/page_imgs/page_2.jpeg DELETED

Git LFS Details

  • SHA256: 25e08746f10d4472d80659869eb73a477ad665d7aaaa850e70aae1bd6076d826
  • Pointer size: 132 Bytes
  • Size of remote file: 1.47 MB
demo/page_imgs/page_3.jpeg DELETED

Git LFS Details

  • SHA256: fe6e35a3c888c77ec36cf48cb762556e489e288d30a457a353ac6bba6fab9251
  • Pointer size: 131 Bytes
  • Size of remote file: 449 kB
demo/page_imgs/page_4.png DELETED

Git LFS Details

  • SHA256: 497cdabe38a4db8318284c0f8963304a876ceceebb796059903703834e4713ed
  • Pointer size: 131 Bytes
  • Size of remote file: 372 kB
demo/page_imgs/page_5.jpg DELETED

Git LFS Details

  • SHA256: 17cdc261fcd7eb8db4a0bdfb56dc2b1f77c8890956f8451f810695e115f6f894
  • Pointer size: 131 Bytes
  • Size of remote file: 641 kB
demo/page_imgs/page_6.pdf DELETED
The diff for this file is too large to render. See raw diff
 
demo/page_imgs/page_7.jpeg DELETED

Git LFS Details

  • SHA256: 19bb9afdb859e905e017fc3d3bac6da0490093811820529f285a20e8d70609f2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
demo/page_imgs/recognition_json/page_1.json DELETED
@@ -1,178 +0,0 @@
1
- [
2
- {
3
- "label": "title",
4
- "bbox": [
5
- 271,
6
- 188,
7
- 1194,
8
- 221
9
- ],
10
- "text": "LLaMA: Open and Efficient Foundation Language Models",
11
- "reading_order": 0
12
- },
13
- {
14
- "label": "author",
15
- "bbox": [
16
- 313,
17
- 272,
18
- 1154,
19
- 317
20
- ],
21
- "text": "Hugo Touvron; Thibaut Lavril*, Gautier Izacard*, Xavier Martinet",
22
- "reading_order": 1
23
- },
24
- {
25
- "label": "para",
26
- "bbox": [
27
- 269,
28
- 317,
29
- 1201,
30
- 425
31
- ],
32
- "text": "Marie-Anne Lachaux, Timothee Lacroix, Baptiste Rozière, Naman Goyal\nEric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin\nEdouard Grave*Guillaume Lample*",
33
- "reading_order": 2
34
- },
35
- {
36
- "label": "para",
37
- "bbox": [
38
- 685,
39
- 440,
40
- 795,
41
- 482
42
- ],
43
- "text": "Meta AI",
44
- "reading_order": 3
45
- },
46
- {
47
- "label": "sec",
48
- "bbox": [
49
- 376,
50
- 524,
51
- 502,
52
- 565
53
- ],
54
- "text": "\\begin{abstract}",
55
- "reading_order": 4
56
- },
57
- {
58
- "label": "para",
59
- "bbox": [
60
- 209,
61
- 586,
62
- 675,
63
- 946
64
- ],
65
- "text": "We introduce LLaMA, a collection of founda-\ntion language models ranging from 7B to 65B\nparameters. We train our models on trillions\nof tokens, and show that it is possible to train\nstate-of-the-art models using publicly avail-\nable datasets exclusively, without resorting\nto proprietary and inaccessible datasets. In\nparticular, LLaMA-13B outperforms GPT-3\n(175B) on most benchmarks, and LLaMA-\n65B is competitive with the best models,\nChinchilla-70B and PaLM-540B. We release\nall our models to the research community $^1$ .",
66
- "reading_order": 5
67
- },
68
- {
69
- "label": "sec",
70
- "bbox": [
71
- 167,
72
- 964,
73
- 376,
74
- 1006
75
- ],
76
- "text": "1 Introduction",
77
- "reading_order": 6
78
- },
79
- {
80
- "label": "para",
81
- "bbox": [
82
- 167,
83
- 1027,
84
- 718,
85
- 1498
86
- ],
87
- "text": "Large Languages Models (LLMs) trained on mas-\nsive corpora of texts have shown their ability to per-\nform new tasks from textual instructions or from a\nfew examples ( Brown et al. , 2020 ) . These few-shot\nproperties first appeared when scaling models to a\nsufficient size ( Kaplan et al. , 2020 ) , resulting in a\nline of work that focuses on further scaling these\nmodels ( Chowdhery et al. , 2022 ; Rae et al. , 2021 ) .\nThese efforts are based on the assumption that\nmore parameters will lead to better performance.\nHowever, recent work from Hoffmann et al. ( 2022 )\nshows that, for a given compute budget, the best\nperformances are not achieved by the largest mod-\nels, but by smaller models trained on more data.",
88
- "reading_order": 7
89
- },
90
- {
91
- "label": "para",
92
- "bbox": [
93
- 167,
94
- 1506,
95
- 717,
96
- 1844
97
- ],
98
- "text": "The objective of the scaling laws from Hoff-\nmann et al. ( 2022 ) is to determine how to best\nscale the dataset and model sizes for a particular\ntraining compute budget. However, this objective\ndisregards the inference budget, which becomes\ncritical when serving a language model at scale.\nIn this context, given a target level of performance,\nthe preferred model is not the fastest to train but the\nfastest at inference, and although it may be cheaper\nto train a large model to reach a certain level of",
99
- "reading_order": 8
100
- },
101
- {
102
- "label": "para",
103
- "bbox": [
104
- 753,
105
- 539,
106
- 1304,
107
- 734
108
- ],
109
- "text": "performance, a smaller one trained longer will\nultimately be cheaper at inference. For instance,\nalthough Hoffmann et al. ( 2022 ) recommends\ntraining a 10B model on 200B tokens, we find\nthat the performance of a 7B model continues to\nimprove even after 1T tokens.",
110
- "reading_order": 9
111
- },
112
- {
113
- "label": "para",
114
- "bbox": [
115
- 753,
116
- 769,
117
- 1305,
118
- 1236
119
- ],
120
- "text": "The focus of this work is to train a series of\nlanguage models that achieve the best possible per-\nformance at various inference budgets, by training\non more tokens than what is typically used. The\nresulting models, called LLaMA , ranges from 7B\nto 65B parameters with competitive performance\ncompared to the best existing LLMs. For instance,\nLLaMA-13B outperforms GPT-3 on most bench-\nmarks, despite being 10 $\\times$ smaller. We believe that\nthis model will help democratize the access and\nstudy of LLMs, since it can be run on a single GPU.\nAt the higher-end of the scale, our 65B-parameter\nmodel is also competitive with the best large lan-\nguage models such as Chinchilla or PaLM-540B.",
121
- "reading_order": 10
122
- },
123
- {
124
- "label": "para",
125
- "bbox": [
126
- 753,
127
- 1257,
128
- 1305,
129
- 1601
130
- ],
131
- "text": "Unlike Chinchilla, PaLM, or GPT-3, we only\nuse publicly available data, making our work com-\npatible with open-sourcing, while most existing\nmodels rely on data which is either not publicly\navailable or undocumented (e.g. “ Books – 2TB ” or\n“ Social media conversations ” ). There exist some\nexceptions, notably OPT ( Zhang et al. , 2022 ) ,\nGPT-NeoX ( Black et al. , 2022 ) , BLOOM ( Scao\net al. , 2022 ) and GLM ( Zeng et al. , 2022 ) , but none\nthat are competitive with PaLM-62B or Chinchilla.",
132
- "reading_order": 11
133
- },
134
- {
135
- "label": "para",
136
- "bbox": [
137
- 753,
138
- 1634,
139
- 1304,
140
- 1933
141
- ],
142
- "text": "In the rest of this paper, we present an overview\nof the modifications we made to the transformer\narchitecture ( Vaswani et al. , 2017 ) , as well as our\ntraining method. We then report the performance of\nour models and compare with others LLMs on a set\nof standard benchmarks. Finally, we expose some\nof the biases and toxicity encoded in our models,\nusing some of the most recent benchmarks from\nthe responsible AI community.",
143
- "reading_order": 12
144
- },
145
- {
146
- "label": "fnote",
147
- "bbox": [
148
- 167,
149
- 1844,
150
- 712,
151
- 1907
152
- ],
153
- "text": "* Equal contribution.\nCorrespondence:\n{htouvron\nthibautlav,gizacard,egrave,glample}@meta.com",
154
- "reading_order": 13
155
- },
156
- {
157
- "label": "fnote",
158
- "bbox": [
159
- 209,
160
- 1907,
161
- 632,
162
- 1931
163
- ],
164
- "text": "https://github.com/facebookresearch/llama",
165
- "reading_order": 14
166
- },
167
- {
168
- "label": "watermark",
169
- "bbox": [
170
- 20,
171
- 649,
172
- 83,
173
- 1530
174
- ],
175
- "text": "arXiv:2302.1397lvl [cs.CL] 27 Feb 2023",
176
- "reading_order": 15
177
- }
178
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/page_imgs/recognition_json/test_page.json DELETED
@@ -1,47 +0,0 @@
1
- [
2
- {
3
- "label": "header",
4
- "bbox": [
5
- 291,
6
- 90,
7
- 675,
8
- 120
9
- ],
10
- "text": "Scaled Dot-Product Attention",
11
- "reading_order": 0
12
- },
13
- {
14
- "label": "fig",
15
- "text": "![Figure](figures/test_page_figure_001.png)",
16
- "figure_path": "figures/test_page_figure_001.png",
17
- "bbox": [
18
- 1274,
19
- 105,
20
- 1536,
21
- 627
22
- ],
23
- "reading_order": 1
24
- },
25
- {
26
- "label": "cap",
27
- "bbox": [
28
- 168,
29
- 719,
30
- 1413,
31
- 789
32
- ],
33
- "text": "Figure 2: (left) Scaled Dot-Product Attention. (right) Multi-Head Attention consists of several\nattention layers running in parallel.",
34
- "reading_order": 2
35
- },
36
- {
37
- "label": "para",
38
- "bbox": [
39
- 168,
40
- 858,
41
- 1413,
42
- 934
43
- ],
44
- "text": "query with all keys, divide each by $\\sqrt{d_{k}}$, and apply a softmax function to obtain the weights on the\nvalues.",
45
- "reading_order": 3
46
- }
47
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/page_imgs/recognition_json/test_page2.json DELETED
@@ -1,102 +0,0 @@
1
- [
2
- {
3
- "label": "fig",
4
- "text": "![Figure](figures/test_page2_figure_000.png)",
5
- "figure_path": "figures/test_page2_figure_000.png",
6
- "bbox": [
7
- 394,
8
- 117,
9
- 897,
10
- 837
11
- ],
12
- "reading_order": 0
13
- },
14
- {
15
- "label": "cap",
16
- "bbox": [
17
- 445,
18
- 852,
19
- 856,
20
- 873
21
- ],
22
- "text": "Figure 1: The Transformer - model architecture",
23
- "reading_order": 1
24
- },
25
- {
26
- "label": "para",
27
- "bbox": [
28
- 218,
29
- 920,
30
- 1086,
31
- 1044
32
- ],
33
- "text": "wise fully connected feed-forward network. We employ a residual connection [ 10 ] around each of\nthe two sub-layers, followed by layer normalization [ 1 ] . That is, the output of each sub-layer is\n$\\mathrm{LayerNorm}(x+\\mathrm{Sublayer}(x))$ , where $\\mathrm{Sublayer}(x)$ is the function implemented by the sub-layer\nitself. To facilitate these residual connections, all sub-layers in the model, as well as the embedding\nlayers, produce outputs of dimension $d_{\\text{model}}=512$ .",
34
- "reading_order": 2
35
- },
36
- {
37
- "label": "para",
38
- "bbox": [
39
- 218,
40
- 1071,
41
- 1085,
42
- 1244
43
- ],
44
- "text": "The The decoder is also composed of a stack of $N=6$ identical layers. In addition to the two\nsub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head\nattention over the output of the encoder stack. Similar to the encoder, we employ residual connections\naround each of the sub-layers, followed by layer normalization. We also modify the self-attention\nsub-layer in the decoder stack to prevent positions from attending to subsequent positions. This\nmasking, combined with fact that the output embeddings are offset by one position, ensures that the\npredictions for position $i$ can depend only on the known outputs at positions less than $i$ .",
45
- "reading_order": 3
46
- },
47
- {
48
- "label": "sub_sec",
49
- "bbox": [
50
- 226,
51
- 1283,
52
- 344,
53
- 1305
54
- ],
55
- "text": "3.2 Attention",
56
- "reading_order": 4
57
- },
58
- {
59
- "label": "para",
60
- "bbox": [
61
- 218,
62
- 1322,
63
- 1087,
64
- 1422
65
- ],
66
- "text": "An attention function can be described as mapping a query and a set of key-value pairs to an output,\nwhere the query, keys, values, and output are all vectors. The output is computed as a weighted sum\nof the values, where the weight assigned to each value is computed by a compatibility function of the\nquery with the corresponding key.",
67
- "reading_order": 5
68
- },
69
- {
70
- "label": "sub_sub_sec",
71
- "bbox": [
72
- 218,
73
- 1456,
74
- 562,
75
- 1474
76
- ],
77
- "text": "3.2.1 Scaled Dot-Product Attention",
78
- "reading_order": 6
79
- },
80
- {
81
- "label": "para",
82
- "bbox": [
83
- 218,
84
- 1498,
85
- 1085,
86
- 1546
87
- ],
88
- "text": "We call our particular attention \"Scaled Dot-Product Attention\" (Figure 2 ). The input consists of\nqueries and keys of dimension $d_k$ , and values of dimension $d_v$ . We compute the dot products of the",
89
- "reading_order": 7
90
- },
91
- {
92
- "label": "foot",
93
- "bbox": [
94
- 646,
95
- 1590,
96
- 662,
97
- 1607
98
- ],
99
- "text": "3",
100
- "reading_order": 8
101
- }
102
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/page_imgs/recognition_json/test_page3.json DELETED
@@ -1,124 +0,0 @@
1
- [
2
- {
3
- "label": "fig",
4
- "text": "![Figure](figures/test_page3_figure_000.png)",
5
- "figure_path": "figures/test_page3_figure_000.png",
6
- "bbox": [
7
- 331,
8
- 134,
9
- 984,
10
- 489
11
- ],
12
- "reading_order": 0
13
- },
14
- {
15
- "label": "cap",
16
- "bbox": [
17
- 198,
18
- 554,
19
- 1065,
20
- 603
21
- ],
22
- "text": "Figure 2: (left) Scaled Dot-Product Attention. (right) Multi-Head Attention consists of several\nattention layers running in parallel.",
23
- "reading_order": 1
24
- },
25
- {
26
- "label": "para",
27
- "bbox": [
28
- 198,
29
- 652,
30
- 1065,
31
- 701
32
- ],
33
- "text": "query with all keys, divide each by $\\sqrt{d_k}$ , and apply a softmax function to obtain the weights on the\nvalues.",
34
- "reading_order": 2
35
- },
36
- {
37
- "label": "para",
38
- "bbox": [
39
- 198,
40
- 715,
41
- 1065,
42
- 881
43
- ],
44
- "text": "In practice, we compute the attention function on a set of queries simultaneously, packed together\ninto a matrix $Q$ . The keys and values are also packed together into matrices $K$ and $V$ . We compute\nthe matrix of outputs as:\n\\[\n \\text{Attention}(Q, K, V) = \\mathrm{softmax}(\\frac{QK^T}{\\sqrt{d_k}})V\n\\]",
45
- "reading_order": 3
46
- },
47
- {
48
- "label": "para",
49
- "bbox": [
50
- 198,
51
- 913,
52
- 1068,
53
- 1060
54
- ],
55
- "text": "The two most commonly used attention functions are additive attention [2] , and dot-product (multi-\nplicative) attention. Dot-product attention is identical to our algorithm, except for the scaling factor\nof $\\frac{1}{\\sqrt{d_k}}$ . Additive attention computes the compatibility function using a feed-forward network with\na single hidden layer. While the two are similar in theoretical complexity, dot-product attention is\nmuch faster and more space-efficient in practice, since it can be implemented using highly optimized\nmatrix multiplication code.",
56
- "reading_order": 4
57
- },
58
- {
59
- "label": "para",
60
- "bbox": [
61
- 198,
62
- 1074,
63
- 1066,
64
- 1175
65
- ],
66
- "text": "While for small values of $d_k$ the two mechanisms perform similarly, additive attention outperforms\ndot product attention without scaling for larger values of $d_k$ [ 3 ] . We suspect that for large values of\n$d_k$ , the dot products grow large in magnitude, pushing the softmax function into regions where it has\nextremely small gradients 4 To counteract this effect, we scale the dot products by $\\frac{1}{\\sqrt{d_k}}$ .",
67
- "reading_order": 5
68
- },
69
- {
70
- "label": "sub_sub_sec",
71
- "bbox": [
72
- 198,
73
- 1207,
74
- 467,
75
- 1225
76
- ],
77
- "text": "3.2.2 Multi-Head Attention",
78
- "reading_order": 6
79
- },
80
- {
81
- "label": "para",
82
- "bbox": [
83
- 198,
84
- 1253,
85
- 1067,
86
- 1395
87
- ],
88
- "text": "Instead of performing a single attention function with $d_{\\text{model}}$ -dimensional keys, values and queries,\nwe found it beneficial to linearly project the queries, keys and values $h$ times with different, learned\nlinear projections to $d_k$ , $d_k$ and $d_v$ dimensions, respectively. On each of these projected versions of\nqueries, keys and values we then perform the attention function in parallel, yielding $d_v$ -dimensional\noutput values. These are concatenated and once again projected, resulting in the final values, as\ndepicted in Figure 2 .",
89
- "reading_order": 7
90
- },
91
- {
92
- "label": "para",
93
- "bbox": [
94
- 198,
95
- 1403,
96
- 1065,
97
- 1453
98
- ],
99
- "text": "Multi­head attention allows the model to jointly attend to information from different representation\nsubspaces at different positions. With a single attention head, averaging inhibits this.",
100
- "reading_order": 8
101
- },
102
- {
103
- "label": "fnote",
104
- "bbox": [
105
- 198,
106
- 1485,
107
- 1065,
108
- 1535
109
- ],
110
- "text": "${ }^{4}$ To illustrate why the dot products get large, assume that the components of $q$ and $k$ are independent random\nvariables with mean 0 and variance 1 . Then their dot product, $q \\cdot k=\\sum_{i=1}^{d_{k}} q_{i} k_{i}$, has mean 0 and variance $d_{k}$.",
111
- "reading_order": 9
112
- },
113
- {
114
- "label": "foot",
115
- "bbox": [
116
- 625,
117
- 1578,
118
- 641,
119
- 1599
120
- ],
121
- "text": "4",
122
- "reading_order": 10
123
- }
124
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/page_imgs/test_page2.jpeg DELETED

Git LFS Details

  • SHA256: 2bbda18d9f6ab0279f80718b15d66e1e444279b24a55a23b872f70a382060ac1
  • Pointer size: 131 Bytes
  • Size of remote file: 366 kB
demo/page_imgs/test_page3.jpeg DELETED

Git LFS Details

  • SHA256: f5a5beda63acd2046fc4c7f39e4aa63e70db723936d71488e5819ab106f90ec0
  • Pointer size: 131 Bytes
  • Size of remote file: 358 kB
demo_element.py DELETED
@@ -1,129 +0,0 @@
1
- """
2
- Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
- SPDX-License-Identifier: MIT
4
- """
5
-
6
- import argparse
7
- import glob
8
- import os
9
-
10
- from omegaconf import OmegaConf
11
- from PIL import Image
12
-
13
- from chat import DOLPHIN
14
- from utils.utils import *
15
-
16
-
17
- def process_element(image_path, model, element_type, save_dir=None):
18
- """Process a single element image (text, table, formula)
19
-
20
- Args:
21
- image_path: Path to the element image
22
- model: DOLPHIN model instance
23
- element_type: Type of element ('text', 'table', 'formula')
24
- save_dir: Directory to save results (default: same as input directory)
25
-
26
- Returns:
27
- Parsed content of the element and recognition results
28
- """
29
- # Load and prepare image
30
- pil_image = Image.open(image_path).convert("RGB")
31
- pil_image = crop_margin(pil_image)
32
-
33
- # Select appropriate prompt based on element type
34
- if element_type == "table":
35
- prompt = "Parse the table in the image."
36
- label = "tab"
37
- elif element_type == "formula":
38
- prompt = "Read text in the image."
39
- label = "formula"
40
- else: # Default to text
41
- prompt = "Read text in the image."
42
- label = "text"
43
-
44
- # Process the element
45
- result = model.chat(prompt, pil_image)
46
-
47
- # Create recognition result in the same format as the document parser
48
- recognition_result = [
49
- {
50
- "label": label,
51
- "text": result.strip(),
52
- }
53
- ]
54
-
55
- # Save results if save_dir is provided
56
- if save_dir:
57
- save_outputs(recognition_result, image_path, save_dir)
58
- print(f"Results saved to {save_dir}")
59
-
60
- return result, recognition_result
61
-
62
-
63
- def main():
64
- parser = argparse.ArgumentParser(description="Element-level processing using DOLPHIN model")
65
- parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file")
66
- parser.add_argument("--input_path", type=str, required=True, help="Path to input image or directory of images")
67
- parser.add_argument(
68
- "--element_type",
69
- type=str,
70
- choices=["text", "table", "formula"],
71
- default="text",
72
- help="Type of element to process (text, table, formula)",
73
- )
74
- parser.add_argument(
75
- "--save_dir",
76
- type=str,
77
- default=None,
78
- help="Directory to save parsing results (default: same as input directory)",
79
- )
80
- parser.add_argument("--print_results", action="store_true", help="Print recognition results to console")
81
- args = parser.parse_args()
82
-
83
- # Load Model
84
- config = OmegaConf.load(args.config)
85
- model = DOLPHIN(config)
86
-
87
- # Set save directory
88
- save_dir = args.save_dir or (
89
- args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
90
- )
91
- setup_output_dirs(save_dir)
92
-
93
- # Collect Images
94
- if os.path.isdir(args.input_path):
95
- image_files = []
96
- for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
97
- image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
98
- image_files = sorted(image_files)
99
- else:
100
- if not os.path.exists(args.input_path):
101
- raise FileNotFoundError(f"Input path {args.input_path} does not exist")
102
- image_files = [args.input_path]
103
-
104
- total_samples = len(image_files)
105
- print(f"\nTotal samples to process: {total_samples}")
106
-
107
- # Process images one by one
108
- for image_path in image_files:
109
- print(f"\nProcessing {image_path}")
110
- try:
111
- result, recognition_result = process_element(
112
- image_path=image_path,
113
- model=model,
114
- element_type=args.element_type,
115
- save_dir=save_dir,
116
- )
117
-
118
- if args.print_results:
119
- print("\nRecognition result:")
120
- print(result)
121
- print("-" * 40)
122
-
123
- except Exception as e:
124
- print(f"Error processing {image_path}: {str(e)}")
125
- continue
126
-
127
-
128
- if __name__ == "__main__":
129
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo_element_hf.py CHANGED
@@ -1,8 +1,3 @@
1
- """
2
- Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
- SPDX-License-Identifier: MIT
4
- """
5
-
6
  import argparse
7
  import glob
8
  import os
 
 
 
 
 
 
1
  import argparse
2
  import glob
3
  import os
demo_page.py DELETED
@@ -1,247 +0,0 @@
1
- """
2
- Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
- SPDX-License-Identifier: MIT
4
- """
5
-
6
- import argparse
7
- import glob
8
- import os
9
-
10
- import cv2
11
- from omegaconf import OmegaConf
12
- from PIL import Image
13
-
14
- from chat import DOLPHIN
15
- from utils.utils import *
16
-
17
-
18
- def process_document(document_path, model, save_dir, max_batch_size):
19
- """Parse documents - Handles both images and PDFs"""
20
- file_ext = os.path.splitext(document_path)[1].lower()
21
-
22
- if file_ext == '.pdf':
23
- # Process PDF file
24
- # Convert PDF to images
25
- images = convert_pdf_to_images(document_path)
26
- if not images:
27
- raise Exception(f"Failed to convert PDF {document_path} to images")
28
-
29
- all_results = []
30
-
31
- # Process each page
32
- for page_idx, pil_image in enumerate(images):
33
- print(f"Processing page {page_idx + 1}/{len(images)}")
34
-
35
- # Generate output name for this page
36
- base_name = os.path.splitext(os.path.basename(document_path))[0]
37
- page_name = f"{base_name}_page_{page_idx + 1:03d}"
38
-
39
- # Process this page (don't save individual page results)
40
- json_path, recognition_results = process_single_image(
41
- pil_image, model, save_dir, page_name, max_batch_size, save_individual=False
42
- )
43
-
44
- # Add page information to results
45
- page_results = {
46
- "page_number": page_idx + 1,
47
- "elements": recognition_results
48
- }
49
- all_results.append(page_results)
50
-
51
- # Save combined results for multi-page PDF
52
- combined_json_path = save_combined_pdf_results(all_results, document_path, save_dir)
53
-
54
- return combined_json_path, all_results
55
-
56
- else:
57
- # Process regular image file
58
- pil_image = Image.open(document_path).convert("RGB")
59
- base_name = os.path.splitext(os.path.basename(document_path))[0]
60
- return process_single_image(pil_image, model, save_dir, base_name, max_batch_size)
61
-
62
-
63
- def process_single_image(image, model, save_dir, image_name, max_batch_size, save_individual=True):
64
- """Process a single image (either from file or converted from PDF page)
65
-
66
- Args:
67
- image: PIL Image object
68
- model: DOLPHIN model instance
69
- save_dir: Directory to save results
70
- image_name: Name for the output file
71
- max_batch_size: Maximum batch size for processing
72
- save_individual: Whether to save individual results (False for PDF pages)
73
-
74
- Returns:
75
- Tuple of (json_path, recognition_results)
76
- """
77
- # Stage 1: Page-level layout and reading order parsing
78
- layout_output = model.chat("Parse the reading order of this document.", image)
79
-
80
- # Stage 2: Element-level content parsing
81
- padded_image, dims = prepare_image(image)
82
- recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size, save_dir, image_name)
83
-
84
- # Save outputs only if requested (skip for PDF pages)
85
- json_path = None
86
- if save_individual:
87
- # Create a dummy image path for save_outputs function
88
- dummy_image_path = f"{image_name}.jpg" # Extension doesn't matter, only basename is used
89
- json_path = save_outputs(recognition_results, dummy_image_path, save_dir)
90
-
91
- return json_path, recognition_results
92
-
93
-
94
- def process_elements(layout_results, padded_image, dims, model, max_batch_size, save_dir=None, image_name=None):
95
- """Parse all document elements with parallel decoding"""
96
- layout_results = parse_layout_string(layout_results)
97
-
98
- text_table_elements = [] # Elements that need processing
99
- figure_results = [] # Figure elements (no processing needed)
100
- previous_box = None
101
- reading_order = 0
102
-
103
- # Collect elements for processing
104
- for bbox, label in layout_results:
105
- try:
106
- # Adjust coordinates
107
- x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
108
- bbox, padded_image, dims, previous_box
109
- )
110
-
111
- # Crop and parse element
112
- cropped = padded_image[y1:y2, x1:x2]
113
- if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
114
- if label == "fig":
115
- pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
116
-
117
- figure_filename = save_figure_to_local(pil_crop, save_dir, image_name, reading_order)
118
-
119
- # For figure regions, store relative path instead of base64
120
- figure_results.append(
121
- {
122
- "label": label,
123
- "text": f"![Figure](figures/{figure_filename})",
124
- "figure_path": f"figures/{figure_filename}",
125
- "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
126
- "reading_order": reading_order,
127
- }
128
- )
129
- else:
130
- # For text or table regions, prepare for parsing
131
- pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
132
- prompt = "Parse the table in the image." if label == "tab" else "Read text in the image."
133
- text_table_elements.append(
134
- {
135
- "crop": pil_crop,
136
- "prompt": prompt,
137
- "label": label,
138
- "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
139
- "reading_order": reading_order,
140
- }
141
- )
142
-
143
- reading_order += 1
144
-
145
- except Exception as e:
146
- print(f"Error processing bbox with label {label}: {str(e)}")
147
- continue
148
-
149
- # Parse text/table elements in parallel
150
- recognition_results = figure_results
151
- if text_table_elements:
152
- crops_list = [elem["crop"] for elem in text_table_elements]
153
- prompts_list = [elem["prompt"] for elem in text_table_elements]
154
-
155
- # Inference in batch
156
- batch_results = model.chat(prompts_list, crops_list, max_batch_size=max_batch_size)
157
-
158
- # Add batch results to recognition_results
159
- for i, result in enumerate(batch_results):
160
- elem = text_table_elements[i]
161
- recognition_results.append(
162
- {
163
- "label": elem["label"],
164
- "bbox": elem["bbox"],
165
- "text": result.strip(),
166
- "reading_order": elem["reading_order"],
167
- }
168
- )
169
-
170
- # Sort elements by reading order
171
- recognition_results.sort(key=lambda x: x.get("reading_order", 0))
172
-
173
- return recognition_results
174
-
175
-
176
- def main():
177
- parser = argparse.ArgumentParser(description="Document parsing based on DOLPHIN")
178
- parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file")
179
- parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image/PDF or directory of files")
180
- parser.add_argument(
181
- "--save_dir",
182
- type=str,
183
- default=None,
184
- help="Directory to save parsing results (default: same as input directory)",
185
- )
186
- parser.add_argument(
187
- "--max_batch_size",
188
- type=int,
189
- default=4,
190
- help="Maximum number of document elements to parse in a single batch (default: 4)",
191
- )
192
- args = parser.parse_args()
193
-
194
- # Load Model
195
- config = OmegaConf.load(args.config)
196
- model = DOLPHIN(config)
197
-
198
- # Collect Document Files (images and PDFs)
199
- if os.path.isdir(args.input_path):
200
- # Support both image and PDF files
201
- file_extensions = [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG", ".pdf", ".PDF"]
202
-
203
- document_files = []
204
- for ext in file_extensions:
205
- document_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
206
- document_files = sorted(document_files)
207
- else:
208
- if not os.path.exists(args.input_path):
209
- raise FileNotFoundError(f"Input path {args.input_path} does not exist")
210
-
211
- # Check if it's a supported file type
212
- file_ext = os.path.splitext(args.input_path)[1].lower()
213
- supported_exts = ['.jpg', '.jpeg', '.png', '.pdf']
214
-
215
- if file_ext not in supported_exts:
216
- raise ValueError(f"Unsupported file type: {file_ext}. Supported types: {supported_exts}")
217
-
218
- document_files = [args.input_path]
219
-
220
- save_dir = args.save_dir or (
221
- args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
222
- )
223
- setup_output_dirs(save_dir)
224
-
225
- total_samples = len(document_files)
226
- print(f"\nTotal files to process: {total_samples}")
227
-
228
- # Process All Document Files
229
- for file_path in document_files:
230
- print(f"\nProcessing {file_path}")
231
- try:
232
- json_path, recognition_results = process_document(
233
- document_path=file_path,
234
- model=model,
235
- save_dir=save_dir,
236
- max_batch_size=args.max_batch_size,
237
- )
238
-
239
- print(f"Processing completed. Results saved to {save_dir}")
240
-
241
- except Exception as e:
242
- print(f"Error processing {file_path}: {str(e)}")
243
- continue
244
-
245
-
246
- if __name__ == "__main__":
247
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo_page_hf.py CHANGED
@@ -1,8 +1,3 @@
1
- """
2
- Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
- SPDX-License-Identifier: MIT
4
- """
5
-
6
  import argparse
7
  import glob
8
  import os
 
 
 
 
 
 
1
  import argparse
2
  import glob
3
  import os
deployment/ReadMe.md DELETED
@@ -1,12 +0,0 @@
1
- <h1 align="center">
2
- 🚀 Dolphin Inference/Serving
3
- </h1>
4
-
5
- ## vLLM
6
- > [Doc](./vllm/ReadMe.md)
7
-
8
- ## TensorRT-LLM
9
- > [Doc](./tensorrt_llm/ReadMe.md)
10
-
11
- ## Others
12
-
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/tensorrt_llm/ReadMe.md DELETED
@@ -1,89 +0,0 @@
1
- <h1 align="center">
2
- 🚀 Dolphin TensorRT-LLM Demo
3
- </h1>
4
-
5
- ## ✅ Introduction
6
- The Dolphin model employs a **Swin Encoder + MBart Decoder** architecture. In the HuggingFace Transformers [Config](https://huggingface.co/ByteDance/Dolphin/blob/main/config.json),
7
- its architectures field is specified as "VisionEncoderDecoderModel". **Dolphin**, **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)**, and **[Donut](https://huggingface.co/docs/transformers/model_doc/donut)** share the same model architecture. TensorRT-LLM has already supported the Nougat model.
8
- Following Nougat's conversion script, we have successfully implemented Dolphin on TensorRT-LLM.
9
-
10
- **Note:** [prompt_ids](./dolphin_runner.py#L120) MUST be of **int32** type, otherwise TensorRT-LLM will produce incorrect results.
11
-
12
- ## 🛠️ Installation
13
- > We only test TensorRT-LLM 0.18.1 on Linux.
14
-
15
- https://nvidia.github.io/TensorRT-LLM/0.18.1/installation/linux.html
16
-
17
-
18
- ## ⚡ Offline Inference
19
- ```
20
- export MODEL_NAME="Dolphin"
21
-
22
- # predict elements reading order
23
- python run_dolphin.py \
24
- --batch_size 1 \
25
- --hf_model_dir tmp/hf_models/${MODEL_NAME} \
26
- --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
27
- --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
28
- --max_new_tokens 4096 \
29
- --repetition_penalty 1.0 \
30
- --input_text "Parse the reading order of this document." \
31
- --image_path "../../demo/page_imgs/page_1.jpeg"
32
-
33
- # recognize text/latex
34
- python run_dolphin.py \
35
- --batch_size 1 \
36
- --hf_model_dir tmp/hf_models/${MODEL_NAME} \
37
- --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
38
- --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
39
- --max_new_tokens 4096 \
40
- --repetition_penalty 1.0 \
41
- --input_text "Read text in the image." \
42
- --image_path "../../demo/element_imgs/block_formula.jpeg"
43
-
44
-
45
- python run_dolphin.py \
46
- --batch_size 1 \
47
- --hf_model_dir tmp/hf_models/${MODEL_NAME} \
48
- --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
49
- --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
50
- --max_new_tokens 4096 \
51
- --repetition_penalty 1.0 \
52
- --input_text "Read text in the image." \
53
- --image_path "../../demo/element_imgs/para_1.jpg"
54
-
55
- # recognize table
56
- python run_dolphin.py \
57
- --batch_size 1 \
58
- --hf_model_dir tmp/hf_models/${MODEL_NAME} \
59
- --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
60
- --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
61
- --max_new_tokens 4096 \
62
- --repetition_penalty 1.0 \
63
- --input_text "Parse the table in the image." \
64
- --image_path "../../demo/element_imgs/table_1.jpeg"
65
- ```
66
-
67
-
68
- ## ⚡ Online Inference
69
- ```
70
- # 1. Start Api Server
71
- export MODEL_NAME="Dolphin"
72
-
73
- python api_server.py \
74
- --hf_model_dir tmp/hf_models/${MODEL_NAME} \
75
- --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
76
- --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
77
- --max_batch_size 16
78
-
79
- # 2. Predict
80
- # predict elements reading order
81
- python deployment/tensorrt_llm/api_client.py --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document."
82
-
83
- # recognize text/latex
84
- python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image."
85
- python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image."
86
-
87
- # recognize table
88
- python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image."
89
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/tensorrt_llm/api_client.py DELETED
@@ -1,100 +0,0 @@
1
- # SPDX-License-Identifier: Apache-2.0
2
- # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
- """Example Python client for `vllm.entrypoints.api_server`
4
- Start the demo server:
5
- python -m vllm.entrypoints.api_server --model <model_name>
6
-
7
- NOTE: The API server is used only for demonstration and simple performance
8
- benchmarks. It is not intended for production use.
9
- For production use, we recommend `vllm serve` and the OpenAI client API.
10
- """
11
-
12
- import argparse
13
- import base64
14
- import json
15
- from argparse import Namespace
16
- from collections.abc import Iterable
17
-
18
- import requests
19
-
20
-
21
- def clear_line(n: int = 1) -> None:
22
- LINE_UP = "\033[1A"
23
- LINE_CLEAR = "\x1b[2K"
24
- for _ in range(n):
25
- print(LINE_UP, end=LINE_CLEAR, flush=True)
26
-
27
-
28
- def encode_image_base64(image_path: str) -> str:
29
- """Encode local image to base64 format."""
30
-
31
- with open(image_path, "rb") as f:
32
- image_data = f.read()
33
- result = base64.b64encode(image_data).decode("utf-8")
34
-
35
- return result
36
-
37
-
38
- def post_http_request(
39
- prompt: str, image_path: str, api_url: str, stream: bool = False
40
- ) -> requests.Response:
41
- headers = {"User-Agent": "Test Client"}
42
- pload = {
43
- "prompt": prompt,
44
- "image_base64": encode_image_base64(image_path),
45
- }
46
- response = requests.post(api_url, headers=headers, json=pload, stream=stream)
47
- return response
48
-
49
-
50
- def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
51
- for chunk in response.iter_lines(
52
- chunk_size=8192, decode_unicode=False, delimiter=b"\n"
53
- ):
54
- if chunk:
55
- data = json.loads(chunk.decode("utf-8"))
56
- output = data["text"]
57
- yield output
58
-
59
-
60
- def get_response(response: requests.Response) -> list[str]:
61
- data = json.loads(response.content)
62
- output = data["text"]
63
- return output
64
-
65
-
66
- def parse_args():
67
- parser = argparse.ArgumentParser()
68
- parser.add_argument("--host", type=str, default="localhost")
69
- parser.add_argument("--port", type=int, default=8000)
70
- parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.")
71
- parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg")
72
- parser.add_argument("--stream", action="store_true")
73
- return parser.parse_args()
74
-
75
-
76
- def main(args: Namespace):
77
- prompt = args.prompt
78
- image_path = args.image_path
79
- api_url = f"http://{args.host}:{args.port}/generate"
80
- stream = args.stream
81
-
82
- print(f"Prompt: {prompt!r}\n", flush=True)
83
- response = post_http_request(prompt, image_path, api_url, stream)
84
-
85
- if stream:
86
- num_printed_lines = 0
87
- for h in get_streaming_response(response):
88
- clear_line(num_printed_lines)
89
- num_printed_lines = 0
90
- for i, line in enumerate(h):
91
- num_printed_lines += 1
92
- print(f"Response {i}: {line!r}", flush=True)
93
- else:
94
- output = get_response(response)
95
- print(f"Response: {output!r}", flush=True)
96
-
97
-
98
- if __name__ == "__main__":
99
- args = parse_args()
100
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/tensorrt_llm/api_server.py DELETED
@@ -1,112 +0,0 @@
1
- # copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/apps/fastapi_server.py
2
-
3
- #!/usr/bin/env python
4
- import asyncio
5
- import base64
6
- import io
7
- import logging
8
- import signal
9
- from http import HTTPStatus
10
- from PIL import Image
11
- from typing import Optional
12
-
13
- import click
14
- import uvicorn
15
- from fastapi import FastAPI, Request
16
- from fastapi.responses import JSONResponse, Response
17
-
18
- from tensorrt_llm.executor import CppExecutorError, RequestError
19
- from dolphin_runner import DolphinRunner, InferenceConfig
20
-
21
- TIMEOUT_KEEP_ALIVE = 5 # seconds.
22
-
23
-
24
- async def decode_image(image_base64: str) -> Image.Image:
25
- image_data = base64.b64decode(image_base64)
26
- image = Image.open(io.BytesIO(image_data))
27
- return image
28
-
29
-
30
- class LlmServer:
31
- def __init__(self, runner: DolphinRunner):
32
- self.runner = runner
33
- self.app = FastAPI()
34
- self.register_routes()
35
-
36
- def register_routes(self):
37
- self.app.add_api_route("/health", self.health, methods=["GET"])
38
- self.app.add_api_route("/generate", self.generate, methods=["POST"])
39
-
40
- async def health(self) -> Response:
41
- return Response(status_code=200)
42
-
43
- async def generate(self, request: Request) -> Response:
44
- """ Generate completion for the request.
45
-
46
- The request should be a JSON object with the following fields:
47
- - prompt: the prompt to use for the generation.
48
- - image_base64: the image to use for the generation.
49
- """
50
- request_dict = await request.json()
51
-
52
- prompt = request_dict.pop("prompt", "")
53
- logging.info(f"request prompt: {prompt}")
54
- image_base64 = request_dict.pop("image_base64", "")
55
- image = await decode_image(image_base64)
56
-
57
- try:
58
- output_texts = self.runner.run([prompt], [image], 4024)
59
- output_texts = [texts[0] for texts in output_texts]
60
- return JSONResponse({"text": output_texts[0]})
61
- except RequestError as e:
62
- return JSONResponse(content=str(e),
63
- status_code=HTTPStatus.BAD_REQUEST)
64
- except CppExecutorError:
65
- # If internal executor error is raised, shutdown the server
66
- signal.raise_signal(signal.SIGINT)
67
-
68
- async def __call__(self, host, port):
69
- config = uvicorn.Config(self.app,
70
- host=host,
71
- port=port,
72
- log_level="info",
73
- timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
74
- await uvicorn.Server(config).serve()
75
-
76
-
77
- @click.command()
78
- @click.option("--hf_model_dir", type=str, required=True)
79
- @click.option("--visual_engine_dir", type=str, required=True)
80
- @click.option("--llm_engine_dir", type=str, required=True)
81
- @click.option("--max_batch_size", type=int, default=16)
82
- @click.option("--max_new_tokens", type=int, default=4024)
83
- @click.option("--host", type=str, default=None)
84
- @click.option("--port", type=int, default=8000)
85
- def entrypoint(hf_model_dir: str,
86
- visual_engine_dir: str,
87
- llm_engine_dir: str,
88
- max_batch_size: int,
89
- max_new_tokens: int,
90
- host: Optional[str] = None,
91
- port: int = 8000):
92
- host = host or "0.0.0.0"
93
- port = port or 8000
94
- logging.info(f"Starting server at {host}:{port}")
95
-
96
- config = InferenceConfig(
97
- max_new_tokens=max_new_tokens,
98
- batch_size=max_batch_size,
99
- log_level="info",
100
- hf_model_dir=hf_model_dir,
101
- visual_engine_dir=visual_engine_dir,
102
- llm_engine_dir=llm_engine_dir,
103
- )
104
-
105
- dolphin_runner = DolphinRunner(config)
106
- server = LlmServer(runner=dolphin_runner)
107
-
108
- asyncio.run(server(host, port))
109
-
110
-
111
- if __name__ == "__main__":
112
- entrypoint()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/tensorrt_llm/convert/__init__.py DELETED
File without changes
deployment/tensorrt_llm/convert/build_visual_engine.py DELETED
@@ -1,14 +0,0 @@
1
- # copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.2/examples/multimodal/build_visual_engine.py
2
-
3
- import argparse
4
-
5
- from tensorrt_llm.tools.multimodal_builder import (VisionEngineBuilder,
6
- add_multimodal_arguments)
7
-
8
- if __name__ == '__main__':
9
- parser = argparse.ArgumentParser()
10
- parser = add_multimodal_arguments(parser)
11
- args = parser.parse_args()
12
-
13
- builder = VisionEngineBuilder(args)
14
- builder.build()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/tensorrt_llm/convert/convert_checkpoint.py DELETED
@@ -1,1528 +0,0 @@
1
- # copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/enc_dec/convert_checkpoint.py
2
-
3
- import argparse
4
- import configparser
5
- import copy
6
- import json
7
- import logging
8
- import os
9
- import types
10
- from ast import literal_eval
11
- from datetime import datetime
12
- from pathlib import Path
13
-
14
- import safetensors
15
- from helper import convert_weight_to_dtype, fuse_qkv_one_layer, reshape, split
16
- from transformers import (AutoModelForSeq2SeqLM, Blip2ForConditionalGeneration,
17
- MBartForConditionalGeneration,
18
- Pix2StructForConditionalGeneration,
19
- T5ForConditionalGeneration, VisionEncoderDecoderModel)
20
-
21
- from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType,
22
- MLPType)
23
- from tensorrt_llm.models import PretrainedConfig
24
-
25
- dir_path = os.path.dirname(os.path.realpath(__file__))
26
- LOGGER = logging.getLogger(__name__)
27
-
28
- layernorm_type_map = {i.name: i.value for i in LayerNormType}
29
- layernorm_position_map = {i.name: i.value for i in LayerNormPositionType}
30
- mlp_type_map = {i.name: i.value for i in MLPType}
31
-
32
-
33
- def copy_args_to_component_config(component_config, args):
34
- for arg in vars(args):
35
- setattr(component_config, arg, getattr(args, arg))
36
- return component_config
37
-
38
-
39
- def parse_t5_config(args, hf_model):
40
- config = configparser.ConfigParser()
41
-
42
- config["encoder"] = {}
43
- for key, val in hf_model.encoder.config.to_dict().items():
44
- config["encoder"][key] = f"{val}"
45
-
46
- # manually set q_scaling to offset attention scaling's effect.
47
- # TODO: modify kernels to control whether to disable attention scaling
48
- def get_offset_q_scaling(config):
49
- scaling = 1 / config.head_size**.5
50
- return scaling
51
-
52
- config["decoder"] = {}
53
- for key, val in hf_model.decoder.config.to_dict().items():
54
- config["decoder"][key] = f"{val}"
55
-
56
- config["structure"] = dict()
57
- config["structure"]["t5_with_bias"] = "false"
58
- config["structure"]["use_gated_activation"] = str(
59
- hf_model.encoder.config.is_gated_act)
60
- config["structure"]["position_embedding_type"] = "relative"
61
- config["structure"]["model_type"] = args.model_type
62
-
63
- def parse_t5_config_by_component(config, component, args):
64
- component_config = types.SimpleNamespace()
65
- component_config = copy_args_to_component_config(component_config, args)
66
- component_config.n_head = config.getint(component, 'num_heads')
67
- component_config.head_size = config.getint(component, 'd_kv')
68
- component_config.hidden_size = config.getint(component, 'd_model')
69
- component_config.ffn_hidden_size = config.getint(component, 'd_ff')
70
- component_config.vocab_size = config.getint(component, 'vocab_size')
71
- component_config.n_positions = config.getint(component,
72
- 'n_positions',
73
- fallback=512)
74
- component_config.has_position_embedding = config.getboolean(
75
- component, 'has_position_embedding',
76
- fallback=False) # TODO: hardcoded here
77
-
78
- component_config.has_token_type_embedding = config.getboolean(
79
- component, 'has_token_type_embedding', fallback=False)
80
- component_config.has_embedding_layernorm = config.getboolean(
81
- component, 'has_embedding_layernorm', fallback=False)
82
- component_config.has_embedding_scale = config.getboolean(
83
- component, 'has_embedding_scale', fallback=False)
84
- component_config.q_scaling = get_offset_q_scaling(component_config)
85
- component_config.has_attention_qkvo_bias = config.getboolean(
86
- component, 'has_attention_qkvo_bias',
87
- fallback=False) # TODO: hardcoded here
88
- component_config.has_mlp_bias = config.getboolean(component,
89
- 'has_mlp_bias',
90
- fallback=False)
91
- component_config.has_model_final_layernorm = config.getboolean(
92
- component, 'has_model_final_layernorm', fallback=True)
93
- component_config.layernorm_eps = config.getfloat(
94
- component, 'layer_norm_epsilon')
95
- component_config.layernorm_position = layernorm_position_map[config.get(
96
- component, 'layernorm_position',
97
- fallback='pre_layernorm')] # TODO: hardcoded here
98
- component_config.layernorm_type = layernorm_type_map[config.get(
99
- component, 'layernorm_type', fallback='RmsNorm')]
100
- component_config.hidden_act = config.get(component, 'dense_act_fn')
101
- component_config.gated_act = config.getboolean(component,
102
- 'is_gated_act')
103
- component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.
104
- gated_act else 'MLP']
105
- component_config.num_buckets = config.getint(
106
- component, 'relative_attention_num_buckets')
107
- component_config.max_distance = config.getint(
108
- component, 'relative_attention_max_distance')
109
- component_config.position_embedding_type = config.get(
110
- 'structure', 'position_embedding_type')
111
- component_config.logits_dtype = config.get(component,
112
- 'logits_dtype',
113
- fallback='float32')
114
-
115
- if component == 'encoder':
116
- component_config.n_layer = config.getint(component, 'num_layers')
117
-
118
- component_config.relative_attention = config.get(
119
- 'structure', 'position_embedding_type') == 'relative'
120
-
121
- elif component == 'decoder':
122
- component_config.n_layer = config.getint(component,
123
- 'num_decoder_layers')
124
- component_config.has_lm_head_bias = config.getboolean(
125
- component, # TODO: T5 with bias
126
- 'has_lm_head_bias',
127
- fallback=False)
128
- component_config.relative_attention = config.getboolean(
129
- component, 'relative_attention', fallback=True)
130
- component_config.rescale_before_lm_head = config.getboolean(
131
- component, 'tie_word_embeddings'
132
- ) # default is True (for T5), but False for Flan-T5
133
- component_config.encoder_hidden_size = config.getint(
134
- 'encoder', 'd_model')
135
- component_config.encoder_num_heads = config.getint(
136
- 'encoder', 'num_heads')
137
- component_config.encoder_head_size = config.getint(
138
- 'encoder', 'd_kv')
139
- component_config.decoder_start_token_id = config.getint(
140
- 'decoder', 'decoder_start_token_id')
141
- component_config.eos_token_id = config.getint(
142
- 'decoder', 'eos_token_id')
143
- bos_token_id = config.get('decoder', 'bos_token_id')
144
- # T5 does not have bos_token_id
145
- component_config.bos_token_id = int(
146
- bos_token_id) if bos_token_id != "None" else None
147
- component_config.pad_token_id = config.getint(
148
- 'decoder', 'pad_token_id')
149
-
150
- else:
151
- assert False, 'Unsupported component!'
152
-
153
- return component_config
154
-
155
- encoder_config = parse_t5_config_by_component(config, "encoder", args)
156
- decoder_config = parse_t5_config_by_component(config, "decoder", args)
157
-
158
- return encoder_config, decoder_config
159
-
160
-
161
- def convert_t5_weights_to_tllm_safetensors(config, component, params):
162
- weights = {}
163
-
164
- mapping = config.mapping
165
-
166
- convert_weight_to_dtype(params, config.dtype)
167
- hidden_size = config.hidden_size
168
- ffn_hidden_size = config.intermediate_size
169
- num_layers = config.num_hidden_layers
170
- n_head = config.num_attention_heads
171
- head_size = config.head_size
172
- attention_hidden_size = n_head * head_size # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5
173
-
174
- hf_param_prefix = f'{component}'
175
- trtllm_layer_name = f'{component}_layers'
176
- trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention'
177
- trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm'
178
- hf_component_idx = 1 if component == 'encoder' else 2
179
-
180
- def get_attn_module_name(component, block, layer, attn_type):
181
- return f'{component}.block.{int(block)}.layer.{int(layer)}.{attn_type}'
182
-
183
- weights['embedding.vocab_embedding.weight'] = reshape(
184
- params['shared.weight'].clone(), None)
185
-
186
- layers_range = mapping.pp_layers(num_layers)
187
- for layer_idx in layers_range:
188
- local_layer_idx = layer_idx - layers_range[0]
189
- trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'
190
- hf_layer_name_prefix = f'{hf_param_prefix}.block.{layer_idx}'
191
-
192
- hidden_layer_name_split = {
193
- f'{hf_layer_name_prefix}.layer.0.SelfAttention.o.weight': {
194
- "name":
195
- f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight',
196
- "shape":
197
- (hidden_size, attention_hidden_size // mapping.tp_size),
198
- "split_dim": -1
199
- },
200
- f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wo.weight':
201
- {
202
- "name": f'{trtllm_layer_name_prefix}.mlp.proj.weight',
203
- "shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
204
- "split_dim": -1
205
- },
206
- f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi.weight':
207
- {
208
- "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight',
209
- "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
210
- "split_dim": 0
211
- },
212
- f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi_0.weight':
213
- {
214
- "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight',
215
- "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
216
- "split_dim": 0
217
- },
218
- }
219
-
220
- hidden_layer_name_no_split = {
221
- f'{hf_layer_name_prefix}.layer.0.layer_norm.weight': {
222
- "name":
223
- f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight',
224
- "shape": None
225
- },
226
- f'{hf_layer_name_prefix}.layer.{hf_component_idx}.layer_norm.weight':
227
- {
228
- "name": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight',
229
- "shape": None
230
- },
231
- }
232
-
233
- if config.gated_act:
234
- hidden_layer_name_split.update({
235
- f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi2.weight':
236
- {
237
- "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight',
238
- "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
239
- "split_dim": 0
240
- },
241
- f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi_1.weight':
242
- {
243
- "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight',
244
- "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
245
- "split_dim": 0
246
- },
247
- })
248
-
249
- if component == 'decoder':
250
- hidden_layer_name_split.update({
251
- f'{hf_layer_name_prefix}.layer.1.EncDecAttention.o.weight': {
252
- "name":
253
- f'{trtllm_layer_name_prefix}.cross_attention.dense.weight',
254
- "shape":
255
- (hidden_size, attention_hidden_size // mapping.tp_size),
256
- "split_dim": -1
257
- },
258
- })
259
- hidden_layer_name_no_split.update({
260
- f'{hf_layer_name_prefix}.layer.1.layer_norm.weight': {
261
- "name":
262
- f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight',
263
- "shape": None
264
- },
265
- })
266
- self_attn_module_name = get_attn_module_name(
267
- component, layer_idx, "1", 'EncDecAttention')
268
- weights.update(
269
- fuse_qkv_one_layer(
270
- params, self_attn_module_name,
271
- f'{trtllm_layer_name_prefix}.cross_attention',
272
- mapping.tp_size, mapping.tp_rank, config.model_type,
273
- (attention_hidden_size * 3 // mapping.tp_size, hidden_size),
274
- None))
275
-
276
- self_attn_module_name = get_attn_module_name(component, layer_idx, "0",
277
- 'SelfAttention')
278
- weights.update(
279
- fuse_qkv_one_layer(
280
- params, self_attn_module_name,
281
- f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
282
- mapping.tp_size, mapping.tp_rank, config.model_type,
283
- (attention_hidden_size * 3 // mapping.tp_size, hidden_size),
284
- None))
285
-
286
- weights[
287
- f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.rel_attn_table'] = reshape(
288
- split(
289
- params[
290
- f'{component}.block.0.layer.0.SelfAttention.relative_attention_bias.weight']
291
- .T, mapping.tp_size, mapping.tp_rank, 0),
292
- (n_head // mapping.tp_size, config.num_buckets))
293
-
294
- for hf_weight_name, weight_info in hidden_layer_name_split.items():
295
- if hf_weight_name in params.keys():
296
- weights[weight_info["name"]] = reshape(
297
- split(params[hf_weight_name],
298
- mapping.tp_size,
299
- mapping.tp_rank,
300
- dim=weight_info["split_dim"]), weight_info["shape"])
301
- for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
302
- if hf_weight_name in params.keys():
303
- weights[weight_info["name"]] = reshape(
304
- params[hf_weight_name].clone(), shape=weight_info["shape"])
305
-
306
- weights['final_layernorm.weight'] = reshape(
307
- params[f'{component}.final_layer_norm.weight'].clone(), None)
308
-
309
- if component == 'decoder':
310
- weights['lm_head.weight'] = reshape(
311
- split(params['lm_head.weight'],
312
- mapping.tp_size,
313
- mapping.tp_rank,
314
- dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
315
- if not config.use_implicit_relative_attention:
316
- weights['rel_attn_table'] = reshape(
317
- split(
318
- params[
319
- f'{component}.block.0.layer.0.SelfAttention.relative_attention_bias.weight']
320
- .T, mapping.tp_size, mapping.tp_rank, 0),
321
- (n_head // mapping.tp_size, config.num_buckets))
322
-
323
- return weights
324
-
325
-
326
- convert_blip2_weights_to_tllm_safetensors = convert_t5_weights_to_tllm_safetensors # func alias
327
-
328
-
329
- def parse_nmt_config(args, model):
330
- config = configparser.ConfigParser()
331
- fairseq_config = vars(model.cfg.model) # Namespace --> dict
332
-
333
- config['encoder'] = dict()
334
- for key, val in fairseq_config.items():
335
- config["encoder"][key] = f"{val}"
336
- config["encoder"]["q_scaling"] = '1'
337
- # NMT has final layernorm for pre-norm model architecture.
338
- config['encoder']['has_model_final_layernorm'] = config['encoder'][
339
- 'encoder_normalize_before']
340
- config['encoder']['vocab_size'] = str(len(model.src_dict)) # fairseq naming
341
-
342
- config['decoder'] = dict()
343
- for key, val in fairseq_config.items():
344
- config["decoder"][key] = f"{val}"
345
- config["decoder"]["q_scaling"] = '1'
346
- config["decoder"]["rescale_before_lm_head"] = 'false'
347
- config['decoder']['has_model_final_layernorm'] = str(
348
- config['decoder'].getboolean('decoder_normalize_before', False)
349
- and not config['decoder'].getboolean('no_decoder_final_norm', False))
350
- config['decoder']['vocab_size'] = str(len(model.tgt_dict)) # fairseq naming
351
-
352
- config["structure"] = dict()
353
- config["structure"]["t5_with_bias"] = "true"
354
- config["structure"]["use_gated_activation"] = "false"
355
- config["structure"][
356
- "position_embedding_type"] = "learned_absolute" # "sinusoid"
357
- config["structure"]["model_type"] = args.model_type
358
-
359
- def parse_nmt_config_by_component(config, component, args):
360
- assert component in ('encoder', 'decoder'), 'Unsupported component!'
361
- component_config = types.SimpleNamespace()
362
- component_config = copy_args_to_component_config(component_config, args)
363
- component_config.n_layer = config.getint(component,
364
- f'{component}_layers')
365
- component_config.n_head = config.getint(component,
366
- f'{component}_attention_heads')
367
- component_config.hidden_size = config.getint(
368
- component, f'{component}_embed_dim') # fairseq naming
369
- component_config.head_size = config.getint(
370
- component,
371
- 'd_kv',
372
- fallback=component_config.hidden_size // component_config.n_head)
373
- component_config.ffn_hidden_size = config.getint(
374
- component, f'{component}_ffn_embed_dim') # fairseq naming
375
- component_config.vocab_size = config.getint(component, 'vocab_size')
376
- component_config.n_positions = config.getint(
377
- component, 'max_source_positions') # fairseq naming
378
- component_config.has_position_embedding = not config.getboolean(
379
- component, 'no_token_positional_embeddings',
380
- fallback=False) # fairseq naming
381
- component_config.has_token_type_embedding = config.getboolean(
382
- component, 'has_token_type_embedding', fallback=False)
383
- component_config.has_embedding_layernorm = config.getboolean(
384
- component, 'layernorm_embedding', fallback=True) # fairseq naming
385
- component_config.has_embedding_scale = not config.getboolean(
386
- component, 'no_scale_embedding') # fairseq naming
387
- component_config.q_scaling = config.getfloat(component,
388
- 'q_scaling',
389
- fallback=1.0)
390
- component_config.has_attention_qkvo_bias = config.getboolean(
391
- 'structure', 't5_with_bias', fallback=True)
392
- component_config.has_mlp_bias = config.getboolean('structure',
393
- 't5_with_bias',
394
- fallback=True)
395
- component_config.has_model_final_layernorm = config.getboolean(
396
- component, 'has_model_final_layernorm')
397
- component_config.layernorm_eps = config.getfloat(
398
- component, 'layer_norm_epsilon', fallback=1e-5) # fairseq naming
399
-
400
- normalize_before = config.getboolean(
401
- component, f'{component}_normalize_before') # fairseq naming
402
- component_config.layernorm_position = layernorm_position_map[
403
- 'pre_layernorm' if normalize_before else 'post_layernorm']
404
-
405
- component_config.layernorm_type = layernorm_type_map[config.get(
406
- component, 'layernorm_type', fallback='LayerNorm')]
407
- component_config.hidden_act = config.get(
408
- component, 'activation_fn') # fairseq naming
409
- component_config.gated_act = config.getboolean(component,
410
- 'is_gated_act',
411
- fallback=False)
412
- component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.
413
- gated_act else 'MLP']
414
- component_config.relative_attention = config.get(
415
- 'structure', 'position_embedding_type') == 'relative'
416
-
417
- component_config.num_buckets = config.getint(
418
- component, 'relative_attention_num_buckets', fallback=0)
419
- component_config.max_distance = config.getint(
420
- component, 'relative_attention_max_distance', fallback=0)
421
- component_config.position_embedding_type = config.get(
422
- 'structure', 'position_embedding_type')
423
- component_config.logits_dtype = config.get(component,
424
- 'logits_dtype',
425
- fallback='float32')
426
- if component == 'decoder':
427
- component_config.rescale_before_lm_head = config.getboolean(
428
- component, 'rescale_before_lm_head')
429
-
430
- component_config.encoder_hidden_size = config.getint(
431
- 'encoder', 'encoder_embed_dim') # fairseq naming
432
- component_config.encoder_num_heads = config.getint(
433
- 'encoder', 'encoder_attention_heads')
434
- component_config.encoder_head_size = config.getint(
435
- 'encoder',
436
- 'd_kv',
437
- fallback=component_config.encoder_hidden_size //
438
- component_config.encoder_num_heads)
439
- component_config.decoder_start_token_id = None
440
- component_config.eos_token_id = None
441
- component_config.bos_token_id = None
442
- component_config.pad_token_id = None
443
-
444
- return component_config
445
-
446
- encoder_config = parse_nmt_config_by_component(config, "encoder", args)
447
- decoder_config = parse_nmt_config_by_component(config, "decoder", args)
448
-
449
- return encoder_config, decoder_config
450
-
451
-
452
- def convert_nmt_weights_to_tllm_safetensors(config, component, params,
453
- sin_pos_embedding):
454
- weights = {}
455
-
456
- mapping = config.mapping
457
-
458
- hidden_size = config.hidden_size
459
-
460
- convert_weight_to_dtype(params, config.dtype)
461
- ffn_hidden_size = config.intermediate_size
462
- vocab_size = config.vocab_size
463
-
464
- hf_param_prefix = f'models.0.{component}'
465
- trtllm_layer_name = f'{component}_layers'
466
- trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention'
467
- trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm'
468
-
469
- hidden_layer_name_split = {
470
- 'self_attn.out_proj.weight': {
471
- "name": f'{trtllm_attn_layer_name}.dense.weight',
472
- "shape": (hidden_size, hidden_size // mapping.tp_size),
473
- "split_dim": -1
474
- },
475
- 'fc1.weight': {
476
- "name": 'mlp.fc.weight',
477
- "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
478
- "split_dim": 0
479
- },
480
- 'fc1.bias': {
481
- "name": 'mlp.fc.bias',
482
- "shape": (ffn_hidden_size // mapping.tp_size),
483
- "split_dim": 0
484
- },
485
- 'fc2.weight': {
486
- "name": 'mlp.proj.weight',
487
- "shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
488
- "split_dim": -1
489
- },
490
- }
491
-
492
- hidden_layer_name_no_split = {
493
- 'self_attn.out_proj.bias': {
494
- "name": f'{trtllm_attn_layer_name}.dense.bias',
495
- "shape": (hidden_size)
496
- },
497
- 'self_attn_layer_norm.weight': {
498
- "name": f'{trtllm_attn_layernorm_name}.weight',
499
- "shape": None
500
- },
501
- 'self_attn_layer_norm.bias': {
502
- "name": f'{trtllm_attn_layernorm_name}.bias',
503
- "shape": None
504
- },
505
- 'fc2.bias': {
506
- "name": 'mlp.proj.bias',
507
- "shape": (hidden_size)
508
- },
509
- 'final_layer_norm.weight': {
510
- "name": 'mlp_layernorm.weight',
511
- "shape": None
512
- },
513
- 'final_layer_norm.bias': {
514
- "name": 'mlp_layernorm.bias',
515
- "shape": None
516
- },
517
- }
518
-
519
- if component == "decoder":
520
- hidden_layer_name_split.update({
521
- 'encoder_attn.out_proj.weight': {
522
- "name": 'cross_attention.dense.weight',
523
- "shape": (hidden_size, hidden_size // mapping.tp_size),
524
- "split_dim": -1
525
- },
526
- })
527
- hidden_layer_name_no_split.update({
528
- 'encoder_attn.out_proj.bias': {
529
- "name": 'cross_attention.dense.bias',
530
- "shape": (hidden_size)
531
- },
532
- 'encoder_attn_layer_norm.weight': {
533
- "name": 'cross_attention_layernorm.weight',
534
- "shape": None,
535
- },
536
- 'encoder_attn_layer_norm.bias': {
537
- "name": 'cross_attention_layernorm.bias',
538
- "shape": None
539
- },
540
- })
541
-
542
- def get_attn_module_name(component, layer, attn_type):
543
- return f'models.0.{component}.layers.{int(layer)}.{attn_type}'
544
-
545
- weights["embedding.vocab_embedding.weight"] = reshape(
546
- params[f'{hf_param_prefix}.embed_tokens.weight'].clone(),
547
- (vocab_size, -1))
548
- weights["embedding.position_embedding.weight"] = reshape(
549
- sin_pos_embedding, (config.max_position_embeddings, hidden_size))
550
-
551
- num_layers = config.num_hidden_layers
552
-
553
- layers_range = mapping.pp_layers(num_layers)
554
- for layer_idx in layers_range:
555
- local_layer_idx = layer_idx - layers_range[0]
556
- hf_layer_name_prefix = f'{hf_param_prefix}.layers.{layer_idx}'
557
- trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'
558
-
559
- for hf_weight_name, weight_info in hidden_layer_name_split.items():
560
- weights[
561
- f'{trtllm_layer_name_prefix}.{weight_info["name"]}'] = reshape(
562
- split(params[f'{hf_layer_name_prefix}.{hf_weight_name}'],
563
- mapping.tp_size,
564
- mapping.tp_rank,
565
- dim=weight_info["split_dim"]), weight_info["shape"])
566
-
567
- for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
568
- trtllm_layer_fullname = f'{trtllm_layer_name_prefix}.{weight_info["name"]}'
569
- hf_layer_fullname = f'{hf_layer_name_prefix}.{hf_weight_name}'
570
- weights[trtllm_layer_fullname] = reshape(
571
- params[hf_layer_fullname].clone(), shape=weight_info["shape"])
572
-
573
- self_attn_module_name = get_attn_module_name(component, layer_idx,
574
- 'self_attn')
575
- weights.update(
576
- fuse_qkv_one_layer(
577
- params, self_attn_module_name,
578
- f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
579
- mapping.tp_size, mapping.tp_rank, config.model_type,
580
- (hidden_size * 3 // mapping.tp_size, hidden_size),
581
- (hidden_size * 3 // mapping.tp_size)))
582
- if component == 'decoder':
583
- cross_attn_module_name = get_attn_module_name(
584
- component, layer_idx, 'encoder_attn')
585
- weights.update(
586
- fuse_qkv_one_layer(
587
- params, cross_attn_module_name,
588
- f'{trtllm_layer_name_prefix}.cross_attention',
589
- mapping.tp_size, mapping.tp_rank, config.model_type,
590
- (hidden_size * 3 // mapping.tp_size, hidden_size),
591
- (hidden_size * 3 // mapping.tp_size)))
592
-
593
- if component == 'decoder':
594
- weights['lm_head.weight'] = reshape(
595
- split(params[f'{hf_param_prefix}.output_projection.weight'],
596
- mapping.tp_size,
597
- mapping.tp_rank,
598
- dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
599
-
600
- if config.has_model_final_layernorm:
601
- weights['final_layernorm.weight'] = params[
602
- f'{hf_param_prefix}.layer_norm.weight'].clone()
603
- weights['final_layernorm.bias'] = params[
604
- f'{hf_param_prefix}.layer_norm.bias'].clone()
605
-
606
- return weights
607
-
608
-
609
- def parse_bart_config(args, hf_model):
610
-
611
- config = configparser.ConfigParser()
612
-
613
- config['decoder'] = dict()
614
- for key, val in hf_model.model.decoder.config.to_dict().items():
615
- config["decoder"][key] = f"{val}"
616
- config["decoder"]["q_scaling"] = '1'
617
- config["decoder"]["rescale_before_lm_head"] = str(False)
618
- config['decoder']['has_model_final_layernorm'] = str(
619
- args.nougat or isinstance(hf_model, MBartForConditionalGeneration))
620
-
621
- if args.nougat:
622
- # These flags are true for mbart decoders, but missing in HF config
623
- config['decoder']['normalize_before'] = str(True)
624
- config['decoder']['normalize_embeddings'] = str(True)
625
-
626
- config['encoder'] = dict()
627
- # Init few encoder configs, needed by build, from decoder config
628
- encoder_config_keys = [
629
- "encoder_ffn_dim", "encoder_layers", "encoder_attention_heads",
630
- "encoder_layerdrop", "d_model"
631
- ]
632
- for key in encoder_config_keys:
633
- config['encoder'][key] = config['decoder'][key]
634
- else:
635
- config['encoder'] = dict()
636
- for key, val in hf_model.model.encoder.config.to_dict().items():
637
- config["encoder"][key] = f"{val}"
638
- config["encoder"]["q_scaling"] = '1'
639
-
640
- # mBART has final layernorm, BART does not
641
- config['encoder']['has_model_final_layernorm'] = str(
642
- isinstance(hf_model, MBartForConditionalGeneration))
643
-
644
- config["structure"] = dict()
645
- config["structure"]["t5_with_bias"] = "true"
646
- config["structure"]["use_gated_activation"] = "false"
647
- config["structure"]["position_embedding_type"] = "learned_absolute"
648
- config["structure"]["model_type"] = args.model_type
649
-
650
- def parse_bart_config_by_component(config, component, args):
651
- assert component in ('encoder', 'decoder'), 'Unsupported component!'
652
- component_config = types.SimpleNamespace()
653
- component_config = copy_args_to_component_config(component_config, args)
654
- component_config.n_layer = config.getint(component,
655
- f'{component}_layers')
656
- component_config.n_head = config.getint(component,
657
- f'{component}_attention_heads')
658
- component_config.hidden_size = config.getint(component, 'd_model')
659
- component_config.head_size = config.getint(
660
- component,
661
- 'd_kv',
662
- fallback=component_config.hidden_size // component_config.n_head)
663
- component_config.ffn_hidden_size = config.getint(
664
- component, f'{component}_ffn_dim')
665
- component_config.vocab_size = config.getint(component, 'vocab_size')
666
- component_config.n_positions = config.getint(component,
667
- 'max_position_embeddings')
668
- component_config.has_position_embedding = config.getboolean(
669
- component, 'has_position_embedding',
670
- fallback=True) # TODO: hardcoded here
671
- component_config.has_token_type_embedding = config.getboolean(
672
- component, 'has_token_type_embedding', fallback=False)
673
- component_config.has_embedding_layernorm = config.getboolean(
674
- component, 'has_embedding_layernorm', fallback=True)
675
- component_config.has_embedding_scale = config.getboolean(
676
- component, 'scale_embedding')
677
- component_config.q_scaling = config.getfloat(component,
678
- 'q_scaling',
679
- fallback=1.0)
680
- component_config.has_attention_qkvo_bias = config.getboolean(
681
- 'structure', 't5_with_bias', fallback=True)
682
- component_config.has_mlp_bias = config.getboolean('structure',
683
- 't5_with_bias',
684
- fallback=True)
685
- component_config.has_model_final_layernorm = config.getboolean(
686
- component, 'has_model_final_layernorm')
687
- component_config.layernorm_eps = config.getfloat(component,
688
- 'layer_norm_epsilon',
689
- fallback=False)
690
-
691
- normalize_before = config.getboolean(component, 'normalize_before')
692
- component_config.layernorm_position = layernorm_position_map[
693
- 'pre_layernorm' if normalize_before else 'post_layernorm']
694
-
695
- component_config.layernorm_type = layernorm_type_map[config.get(
696
- component, 'layernorm_type', fallback='LayerNorm')]
697
- component_config.hidden_act = config.get(component,
698
- 'activation_function')
699
- component_config.gated_act = config.getboolean(component,
700
- 'is_gated_act',
701
- fallback=False)
702
- component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.
703
- gated_act else 'MLP']
704
- component_config.relative_attention = config.get(
705
- 'structure', 'position_embedding_type') == 'relative'
706
-
707
- component_config.num_buckets = config.getint(
708
- component, 'relative_attention_num_buckets', fallback=0)
709
- component_config.max_distance = config.getint(
710
- component, 'relative_attention_max_distance', fallback=0)
711
- component_config.max_lora_rank = config.getint(component,
712
- 'max_lora_rank',
713
- fallback=0)
714
- component_config.lora_target_modules = literal_eval(
715
- config.get(component, 'lora_target_modules', fallback="[]"))
716
- component_config.hf_modules_to_trtllm_modules = literal_eval(
717
- config.get(component, 'hf_modules_to_trtllm_modules',
718
- fallback="{}"))
719
- component_config.trtllm_modules_to_hf_modules = literal_eval(
720
- config.get(component, 'trtllm_modules_to_hf_modules',
721
- fallback="{}"))
722
- component_config.logits_dtype = config.get(component,
723
- 'logits_dtype',
724
- fallback='float32')
725
- component_config.position_embedding_type = config.get(
726
- 'structure', 'position_embedding_type')
727
-
728
- if component == 'decoder':
729
- component_config.rescale_before_lm_head = config.getboolean(
730
- component, 'rescale_before_lm_head')
731
-
732
- component_config.encoder_hidden_size = config.getint(
733
- 'encoder', 'd_model')
734
- component_config.encoder_num_heads = config.getint(
735
- 'encoder', 'encoder_attention_heads')
736
- component_config.encoder_head_size = config.getint(
737
- 'encoder',
738
- 'd_kv',
739
- fallback=component_config.encoder_hidden_size //
740
- component_config.encoder_num_heads)
741
-
742
- # nougat has decoder_start_token_id = None, special handling
743
- decoder_start_token_id = config.get('decoder',
744
- 'decoder_start_token_id')
745
- component_config.decoder_start_token_id = int(
746
- decoder_start_token_id
747
- ) if decoder_start_token_id != "None" else None
748
- component_config.eos_token_id = config.getint(
749
- 'decoder', 'eos_token_id')
750
- component_config.bos_token_id = config.getint(
751
- 'decoder', 'bos_token_id')
752
- component_config.pad_token_id = config.getint(
753
- 'decoder', 'pad_token_id')
754
-
755
- return component_config
756
-
757
- encoder_config = None
758
- if not args.nougat:
759
- encoder_config = parse_bart_config_by_component(config, "encoder", args)
760
- decoder_config = parse_bart_config_by_component(config, "decoder", args)
761
-
762
- return encoder_config, decoder_config
763
-
764
-
765
- def convert_bart_weights_to_tllm_safetensors(config, component, params):
766
- weights = {}
767
-
768
- mapping = config.mapping
769
-
770
- hidden_size = config.hidden_size
771
-
772
- convert_weight_to_dtype(params, config.dtype)
773
- ffn_hidden_size = config.intermediate_size
774
- vocab_size = config.vocab_size
775
-
776
- hf_param_prefix = f'model.{component}'
777
- trtllm_layer_name = f'{component}_layers'
778
- trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention'
779
- trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm'
780
- embedding_layer_names = {
781
- 'embed_tokens.weight': {
782
- "name": 'embedding.vocab_embedding.weight',
783
- "shape": (vocab_size, -1)
784
- },
785
- 'embed_positions.weight': {
786
- "name": 'embedding.position_embedding.weight',
787
- "shape": (config.max_position_embeddings, hidden_size)
788
- },
789
- 'layernorm_embedding.weight': {
790
- "name": 'embedding.embedding_layernorm.weight',
791
- "shape": None
792
- },
793
- 'layernorm_embedding.bias': {
794
- "name": 'embedding.embedding_layernorm.bias',
795
- "shape": None
796
- },
797
- }
798
-
799
- hidden_layer_name_split = {
800
- 'self_attn.out_proj.weight': {
801
- "name": f'{trtllm_attn_layer_name}.dense.weight',
802
- "shape": (hidden_size, hidden_size // mapping.tp_size),
803
- "split_dim": -1
804
- },
805
- 'fc1.weight': {
806
- "name": 'mlp.fc.weight',
807
- "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
808
- "split_dim": 0
809
- },
810
- 'fc1.bias': {
811
- "name": 'mlp.fc.bias',
812
- "shape": (ffn_hidden_size // mapping.tp_size),
813
- "split_dim": 0
814
- },
815
- 'fc2.weight': {
816
- "name": 'mlp.proj.weight',
817
- "shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
818
- "split_dim": -1
819
- },
820
- }
821
-
822
- hidden_layer_name_no_split = {
823
- 'self_attn.out_proj.bias': {
824
- "name": f'{trtllm_attn_layer_name}.dense.bias',
825
- "shape": (hidden_size)
826
- },
827
- 'self_attn_layer_norm.weight': {
828
- "name": f'{trtllm_attn_layernorm_name}.weight',
829
- "shape": None
830
- },
831
- 'self_attn_layer_norm.bias': {
832
- "name": f'{trtllm_attn_layernorm_name}.bias',
833
- "shape": None
834
- },
835
- 'fc2.bias': {
836
- "name": 'mlp.proj.bias',
837
- "shape": (hidden_size)
838
- },
839
- 'final_layer_norm.weight': {
840
- "name": 'mlp_layernorm.weight',
841
- "shape": None
842
- },
843
- 'final_layer_norm.bias': {
844
- "name": 'mlp_layernorm.bias',
845
- "shape": None
846
- },
847
- }
848
-
849
- if config.model_type == 'mbart':
850
- hidden_layer_name_split['layer_norm.weight'] = {
851
- "name": 'final_layernorm.weight',
852
- "shape": None,
853
- "split_dim": 0
854
- }
855
- hidden_layer_name_no_split['layer_norm.bias'] = {
856
- "name": 'final_layernorm.bias',
857
- "shape": None,
858
- "split_dim": 0
859
- }
860
-
861
- if component == "decoder":
862
- hidden_layer_name_split.update({
863
- 'encoder_attn.out_proj.weight': {
864
- "name": 'cross_attention.dense.weight',
865
- "shape": (hidden_size, hidden_size // mapping.tp_size),
866
- "split_dim": -1
867
- }
868
- })
869
- hidden_layer_name_no_split.update({
870
- 'encoder_attn.out_proj.bias': {
871
- "name": 'cross_attention.dense.bias',
872
- "shape": (hidden_size)
873
- },
874
- 'encoder_attn_layer_norm.weight': {
875
- "name": 'cross_attention_layernorm.weight',
876
- "shape": None
877
- },
878
- 'encoder_attn_layer_norm.bias': {
879
- "name": 'cross_attention_layernorm.bias',
880
- "shape": None
881
- },
882
- })
883
-
884
- def get_attn_module_name(component, layer, attn_type):
885
- return f'model.{component}.layers.{int(layer)}.{attn_type}'
886
-
887
- for hf_weight_name, weight_info in embedding_layer_names.items():
888
- if 'position' in hf_weight_name:
889
- weights[weight_info["name"]] = params[
890
- f'{hf_param_prefix}.{hf_weight_name}'][2:].clone()
891
- else:
892
- weights[weight_info["name"]] = params[
893
- f'{hf_param_prefix}.{hf_weight_name}'].clone()
894
- weights[weight_info["name"]] = reshape(weights[weight_info["name"]],
895
- weight_info["shape"])
896
-
897
- num_layers = config.num_hidden_layers
898
-
899
- layers_range = mapping.pp_layers(num_layers)
900
- for layer_idx in layers_range:
901
- local_layer_idx = layer_idx - layers_range[0]
902
- hf_layer_name_prefix = f'{hf_param_prefix}.layers.{layer_idx}'
903
- trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'
904
-
905
- for hf_weight_name, weight_info in hidden_layer_name_split.items():
906
- weights[
907
- f'{trtllm_layer_name_prefix}.{weight_info["name"]}'] = reshape(
908
- split(params[f'{hf_layer_name_prefix}.{hf_weight_name}'],
909
- mapping.tp_size,
910
- mapping.tp_rank,
911
- dim=weight_info["split_dim"]), weight_info["shape"])
912
-
913
- for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
914
- trtllm_layer_fullname = f'{trtllm_layer_name_prefix}.{weight_info["name"]}'
915
- hf_layer_fullname = f'{hf_layer_name_prefix}.{hf_weight_name}'
916
- weights[trtllm_layer_fullname] = reshape(
917
- params[hf_layer_fullname].clone(), shape=weight_info["shape"])
918
-
919
- self_attn_module_name = get_attn_module_name(component, layer_idx,
920
- 'self_attn')
921
- weights.update(
922
- fuse_qkv_one_layer(
923
- params, self_attn_module_name,
924
- f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
925
- mapping.tp_size, mapping.tp_rank, config.model_type,
926
- (hidden_size * 3 // mapping.tp_size, hidden_size),
927
- (hidden_size * 3 // mapping.tp_size)))
928
- if component == 'decoder':
929
- cross_attn_module_name = get_attn_module_name(
930
- component, layer_idx, 'encoder_attn')
931
- weights.update(
932
- fuse_qkv_one_layer(
933
- params, cross_attn_module_name,
934
- f'{trtllm_layer_name_prefix}.cross_attention',
935
- mapping.tp_size, mapping.tp_rank, config.model_type,
936
- (hidden_size * 3 // mapping.tp_size, hidden_size),
937
- (hidden_size * 3 // mapping.tp_size)))
938
-
939
- if component == 'decoder':
940
- weights['lm_head.weight'] = reshape(
941
- split(params['lm_head.weight'],
942
- mapping.tp_size,
943
- mapping.tp_rank,
944
- dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
945
-
946
- if config.has_model_final_layernorm:
947
- weights['final_layernorm.weight'] = params[
948
- f'{hf_param_prefix}.layer_norm.weight'].clone()
949
- weights['final_layernorm.bias'] = params[
950
- f'{hf_param_prefix}.layer_norm.bias'].clone()
951
-
952
- return weights
953
-
954
-
955
- def parse_pix2struct_config(args, hf_model):
956
- # manually set q_scaling to offset attention scaling's effect.
957
- # TODO: modify kernels to control whether to disable attention scaling
958
- config = configparser.ConfigParser()
959
-
960
- def get_offset_q_scaling(config) -> str:
961
- d_model = config.hidden_size
962
- num_heads = config.num_heads
963
- head_size = d_model / num_heads
964
- scaling = 1 / head_size**.5
965
- return str(scaling)
966
-
967
- config["decoder"] = {}
968
- for key, val in hf_model.decoder.config.to_dict().items():
969
- config["decoder"][key] = f"{val}"
970
-
971
- config["decoder"]["q_scaling"] = get_offset_q_scaling(
972
- hf_model.decoder.config)
973
-
974
- config["structure"] = dict()
975
- config["structure"]["pix2struct_with_bias"] = "false"
976
- config["structure"]["use_gated_activation"] = "false"
977
- config["structure"]["position_embedding_type"] = "relative"
978
- config["structure"]["model_type"] = args.model_type
979
-
980
- def parse_pix2struct_config_by_component(config, component, args):
981
- if component == 'decoder':
982
- args.n_layer = config.getint(component, 'num_layers')
983
- args.n_head = config.getint(component, 'num_heads')
984
- args.head_size = config.getint(component, 'd_kv')
985
- args.hidden_size = config.getint(component, 'hidden_size')
986
- args.ffn_hidden_size = config.getint(component, 'd_ff')
987
- args.vocab_size = config.getint(component, 'vocab_size')
988
- args.n_positions = config.getint(component,
989
- 'n_positions',
990
- fallback=512)
991
- args.has_position_embedding = config.getboolean(
992
- component, 'has_position_embedding',
993
- fallback=False) # TODO: hardcoded here
994
- args.has_token_type_embedding = config.getboolean(
995
- component, 'has_token_type_embedding', fallback=False)
996
- args.has_embedding_layernorm = config.getboolean(
997
- component, 'has_embedding_layernorm', fallback=False)
998
- args.has_embedding_scale = config.getboolean(component,
999
- 'has_embedding_scale',
1000
- fallback=False)
1001
- args.q_scaling = config.getfloat(component,
1002
- 'q_scaling',
1003
- fallback=1.0)
1004
- args.has_attention_qkvo_bias = config.getboolean(
1005
- component, 'has_attention_qkvo_bias', fallback=False)
1006
- args.has_mlp_bias = config.getboolean(component,
1007
- 'has_mlp_bias',
1008
- fallback=False)
1009
- args.has_model_final_layernorm = config.getboolean(
1010
- component, 'has_model_final_layernorm', fallback=True)
1011
- args.layernorm_eps = config.getfloat(component,
1012
- 'layer_norm_epsilon')
1013
- args.layernorm_position = layernorm_position_map[config.get(
1014
- component, 'layernorm_position',
1015
- fallback='pre_layernorm')] # TODO: hardcoded here
1016
- args.layernorm_type = layernorm_type_map[config.get(
1017
- component, 'layernorm_type', fallback='RmsNorm')]
1018
- args.hidden_act = config.get(component, 'dense_act_fn')
1019
- args.gated_act = True
1020
- args.mlp_type = mlp_type_map['GatedMLP' if args.
1021
- gated_act else 'MLP']
1022
- args.has_lm_head_bias = config.getboolean(
1023
- component, # TODO: T5 with bias
1024
- 'has_lm_head_bias',
1025
- fallback=False)
1026
- args.relative_attention = config.getboolean(component,
1027
- 'relative_attention',
1028
- fallback=True)
1029
- args.num_buckets = config.getint(component,
1030
- 'relative_attention_num_buckets')
1031
- args.max_distance = config.getint(
1032
- component, 'relative_attention_max_distance')
1033
- args.logits_dtype = config.get(component,
1034
- 'logits_dtype',
1035
- fallback='float32')
1036
- args.rescale_before_lm_head = config.getboolean(
1037
- component, 'tie_word_embeddings'
1038
- ) # default is True (for T5), but False for Flan-T5
1039
- args.encoder_hidden_size = config.getint('decoder', 'hidden_size')
1040
- args.encoder_num_heads = config.getint('decoder', 'num_heads')
1041
- args.encoder_head_size = config.getint('decoder', 'd_kv')
1042
- args.position_embedding_type = config.get(
1043
- 'structure', 'position_embedding_type')
1044
- args.decoder_start_token_id = config.getint(
1045
- 'decoder', 'decoder_start_token_id')
1046
- args.eos_token_id = config.getint('decoder', 'eos_token_id')
1047
- bos_token_id = config.get('decoder', 'bos_token_id')
1048
- # pix2struct does not have bos_token_id
1049
- args.bos_token_id = int(
1050
- bos_token_id) if bos_token_id != "None" else None
1051
- args.pad_token_id = config.getint('decoder', 'pad_token_id')
1052
-
1053
- else:
1054
- assert False, 'Unsupported component!'
1055
- return args
1056
-
1057
- decoder_args = parse_pix2struct_config_by_component(config, "decoder", args)
1058
- return None, decoder_args
1059
-
1060
-
1061
- def convert_pix2struct_weights_to_tllm_safetensors(config, component, params):
1062
- weights = {}
1063
-
1064
- mapping = config.mapping
1065
-
1066
- convert_weight_to_dtype(params, config.dtype)
1067
- hidden_size = config.hidden_size
1068
- ffn_hidden_size = config.intermediate_size
1069
- num_layers = config.num_hidden_layers
1070
- n_head = config.num_attention_heads
1071
- head_size = config.head_size
1072
- attention_hidden_size = n_head * head_size # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5
1073
-
1074
- hf_param_prefix = f'{component}'
1075
- trtllm_layer_name = f'{component}_layers'
1076
- trtllm_attn_layer_name = 'self_attention'
1077
- trtllm_attn_layernorm_name = 'self_attention_layernorm'
1078
-
1079
- def get_attn_module_name(component, layer, attn_type):
1080
- return f'{component}.layer.{int(layer)}.{attn_type}.attention'
1081
-
1082
- weights['embedding.vocab_embedding.weight'] = reshape(
1083
- params[f'{hf_param_prefix}.embed_tokens.weight'].clone(), None)
1084
-
1085
- layers_range = mapping.pp_layers(num_layers)
1086
- for layer_idx in layers_range:
1087
- local_layer_idx = layer_idx - layers_range[0]
1088
- trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'
1089
- hf_layer_name_prefix = f'{hf_param_prefix}.layer.{layer_idx}'
1090
-
1091
- hidden_layer_name_split = {
1092
- f'{hf_layer_name_prefix}.self_attention.attention.output.weight': {
1093
- "name":
1094
- f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight',
1095
- "shape":
1096
- (hidden_size, attention_hidden_size // mapping.tp_size),
1097
- "split_dim": -1
1098
- },
1099
- f'{hf_layer_name_prefix}.mlp.DenseReluDense.wo.weight': {
1100
- "name": f'{trtllm_layer_name_prefix}.mlp.proj.weight',
1101
- "shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
1102
- "split_dim": -1
1103
- },
1104
- f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_0.weight': {
1105
- "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight',
1106
- "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
1107
- "split_dim": 0
1108
- },
1109
- }
1110
-
1111
- hidden_layer_name_no_split = {
1112
- f'{hf_layer_name_prefix}.self_attention.layer_norm.weight': {
1113
- "name":
1114
- f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight',
1115
- "shape": None
1116
- },
1117
- f'{hf_layer_name_prefix}.mlp.layer_norm.weight': {
1118
- "name": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight',
1119
- "shape": None
1120
- },
1121
- }
1122
-
1123
- if config.gated_act:
1124
- hidden_layer_name_split.update({
1125
- f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_1.weight': {
1126
- "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight',
1127
- "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
1128
- "split_dim": 0
1129
- },
1130
- })
1131
-
1132
- hidden_layer_name_split.update({
1133
- f'{hf_layer_name_prefix}.encoder_decoder_attention.attention.output.weight':
1134
- {
1135
- "name":
1136
- f'{trtllm_layer_name_prefix}.cross_attention.dense.weight',
1137
- "shape":
1138
- (hidden_size, attention_hidden_size // mapping.tp_size),
1139
- "split_dim": -1
1140
- },
1141
- })
1142
- hidden_layer_name_no_split.update({
1143
- f'{hf_layer_name_prefix}.encoder_decoder_attention.layer_norm.weight':
1144
- {
1145
- "name":
1146
- f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight',
1147
- "shape": None
1148
- },
1149
- })
1150
- self_attn_module_name = get_attn_module_name(
1151
- component, layer_idx, 'encoder_decoder_attention')
1152
- weights.update(
1153
- fuse_qkv_one_layer(
1154
- params, self_attn_module_name,
1155
- f'{trtllm_layer_name_prefix}.cross_attention', mapping.tp_size,
1156
- mapping.tp_rank, config.model_type,
1157
- (attention_hidden_size * 3 // mapping.tp_size, hidden_size),
1158
- None))
1159
-
1160
- self_attn_module_name = get_attn_module_name(component, layer_idx,
1161
- 'self_attention')
1162
- weights.update(
1163
- fuse_qkv_one_layer(
1164
- params, self_attn_module_name,
1165
- f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
1166
- mapping.tp_size, mapping.tp_rank, config.model_type,
1167
- (attention_hidden_size * 3 // mapping.tp_size, hidden_size),
1168
- None))
1169
-
1170
- weights[
1171
- f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.rel_attn_table'] = reshape(
1172
- split(
1173
- params[
1174
- f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight']
1175
- .T, mapping.tp_size, mapping.tp_rank, 0),
1176
- (n_head // mapping.tp_size, config.num_buckets))
1177
-
1178
- for hf_weight_name, weight_info in hidden_layer_name_split.items():
1179
- if hf_weight_name in params.keys():
1180
- weights[weight_info["name"]] = reshape(
1181
- split(params[hf_weight_name],
1182
- mapping.tp_size,
1183
- mapping.tp_rank,
1184
- dim=weight_info["split_dim"]), weight_info["shape"])
1185
- for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
1186
- if hf_weight_name in params.keys():
1187
- weights[weight_info["name"]] = reshape(
1188
- params[hf_weight_name].clone(), shape=weight_info["shape"])
1189
-
1190
- weights[f'final_layernorm.weight'] = reshape(
1191
- params[f'{component}.final_layer_norm.weight'].clone(), None)
1192
-
1193
- weights['lm_head.weight'] = reshape(
1194
- split(params[f'{component}.lm_head.weight'],
1195
- mapping.tp_size,
1196
- mapping.tp_rank,
1197
- dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
1198
- if not config.use_implicit_relative_attention:
1199
- weights[f'rel_attn_table'] = reshape(
1200
- split(
1201
- params[
1202
- f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight']
1203
- .T, mapping.tp_size, mapping.tp_rank, 0),
1204
- (n_head // mapping.tp_size, config.num_buckets))
1205
-
1206
- return weights
1207
-
1208
-
1209
- def get_model(args):
1210
- if args.model_type == "t5":
1211
- model = T5ForConditionalGeneration.from_pretrained(args.model_dir)
1212
- elif args.model_type == "nmt":
1213
- from fairseq.models.transformer import TransformerModel
1214
- model = TransformerModel.from_pretrained(args.model_dir)
1215
- elif args.model_type == "bart":
1216
- if args.nougat:
1217
- model = VisionEncoderDecoderModel.from_pretrained(args.model_dir)
1218
- model = model.get_decoder()
1219
- else:
1220
- model = AutoModelForSeq2SeqLM.from_pretrained(args.model_dir)
1221
- elif args.model_type == "pix2struct":
1222
- model = Pix2StructForConditionalGeneration.from_pretrained(
1223
- args.model_dir)
1224
- elif args.model_type == "blip2":
1225
- model = Blip2ForConditionalGeneration.from_pretrained(
1226
- args.model_dir).language_model
1227
- return model
1228
-
1229
-
1230
- def convert_checkpoint(args):
1231
-
1232
- model = get_model(args)
1233
-
1234
- saved_dir = Path(args.output_dir)
1235
- saved_dir.mkdir(parents=True, exist_ok=True)
1236
-
1237
- encoder_saved_dir = saved_dir / "encoder"
1238
- encoder_saved_dir.mkdir(parents=True, exist_ok=True)
1239
- decoder_saved_dir = saved_dir / "decoder"
1240
- decoder_saved_dir.mkdir(parents=True, exist_ok=True)
1241
-
1242
- world_size = args.tp_size * args.pp_size
1243
-
1244
- kv_cache_quant_algo = None
1245
- quant_algo = None
1246
-
1247
- model_type = args.model_type if args.model_type != "blip2" else "t5"
1248
- encoder_config, decoder_config = globals()[f'parse_{model_type}_config'](
1249
- args, model)
1250
-
1251
- additional_settings = ["gated_act"]
1252
-
1253
- if not args.nougat and args.model_type != "pix2struct":
1254
- tllm_encoder_config = {
1255
- 'architecture': "EncoderModel",
1256
- 'dtype': args.dtype,
1257
- 'logits_dtype': encoder_config.logits_dtype,
1258
- 'num_hidden_layers': encoder_config.n_layer,
1259
- 'num_attention_heads': encoder_config.n_head,
1260
- 'hidden_size': encoder_config.hidden_size,
1261
- 'norm_epsilon': encoder_config.layernorm_eps,
1262
- 'vocab_size': encoder_config.vocab_size,
1263
- 'position_embedding_type': encoder_config.position_embedding_type,
1264
- 'hidden_act': encoder_config.hidden_act,
1265
- 'quantization': {
1266
- 'quant_algo': quant_algo,
1267
- 'kv_cache_quant_algo': kv_cache_quant_algo,
1268
- },
1269
- 'mapping': {
1270
- 'world_size': world_size,
1271
- 'tp_size': args.tp_size,
1272
- 'pp_size': args.pp_size,
1273
- },
1274
- 'use_parallel_embedding': args.use_parallel_embedding,
1275
- 'embedding_sharding_dim': args.embedding_sharding_dim,
1276
- 'max_position_embeddings': encoder_config.n_positions,
1277
- 'num_key_value_heads': encoder_config.n_head,
1278
- 'head_size': encoder_config.head_size,
1279
- 'has_position_embedding': encoder_config.has_position_embedding,
1280
- 'layernorm_type': encoder_config.layernorm_type,
1281
- 'has_attention_qkvo_bias': encoder_config.has_attention_qkvo_bias,
1282
- 'has_mlp_bias': encoder_config.has_mlp_bias,
1283
- 'has_model_final_layernorm':
1284
- encoder_config.has_model_final_layernorm,
1285
- 'has_embedding_layernorm': encoder_config.has_embedding_layernorm,
1286
- 'has_embedding_scale': encoder_config.has_embedding_scale,
1287
- 'intermediate_size': encoder_config.ffn_hidden_size,
1288
- 'q_scaling': encoder_config.q_scaling,
1289
- 'layernorm_position': encoder_config.layernorm_position,
1290
- 'mlp_type': encoder_config.mlp_type,
1291
- 'relative_attention': encoder_config.relative_attention,
1292
- 'max_distance': encoder_config.max_distance,
1293
- 'num_buckets': encoder_config.num_buckets,
1294
- 'model_type': encoder_config.model_type,
1295
- }
1296
-
1297
- for additional_setting in additional_settings:
1298
- if hasattr(encoder_config, additional_setting):
1299
- tllm_encoder_config.update({
1300
- additional_setting:
1301
- getattr(encoder_config, additional_setting)
1302
- })
1303
-
1304
- with (encoder_saved_dir / "config.json").open('w') as f:
1305
- json.dump(tllm_encoder_config, f, indent=4)
1306
-
1307
- encoder_convert_args = dict(params=model.state_dict(),
1308
- component="encoder")
1309
- tllm_decoder_config = {
1310
- 'architecture': "DecoderModel",
1311
- 'dtype': args.dtype,
1312
- 'logits_dtype': decoder_config.logits_dtype,
1313
- 'num_hidden_layers': decoder_config.n_layer,
1314
- 'num_attention_heads': decoder_config.n_head,
1315
- 'hidden_size': decoder_config.hidden_size,
1316
- 'norm_epsilon': decoder_config.layernorm_eps,
1317
- 'vocab_size': decoder_config.vocab_size,
1318
- 'position_embedding_type': decoder_config.position_embedding_type,
1319
- 'hidden_act': decoder_config.hidden_act,
1320
- 'quantization': {
1321
- 'quant_algo': quant_algo,
1322
- 'kv_cache_quant_algo': kv_cache_quant_algo,
1323
- },
1324
- 'mapping': {
1325
- 'world_size': world_size,
1326
- 'tp_size': args.tp_size,
1327
- 'pp_size': args.pp_size,
1328
- },
1329
- 'use_parallel_embedding': args.use_parallel_embedding,
1330
- 'embedding_sharding_dim': args.embedding_sharding_dim,
1331
- 'max_position_embeddings': decoder_config.n_positions,
1332
- 'head_size': decoder_config.head_size,
1333
- 'has_position_embedding': decoder_config.has_position_embedding,
1334
- 'layernorm_type': decoder_config.layernorm_type,
1335
- 'has_attention_qkvo_bias': decoder_config.has_attention_qkvo_bias,
1336
- 'has_mlp_bias': decoder_config.has_mlp_bias,
1337
- 'has_model_final_layernorm': decoder_config.has_model_final_layernorm,
1338
- 'has_embedding_layernorm': decoder_config.has_embedding_layernorm,
1339
- 'has_embedding_scale': decoder_config.has_embedding_scale,
1340
- 'intermediate_size': decoder_config.ffn_hidden_size,
1341
- 'q_scaling': decoder_config.q_scaling,
1342
- 'layernorm_position': decoder_config.layernorm_position,
1343
- 'mlp_type': decoder_config.mlp_type,
1344
- 'relative_attention': decoder_config.relative_attention,
1345
- 'max_distance': decoder_config.max_distance,
1346
- 'num_buckets': decoder_config.num_buckets,
1347
- 'model_type': decoder_config.model_type,
1348
- 'rescale_before_lm_head': decoder_config.rescale_before_lm_head,
1349
- 'encoder_hidden_size': decoder_config.encoder_hidden_size,
1350
- 'encoder_num_heads': decoder_config.encoder_num_heads,
1351
- 'encoder_head_size': decoder_config.encoder_head_size,
1352
- 'skip_cross_kv': args.skip_cross_kv,
1353
- 'use_implicit_relative_attention': args.use_implicit_relative_attention,
1354
- 'decoder_start_token_id': decoder_config.decoder_start_token_id,
1355
- 'eos_token_id': decoder_config.eos_token_id,
1356
- 'bos_token_id': decoder_config.bos_token_id,
1357
- 'pad_token_id': decoder_config.pad_token_id,
1358
- }
1359
- for additional_setting in additional_settings:
1360
- if hasattr(decoder_config, additional_setting):
1361
- tllm_decoder_config.update({
1362
- additional_setting:
1363
- getattr(decoder_config, additional_setting)
1364
- })
1365
-
1366
- with (decoder_saved_dir / "config.json").open('w') as f:
1367
- json.dump(tllm_decoder_config, f, indent=4)
1368
-
1369
- decoder_convert_args = dict(params=model.state_dict(), component="decoder")
1370
-
1371
- if args.model_type == "nmt":
1372
- fairseq_config = vars(model.cfg.model) # Namespace --> dict
1373
- num_embeddings = fairseq_config['max_source_positions']
1374
- embedding_dim = fairseq_config['encoder_embed_dim']
1375
- padding_idx = model.models[0].encoder.embed_tokens.padding_idx # 1
1376
-
1377
- sin_pos_embedding = model.models[
1378
- 0].encoder.embed_positions.get_embedding(
1379
- padding_idx + 1 + num_embeddings,
1380
- embedding_dim,
1381
- padding_idx=padding_idx) # [2 + num_embeddings, embed_dim]
1382
- sin_pos_embedding = sin_pos_embedding[2:, :] # remove offset embeddings
1383
-
1384
- encoder_convert_args["sin_pos_embedding"] = sin_pos_embedding
1385
- decoder_convert_args["sin_pos_embedding"] = sin_pos_embedding
1386
-
1387
- if args.workers == 1:
1388
- if not args.nougat and args.model_type != "pix2struct":
1389
- convert(0, world_size, args, tllm_encoder_config,
1390
- encoder_convert_args, encoder_saved_dir)
1391
- convert(0, world_size, args, tllm_decoder_config, decoder_convert_args,
1392
- decoder_saved_dir)
1393
- else:
1394
- if args.workers > world_size:
1395
- args.workers = world_size
1396
- LOGGER.info(f'Convert checkpoint using {args.workers} workers.')
1397
- import torch.multiprocessing as mp
1398
- if not args.nougat and args.model_type != "pix2struct":
1399
- mp.spawn(convert,
1400
- nprocs=args.workers,
1401
- args=(world_size, args, tllm_encoder_config,
1402
- encoder_convert_args, encoder_saved_dir))
1403
- mp.spawn(convert,
1404
- nprocs=args.workers,
1405
- args=(world_size, args, tllm_decoder_config,
1406
- decoder_convert_args, decoder_saved_dir))
1407
-
1408
-
1409
- def convert(worker_rank, world_size, args, model_config, convert_args,
1410
- saved_dir):
1411
- for rank in range(worker_rank, world_size, args.workers):
1412
- rank_config = copy.deepcopy(PretrainedConfig.from_dict(model_config))
1413
- rank_config.set_rank(rank)
1414
- weights = globals(
1415
- )[f'convert_{rank_config.model_type}_weights_to_tllm_safetensors'](
1416
- config=rank_config, **convert_args)
1417
- safetensors.torch.save_file(weights,
1418
- f'{saved_dir}/rank{rank}.safetensors')
1419
-
1420
-
1421
- if __name__ == "__main__":
1422
- parser = argparse.ArgumentParser(
1423
- formatter_class=argparse.RawTextHelpFormatter)
1424
- parser.add_argument(
1425
- '--model_type',
1426
- type=str,
1427
- default='t5',
1428
- choices=['t5', 'nmt', 'bart', 'pix2struct', 'blip2'],
1429
- help=
1430
- 'Multimodal type when this script is used for multimodal conversion.')
1431
-
1432
- parser.add_argument('--tp_size',
1433
- type=int,
1434
- default=1,
1435
- help='N-way tensor parallelism size')
1436
- parser.add_argument('--pp_size',
1437
- type=int,
1438
- default=1,
1439
- help='N-way pipeline parallelism size')
1440
- parser.add_argument("--model_dir",
1441
- "-i",
1442
- type=str,
1443
- help="Path to the framework checkpoint file",
1444
- required=True)
1445
- parser.add_argument("--output_dir",
1446
- "-o",
1447
- type=str,
1448
- help="Path to the converted TRT-LLM model weight file",
1449
- required=True)
1450
- parser.add_argument(
1451
- "--workers",
1452
- type=int,
1453
- help="How many workers to spawn for conversion (default: 4)",
1454
- default=4)
1455
- parser.add_argument("--nougat",
1456
- action="store_true",
1457
- help="Model which uses vision encoder + mbart decoder")
1458
- parser.add_argument("--verbose",
1459
- action="store_true",
1460
- help="Provide verbose messages")
1461
- parser.add_argument(
1462
- '--use_parallel_embedding',
1463
- action="store_true",
1464
- default=False,
1465
- help=
1466
- 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
1467
- )
1468
- parser.add_argument(
1469
- '--embedding_sharding_dim',
1470
- type=int,
1471
- default=0,
1472
- choices=[0, 1],
1473
- help=
1474
- 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
1475
- 'To shard it along hidden dimension, set embedding_sharding_dim=1'
1476
- 'Note: embedding sharding is only enabled when embedding_sharding_dim = 0'
1477
- )
1478
- parser.add_argument(
1479
- '--use_weight_only',
1480
- default=False,
1481
- action="store_true",
1482
- help='Quantize weights for the various GEMMs to INT4/INT8.'
1483
- 'See --weight_only_precision to set the precision')
1484
- parser.add_argument(
1485
- '--weight_only_precision',
1486
- const='int8',
1487
- type=str,
1488
- nargs='?',
1489
- default='int8',
1490
- choices=['int8', 'int4'],
1491
- help=
1492
- 'Define the precision for the weights when using weight-only quantization.'
1493
- 'You must also use --use_weight_only for that argument to have an impact.'
1494
- )
1495
- parser.add_argument(
1496
- '--dtype',
1497
- type=str,
1498
- default='float16',
1499
- choices=['float16', 'float32', 'bfloat16'],
1500
- help=
1501
- 'Target inference dtype. Weights and Computation will be in this dtype, no matter what original dtype the weight checkpoint has.'
1502
- )
1503
- parser.add_argument(
1504
- '--skip_cross_kv',
1505
- action='store_true',
1506
- help=
1507
- 'Skip redundant cross qkv computation by using TensorRT IfConditional switch (experimental).'
1508
- )
1509
- parser.add_argument(
1510
- '--use_implicit_relative_attention',
1511
- action='store_true',
1512
- help=
1513
- 'Compute relative attention bias on the fly instead of pre-compute a relative attention bias table.'
1514
- )
1515
- args = parser.parse_args()
1516
- log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s"
1517
- logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO,
1518
- format=log_format)
1519
- LOGGER.info("\n=============== Argument ===============")
1520
- for key in vars(args):
1521
- LOGGER.info(f"{key}: {vars(args)[key]}")
1522
- LOGGER.info("========================================")
1523
-
1524
- start_time = datetime.now()
1525
- convert_checkpoint(args)
1526
- stop_time = datetime.now()
1527
- run_time = (stop_time - start_time)
1528
- LOGGER.info("Spend {} (h:m:s) to convert the model".format(run_time))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/tensorrt_llm/convert/helper.py DELETED
@@ -1,95 +0,0 @@
1
- # copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/enc_dec/helper.py
2
-
3
- import typing
4
- from typing import Union
5
-
6
- import numpy as np
7
- import torch # pytype: disable=import-error
8
-
9
- from tensorrt_llm._utils import str_dtype_to_torch
10
-
11
-
12
- def split(v: Union[np.ndarray, torch.Tensor],
13
- tp_size: int,
14
- tp_rank: int,
15
- dim=0):
16
- if tp_size == 1:
17
- if isinstance(v, np.ndarray):
18
- return np.ascontiguousarray(v.copy())
19
- else:
20
- return v.clone().detach()
21
- assert len(v.shape) > 1 or dim == 0
22
- if isinstance(v, np.ndarray):
23
- return np.ascontiguousarray(
24
- np.split(v, tp_size, axis=dim)[tp_rank].copy())
25
- else:
26
- assert v.shape[dim] % tp_size == 0, \
27
- 'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.'
28
- split_size = v.shape[dim] // tp_size
29
- return v.split(split_size, dim=dim)[tp_rank].clone().detach()
30
-
31
-
32
- def reshape(v: torch.Tensor, shape=None):
33
- if shape is None:
34
- return v.contiguous()
35
- else:
36
- return v.reshape(shape).contiguous()
37
-
38
-
39
- def fuse_qkv_one_layer(params, attn_module_name, trtllm_layer_name, tp_size,
40
- tp_rank, model_type, weight_shape, bias_shape):
41
-
42
- qkv_module_names = get_qkv_module_name(model_type)
43
-
44
- weight = {}
45
-
46
- # fuse weights of q, k, v
47
- q_w = params[f'{attn_module_name}.{qkv_module_names["q"]}.weight']
48
- k_w = params[f'{attn_module_name}.{qkv_module_names["k"]}.weight']
49
- v_w = params[f'{attn_module_name}.{qkv_module_names["v"]}.weight']
50
-
51
- # fuse qkv weight
52
- shape = q_w.shape # (do, din)
53
- qkv_w = torch.cat([q_w, k_w, v_w],
54
- dim=0).reshape([3, shape[0], shape[1]]) # (3, do, din)
55
- qkv_w = split(qkv_w, tp_size, tp_rank, dim=1)
56
- weight[f'{trtllm_layer_name}.qkv.weight'] = reshape(qkv_w,
57
- shape=weight_shape)
58
-
59
- # fuse qkv biases if present
60
- if f'{attn_module_name}.{qkv_module_names["q"]}.bias' in params.keys(
61
- ) and params[f'{attn_module_name}.{qkv_module_names["q"]}.bias'] is not None:
62
- q_b = params[f'{attn_module_name}.{qkv_module_names["q"]}.bias']
63
- k_b = params[f'{attn_module_name}.{qkv_module_names["k"]}.bias']
64
- v_b = params[f'{attn_module_name}.{qkv_module_names["v"]}.bias']
65
- shape = q_b.shape[0] # (do,)
66
- qkv_b = torch.cat([q_b, k_b, v_b], dim=0).reshape([3, shape]) # (3, do)
67
- qkv_b = split(qkv_b, tp_size, tp_rank, dim=1)
68
- weight[f'{trtllm_layer_name}.qkv.bias'] = reshape(qkv_b,
69
- shape=bias_shape)
70
- return weight
71
-
72
-
73
- def get_qkv_module_name(model_type):
74
- if model_type in ["t5", "blip2"]:
75
- q = "q"
76
- k = "k"
77
- v = "v"
78
- elif model_type == "bart" or model_type == "nmt":
79
- q = "q_proj"
80
- k = "k_proj"
81
- v = "v_proj"
82
- elif model_type == "pix2struct":
83
- q = "query"
84
- k = "key"
85
- v = "value"
86
- return {"q": q, "k": k, "v": v}
87
-
88
-
89
- def convert_weight_to_dtype(params: typing.Dict[str, torch.Tensor],
90
- dtype: typing.Optional[np.dtype] = None):
91
- if dtype is not None:
92
- assert isinstance(dtype,
93
- str), f"dtype must be str, but get type {type(dtype)}"
94
- for name in params.keys():
95
- params[name] = params[name].to(str_dtype_to_torch(dtype))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/tensorrt_llm/convert_dolphin.sh DELETED
@@ -1,47 +0,0 @@
1
- #!/usr/bin/env bash
2
- set -ex
3
-
4
- ############################################################################################
5
- # Reference: https://github.com/NVIDIA/TensorRT-LLM/tree/v0.18.2/examples/multimodal#nougat
6
- ############################################################################################
7
-
8
- export LD_LIBRARY_PATH=/usr/local/lib/python3.10/site-packages/tensorrt_libs/:/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:$LD_LIBRARY_PATH
9
-
10
- # 1. Download Huggingface weights
11
- export MODEL_NAME="Dolphin"
12
- git clone https://huggingface.co/Bytedance/${MODEL_NAME} tmp/hf_models/${MODEL_NAME}
13
-
14
-
15
- export MAX_BATCH_SIZE=16
16
- export MAX_SEQ_LEN=4096
17
- export MAX_INPUT_LEN=10
18
- export MAX_ENCODER_INPUT_LEN=784
19
-
20
- # 2. Convert Huggingface weights into TRT-LLM checkpoints and build TRT engines using scripts in examples/enc_dec
21
- python ./convert/convert_checkpoint.py --model_type bart \
22
- --model_dir tmp/hf_models/${MODEL_NAME} \
23
- --output_dir tmp/trt_models/${MODEL_NAME}/bfloat16 \
24
- --tp_size 1 \
25
- --pp_size 1 \
26
- --dtype bfloat16 \
27
- --nougat
28
-
29
-
30
- trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/bfloat16/decoder \
31
- --output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/decoder \
32
- --paged_kv_cache disable \
33
- --moe_plugin disable \
34
- --gemm_plugin bfloat16 \
35
- --bert_attention_plugin bfloat16 \
36
- --gpt_attention_plugin bfloat16 \
37
- --remove_input_padding enable \
38
- --max_beam_width 1 \
39
- --max_batch_size ${MAX_BATCH_SIZE} \
40
- --max_seq_len ${MAX_SEQ_LEN} \
41
- --max_input_len ${MAX_INPUT_LEN} \
42
- --max_encoder_input_len $((${MAX_BATCH_SIZE} * ${MAX_ENCODER_INPUT_LEN})) # MAX_BATCH_SIZE (max_batch_size) * MAX_ENCODER_INPUT_LEN (num_visual_features)
43
-
44
- # 3. Generate TensorRT engines for visual components and combine everything into final pipeline.
45
- python ./convert/build_visual_engine.py --model_type nougat \
46
- --model_path tmp/hf_models/${MODEL_NAME} \
47
- --max_batch_size ${MAX_BATCH_SIZE}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/tensorrt_llm/dolphin_runner.py DELETED
@@ -1,220 +0,0 @@
1
- """
2
- Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
- SPDX-License-Identifier: MIT
4
- """
5
-
6
- import json
7
- import os
8
- from typing import Optional
9
-
10
- import tensorrt_llm
11
- import tensorrt_llm.profiler as profiler
12
- import torch
13
- from PIL import Image
14
- from pydantic import BaseModel, Field
15
- from tensorrt_llm import logger
16
- from tensorrt_llm import mpi_rank
17
- from tensorrt_llm.runtime import MultimodalModelRunner
18
- from transformers import AutoTokenizer, DonutProcessor
19
-
20
-
21
- class InferenceConfig(BaseModel):
22
- max_new_tokens: int = Field(128, description="Maximum new tokens to generate")
23
- batch_size: int = Field(1, description="Batch size for inference")
24
- log_level: str = Field("info", description="Logging level")
25
- visual_engine_dir: Optional[str] = Field(None, description="Directory for visual engine files")
26
- visual_engine_name: str = Field("model.engine", description="Visual engine filename")
27
- llm_engine_dir: Optional[str] = Field(None, description="Directory for LLM engine files")
28
- hf_model_dir: Optional[str] = Field(None, description="Hugging Face model directory")
29
- input_text: Optional[str] = Field(None, description="Input text for inference")
30
- num_beams: int = Field(1, description="Number of beams for beam search")
31
- top_k: int = Field(1, description="Top-k sampling value")
32
- top_p: float = Field(0.0, description="Top-p (nucleus) sampling value")
33
- temperature: float = Field(1.0, description="Sampling temperature")
34
- repetition_penalty: float = Field(1.0, description="Repetition penalty factor")
35
- run_profiling: bool = Field(False, description="Enable profiling mode")
36
- profiling_iterations: int = Field(20, description="Number of profiling iterations")
37
- check_accuracy: bool = Field(False, description="Enable accuracy checking")
38
- video_path: Optional[str] = Field(None, description="Path to input video file")
39
- video_num_frames: Optional[int] = Field(None, description="Number of video frames to process")
40
- image_path: Optional[str] = Field(None, description="Path to input image file")
41
- path_sep: str = Field(",", description="Path separator character")
42
- prompt_sep: str = Field(",", description="Prompt separator character")
43
- enable_context_fmha_fp32_acc: Optional[bool] = Field(
44
- None,
45
- description="Enable FP32 accumulation for context FMHA"
46
- )
47
- enable_chunked_context: bool = Field(False, description="Enable chunked context processing")
48
- use_py_session: bool = Field(False, description="Use Python session instead of C++")
49
- kv_cache_free_gpu_memory_fraction: float = Field(
50
- 0.9,
51
- description="Fraction of GPU memory free for KV cache",
52
- ge=0.0, le=1.0
53
- )
54
- cross_kv_cache_fraction: float = Field(
55
- 0.5,
56
- description="Fraction of cross-attention KV cache",
57
- ge=0.0, le=1.0
58
- )
59
- multi_block_mode: bool = Field(True, description="Enable multi-block processing mode")
60
-
61
-
62
- class DolphinRunner(MultimodalModelRunner):
63
- def __init__(self, args):
64
- self.args = args
65
-
66
- self.runtime_rank = mpi_rank()
67
- device_id = self.runtime_rank % torch.cuda.device_count()
68
- torch.cuda.set_device(device_id)
69
- self.device = "cuda:%d" % (device_id)
70
-
71
- self.stream = torch.cuda.Stream(torch.cuda.current_device())
72
- torch.cuda.set_stream(self.stream)
73
-
74
- # parse model type from visual engine config
75
- with open(os.path.join(self.args.visual_engine_dir, "config.json"),
76
- "r") as f:
77
- config = json.load(f)
78
- self.model_type = config['builder_config']['model_type']
79
- self.vision_precision = config['builder_config']['precision']
80
- self.decoder_llm = not (
81
- 't5' in self.model_type
82
- or self.model_type in ['nougat', 'pix2struct']
83
- ) # BLIP2-T5, pix2struct and Nougat are using encoder-decoder models as LLMs
84
-
85
- if self.model_type == "mllama":
86
- self.vision_input_names = [
87
- "pixel_values",
88
- "aspect_ratio_ids",
89
- "aspect_ratio_mask",
90
- ]
91
- self.vision_output_names = [
92
- "output",
93
- ]
94
- else:
95
- self.vision_input_names = ["input"]
96
- self.vision_output_names = ["output"]
97
-
98
- self.use_py_session = True
99
-
100
- self.init_image_encoder()
101
- self.init_tokenizer()
102
- self.init_processor()
103
- self.init_llm()
104
-
105
- def init_tokenizer(self):
106
- assert self.model_type == 'nougat'
107
- self.tokenizer = AutoTokenizer.from_pretrained(self.args.hf_model_dir)
108
- self.tokenizer.padding_side = "right"
109
-
110
- def init_processor(self):
111
- assert self.model_type == 'nougat'
112
- self.processor = DonutProcessor.from_pretrained(self.args.hf_model_dir, use_fast=True)
113
-
114
- def run(self, input_texts, input_images, max_new_tokens):
115
- prompts = [f"<s>{text.strip()} <Answer/>" for text in input_texts]
116
- images = self.processor(input_images, return_tensors="pt")['pixel_values'].to("cuda")
117
- prompt_ids = self.tokenizer(prompts, add_special_tokens=False, return_tensors="pt").input_ids.to("cuda")
118
-
119
- # 🚨🚨🚨 Important! If the type of prompt_ids is not int32, the output will be wrong. 🚨🚨🚨
120
- prompt_ids = prompt_ids.to(torch.int32)
121
-
122
- logger.info("---------------------------------------------------------")
123
- logger.info(f"images size: {images.size()}")
124
- logger.info(f"prompt_ids: {prompt_ids}, size: {prompt_ids.size()}, dtype: {prompt_ids.dtype}")
125
- logger.info("---------------------------------------------------------")
126
-
127
- output_texts = self.generate(input_texts,
128
- [None] * len(input_texts),
129
- images,
130
- prompt_ids,
131
- max_new_tokens,
132
- warmup=False,
133
- )
134
-
135
- return output_texts
136
-
137
- def generate(self,
138
- pre_prompt,
139
- post_prompt,
140
- image,
141
- decoder_input_ids,
142
- max_new_tokens,
143
- warmup=False,
144
- other_vision_inputs={},
145
- other_decoder_inputs={}):
146
- if not warmup:
147
- profiler.start("Generate")
148
- input_ids, input_lengths, ptuning_args, visual_features = self.preprocess(
149
- warmup, pre_prompt, post_prompt, image, other_vision_inputs)
150
-
151
- if warmup: return None
152
-
153
- # use prompt tuning to pass multimodal features
154
- # model.generate() expects the following params (see layers/embedding.py):
155
- # args[0]: prompt embedding table, [batch_size, multimodal_len, hidden_size], later flattened to [batch_size * multimodal_len, hidden_size]
156
- # args[1]: prompt task ids, [batch_size]. in multimodal case, arange(batch_size), i.e. in VILA batching mode 2, each image is treated separately in the batch instead of concated together (although the prompt embedding table has to be concated)
157
- # args[2]: prompt task vocab size, [1]. assuming all table has the same length, which in multimodal case equals to multimodal_len
158
- profiler.start("LLM")
159
- if self.model_type in ['nougat', 'pix2struct']:
160
- # Trim encoder input_ids to match visual features shape
161
- ids_shape = (min(self.args.batch_size, len(pre_prompt)), visual_features.shape[1])
162
- if self.model_type == 'nougat':
163
- input_ids = torch.zeros(ids_shape, dtype=torch.int32)
164
- elif self.model_type == 'pix2struct':
165
- input_ids = torch.ones(ids_shape, dtype=torch.int32)
166
-
167
- output_ids = self.model.generate(
168
- input_ids,
169
- decoder_input_ids,
170
- max_new_tokens,
171
- num_beams=self.args.num_beams,
172
- bos_token_id=self.tokenizer.bos_token_id,
173
- pad_token_id=self.tokenizer.pad_token_id,
174
- eos_token_id=self.tokenizer.eos_token_id,
175
- debug_mode=False,
176
- prompt_embedding_table=ptuning_args[0],
177
- prompt_tasks=ptuning_args[1],
178
- prompt_vocab_size=ptuning_args[2],
179
- )
180
- profiler.stop("LLM")
181
-
182
- if mpi_rank() == 0:
183
- # Extract a list of tensors of shape beam_width x output_ids.
184
- output_beams_list = [
185
- self.tokenizer.batch_decode(
186
- output_ids[batch_idx, :, decoder_input_ids.shape[1]:],
187
- skip_special_tokens=False) for batch_idx in range(
188
- min(self.args.batch_size, decoder_input_ids.shape[0]))
189
- ]
190
-
191
- stripped_text = [[
192
- output_beams_list[batch_idx][beam_idx].replace("</s>", "").replace("<pad>", "").strip()
193
- for beam_idx in range(self.args.num_beams)
194
- ] for batch_idx in range(
195
- min(self.args.batch_size, decoder_input_ids.shape[0]))]
196
- profiler.stop("Generate")
197
- return stripped_text
198
- else:
199
- profiler.stop("Generate")
200
- return None
201
-
202
-
203
- if __name__ == "__main__":
204
- config = InferenceConfig(
205
- max_new_tokens=4024,
206
- batch_size=16,
207
- log_level="info",
208
- hf_model_dir=f"./tmp/hf_models/Dolphin",
209
- visual_engine_dir=f"./tmp/trt_engines/Dolphin/vision_encoder",
210
- llm_engine_dir=f"./tmp/trt_engines/Dolphin/1-gpu/bfloat16",
211
- )
212
-
213
- model = DolphinRunner(config)
214
-
215
- image_path = "../../demo/page_imgs/page_1.jpeg"
216
- prompt = "Parse the reading order of this document."
217
- image = Image.open(image_path).convert("RGB")
218
- output_texts = model.run([prompt], [image], 4024)
219
- output_texts = [texts[0] for texts in output_texts]
220
- print(output_texts)