yxchng commited on
Commit
c2a24ff
1 Parent(s): 729db41

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. elia/.gitattributes +34 -0
  2. elia/LICENSE +674 -0
  3. elia/README.md +222 -0
  4. elia/__pycache__/args.cpython-37.pyc +0 -0
  5. elia/__pycache__/args.cpython-38.pyc +0 -0
  6. elia/__pycache__/transforms.cpython-37.pyc +0 -0
  7. elia/__pycache__/transforms.cpython-38.pyc +0 -0
  8. elia/__pycache__/utils.cpython-37.pyc +0 -0
  9. elia/__pycache__/utils.cpython-38.pyc +0 -0
  10. elia/app.py +310 -0
  11. elia/args.py +74 -0
  12. elia/bert/__pycache__/activations.cpython-37.pyc +0 -0
  13. elia/bert/__pycache__/activations.cpython-38.pyc +0 -0
  14. elia/bert/__pycache__/configuration_bert.cpython-37.pyc +0 -0
  15. elia/bert/__pycache__/configuration_bert.cpython-38.pyc +0 -0
  16. elia/bert/__pycache__/configuration_utils.cpython-37.pyc +0 -0
  17. elia/bert/__pycache__/configuration_utils.cpython-38.pyc +0 -0
  18. elia/bert/__pycache__/file_utils.cpython-37.pyc +0 -0
  19. elia/bert/__pycache__/file_utils.cpython-38.pyc +0 -0
  20. elia/bert/__pycache__/generation_utils.cpython-37.pyc +0 -0
  21. elia/bert/__pycache__/generation_utils.cpython-38.pyc +0 -0
  22. elia/bert/__pycache__/modeling_bert.cpython-37.pyc +0 -0
  23. elia/bert/__pycache__/modeling_bert.cpython-38.pyc +0 -0
  24. elia/bert/__pycache__/modeling_utils.cpython-37.pyc +0 -0
  25. elia/bert/__pycache__/modeling_utils.cpython-38.pyc +0 -0
  26. elia/bert/__pycache__/multimodal_bert.cpython-37.pyc +0 -0
  27. elia/bert/__pycache__/multimodal_bert.cpython-38.pyc +0 -0
  28. elia/bert/__pycache__/tokenization_bert.cpython-37.pyc +0 -0
  29. elia/bert/__pycache__/tokenization_bert.cpython-38.pyc +0 -0
  30. elia/bert/__pycache__/tokenization_utils.cpython-37.pyc +0 -0
  31. elia/bert/__pycache__/tokenization_utils.cpython-38.pyc +0 -0
  32. elia/bert/__pycache__/tokenization_utils_base.cpython-37.pyc +0 -0
  33. elia/bert/__pycache__/tokenization_utils_base.cpython-38.pyc +0 -0
  34. elia/bert/activations.py +56 -0
  35. elia/bert/configuration_bert.py +143 -0
  36. elia/bert/configuration_utils.py +408 -0
  37. elia/bert/file_utils.py +808 -0
  38. elia/bert/generation_utils.py +993 -0
  39. elia/bert/modeling_bert.py +1569 -0
  40. elia/bert/modeling_utils.py +1268 -0
  41. elia/bert/multimodal_bert.py +277 -0
  42. elia/bert/tokenization_bert.py +545 -0
  43. elia/bert/tokenization_utils.py +723 -0
  44. elia/bert/tokenization_utils_base.py +0 -0
  45. elia/demo_inference.py +295 -0
  46. elia/requirements.txt +14 -0
  47. elia/test_elia.py +312 -0
  48. elia/test_lavt.py +139 -0
  49. elia/train_elia.py +812 -0
  50. elia/train_lavt.py +444 -0
elia/.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
elia/LICENSE ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
elia/README.md ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LAVT: Language-Aware Vision Transformer for Referring Image Segmentation
2
+ Welcome to the official repository for the method presented in
3
+ "LAVT: Language-Aware Vision Transformer for Referring Image Segmentation."
4
+
5
+
6
+ ![Pipeline Image](pipeline.jpg)
7
+
8
+ Code in this repository is written using [PyTorch](https://pytorch.org/) and is organized in the following way (assuming the working directory is the root directory of this repository):
9
+ * `./lib` contains files implementing the main network.
10
+ * Inside `./lib`, `_utils.py` defines the highest-level model, which incorporates the backbone network
11
+ defined in `backbone.py` and the simple mask decoder defined in `mask_predictor.py`.
12
+ `segmentation.py` provides the model interface and initialization functions.
13
+ * `./bert` contains files migrated from [Hugging Face Transformers v3.0.2](https://huggingface.co/transformers/v3.0.2/quicktour.html),
14
+ which implement the BERT language model.
15
+ We used Transformers v3.0.2 during development but it had a bug that would appear when using `DistributedDataParallel`.
16
+ Therefore we maintain a copy of the relevant source files in this repository.
17
+ This way, the bug is fixed and code in this repository is self-contained.
18
+ * `./train.py` is invoked to train the model.
19
+ * `./test.py` is invoked to run inference on the evaluation subsets after training.
20
+ * `./refer` contains data pre-processing code and is also where data should be placed, including the images and all annotations.
21
+ It is cloned from [refer](https://github.com/lichengunc/refer).
22
+ * `./data/dataset_refer_bert.py` is where the dataset class is defined.
23
+ * `./utils.py` defines functions that track training statistics and setup
24
+ functions for `DistributedDataParallel`.
25
+
26
+
27
+ ## Updates
28
+ **June 21<sup>st</sup>, 2022**. Uploaded the training logs and trained
29
+ model weights of lavt_one.
30
+
31
+ **June 9<sup>th</sup>, 2022**.
32
+ Added a more efficient implementation of LAVT.
33
+ * To train this new model, specify `--model` as `lavt_one`
34
+ (and `lavt` is still valid for specifying the old model).
35
+ The rest of the configuration stays unchanged.
36
+ * The difference between this version and the previous one
37
+ is that the language model has been moved inside the overall model,
38
+ so that `DistributedDataParallel` needs to be applied only once.
39
+ Applying it twice (on the standalone language model and the main branch)
40
+ as done in the old implementation led to low GPU utility,
41
+ which prevented scaling up training speed with more GPUs.
42
+ We recommend training this model on 8 GPUs
43
+ (and same as before with batch size 32).
44
+
45
+ ## Setting Up
46
+ ### Preliminaries
47
+ The code has been verified to work with PyTorch v1.7.1 and Python 3.7.
48
+ 1. Clone this repository.
49
+ 2. Change directory to root of this repository.
50
+ ### Package Dependencies
51
+ 1. Create a new Conda environment with Python 3.7 then activate it:
52
+ ```shell
53
+ conda create -n lavt python==3.7
54
+ conda activate lavt
55
+ ```
56
+
57
+ 2. Install PyTorch v1.7.1 with a CUDA version that works on your cluster/machine (CUDA 10.2 is used in this example):
58
+ ```shell
59
+ conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.2 -c pytorch
60
+ ```
61
+
62
+ 3. Install the packages in `requirements.txt` via `pip`:
63
+ ```shell
64
+ pip install -r requirements.txt
65
+ ```
66
+
67
+ ### Datasets
68
+ 1. Follow instructions in the `./refer` directory to set up subdirectories
69
+ and download annotations.
70
+ This directory is a git clone (minus two data files that we do not need)
71
+ from the [refer](https://github.com/lichengunc/refer) public API.
72
+
73
+ 2. Download images from [COCO](https://cocodataset.org/#download).
74
+ Please use the first downloading link *2014 Train images [83K/13GB]*, and extract
75
+ the downloaded `train_2014.zip` file to `./refer/data/images/mscoco/images`.
76
+
77
+ ### The Initialization Weights for Training
78
+ 1. Create the `./pretrained_weights` directory where we will be storing the weights.
79
+ ```shell
80
+ mkdir ./pretrained_weights
81
+ ```
82
+ 2. Download [pre-trained classification weights of
83
+ the Swin Transformer](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth),
84
+ and put the `pth` file in `./pretrained_weights`.
85
+ These weights are needed for training to initialize the model.
86
+
87
+ ### Trained Weights of LAVT for Testing
88
+ 1. Create the `./checkpoints` directory where we will be storing the weights.
89
+ ```shell
90
+ mkdir ./checkpoints
91
+ ```
92
+ 2. Download LAVT model weights (which are stored on Google Drive) using links below and put them in `./checkpoints`.
93
+
94
+ | [RefCOCO](https://drive.google.com/file/d/13D-OeEOijV8KTC3BkFP-gOJymc6DLwVT/view?usp=sharing) | [RefCOCO+](https://drive.google.com/file/d/1B8Q44ZWsc8Pva2xD_M-KFh7-LgzeH2-2/view?usp=sharing) | [G-Ref (UMD)](https://drive.google.com/file/d/1BjUnPVpALurkGl7RXXvQiAHhA-gQYKvK/view?usp=sharing) | [G-Ref (Google)](https://drive.google.com/file/d/1weiw5UjbPfo3tCBPfB8tu6xFXCUG16yS/view?usp=sharing) |
95
+ |---|---|---|---|
96
+
97
+ 3. Model weights and training logs of the new lavt_one implementation are below.
98
+
99
+ | RefCOCO | RefCOCO+ | G-Ref (UMD) | G-Ref (Google) |
100
+ |:-----:|:-----:|:-----:|:-----:|
101
+ |[log](https://drive.google.com/file/d/1YIojIHqe3bxxsWOltifa2U9jH67hPHLM/view?usp=sharing) &#124; [weights](https://drive.google.com/file/d/1xFMEXr6AGU97Ypj1yr8oo00uObbeIQvJ/view?usp=sharing)|[log](https://drive.google.com/file/d/1Z34T4gEnWlvcSUQya7txOuM0zdLK7MRT/view?usp=sharing) &#124; [weights](https://drive.google.com/file/d/1HS8ZnGaiPJr-OmoUn4-4LVnVtD_zHY6w/view?usp=sharing)|[log](https://drive.google.com/file/d/14VAgahngOV8NA6noLZCqDoqaUrlW14v8/view?usp=sharing) &#124; [weights](https://drive.google.com/file/d/14g8NzgZn6HzC6tP_bsQuWmh5LnOcovsE/view?usp=sharing)|[log](https://drive.google.com/file/d/1JBXfmlwemWSvs92Rky0TlHcVuuLpt4Da/view?usp=sharing) &#124; [weights](https://drive.google.com/file/d/1IJeahFVLgKxu_BVmWacZs3oUzgTCeWcz/view?usp=sharing)|
102
+
103
+ * The Prec@K, overall IoU and mean IoU numbers in the training logs will differ
104
+ from the final results obtained by running `test.py`,
105
+ because only one out of multiple annotated expressions is
106
+ randomly selected and evaluated for each object during training.
107
+ But these numbers give a good idea about the test performance.
108
+ The two should be fairly close.
109
+
110
+
111
+ ## Training
112
+ We use `DistributedDataParallel` from PyTorch.
113
+ The released `lavt` weights were trained using 4 x 32G V100 cards (max mem on each card was about 26G).
114
+ The released `lavt_one` weights were trained using 8 x 32G V100 cards (max mem on each card was about 13G).
115
+ Using more cards was to accelerate training.
116
+ To run on 4 GPUs (with IDs 0, 1, 2, and 3) on a single node:
117
+ ```shell
118
+ mkdir ./models
119
+
120
+ mkdir ./models/refcoco
121
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcoco --model_id refcoco --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/refcoco/output
122
+
123
+ mkdir ./models/refcoco+
124
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcoco+ --model_id refcoco+ --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/refcoco+/output
125
+
126
+ mkdir ./models/gref_umd
127
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcocog --splitBy umd --model_id gref_umd --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/gref_umd/output
128
+
129
+ mkdir ./models/gref_google
130
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcocog --splitBy google --model_id gref_google --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/gref_google/output
131
+ ```
132
+ * *--model* is a pre-defined model name. Options include `lavt` and `lavt_one`. See [Updates](#updates).
133
+ * *--dataset* is the dataset name. One can choose from `refcoco`, `refcoco+`, and `refcocog`.
134
+ * *--splitBy* needs to be specified if and only if the dataset is G-Ref (which is also called RefCOCOg).
135
+ `umd` identifies the UMD partition and `google` identifies the Google partition.
136
+ * *--model_id* is the model name one should define oneself (*e.g.*, customize it to contain training/model configurations, dataset information, experiment IDs, *etc*.).
137
+ It is used in two ways: Training log will be saved as `./models/[args.model_id]/output` and the best checkpoint will be saved as `./checkpoints/model_best_[args.model_id].pth`.
138
+ * *--swin_type* specifies the version of the Swin Transformer.
139
+ One can choose from `tiny`, `small`, `base`, and `large`. The default is `base`.
140
+ * *--pretrained_swin_weights* specifies the path to pre-trained Swin Transformer weights used for model initialization.
141
+ * Note that currently we need to manually create the `./models/[args.model_id]` directory via `mkdir` before running `train.py`.
142
+ This is because we use `tee` to redirect `stdout` and `stderr` to `./models/[args.model_id]/output` for logging.
143
+ This is a nuisance and should be resolved in the future, *i.e.*, using a proper logger or a bash script for initiating training.
144
+
145
+ ## Testing
146
+ For RefCOCO/RefCOCO+, run one of
147
+ ```shell
148
+ python test.py --model lavt --swin_type base --dataset refcoco --split val --resume ./checkpoints/refcoco.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
149
+ python test.py --model lavt --swin_type base --dataset refcoco+ --split val --resume ./checkpoints/refcoco+.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
150
+ ```
151
+ * *--split* is the subset to evaluate, and one can choose from `val`, `testA`, and `testB`.
152
+ * *--resume* is the path to the weights of a trained model.
153
+
154
+ For G-Ref (UMD)/G-Ref (Google), run one of
155
+ ```shell
156
+ python test.py --model lavt --swin_type base --dataset refcocog --splitBy umd --split val --resume ./checkpoints/gref_umd.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
157
+ python test.py --model lavt --swin_type base --dataset refcocog --splitBy google --split val --resume ./checkpoints/gref_google.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
158
+ ```
159
+ * *--splitBy* specifies the partition to evaluate.
160
+ One can choose from `umd` or `google`.
161
+ * *--split* is the subset (according to the specified partition) to evaluate, and one can choose from `val` and `test` for the UMD partition, and only `val` for the Google partition..
162
+ * *--resume* is the path to the weights of a trained model.
163
+
164
+ ## Results
165
+ The complete test results of the released LAVT models are summarized as follows:
166
+
167
168
+ |:---------------:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----------:|:--------:|
169
+ | RefCOCO val | 84.46 | 80.90 | 75.28 | 64.71 | 34.30 | 72.73 | 74.46 |
170
+ | RefCOCO test A | 88.07 | 85.17 | 79.90 | 68.52 | 35.69 | 75.82 | 76.89 |
171
+ | RefCOCO test B | 79.12 | 74.94 | 69.17 | 59.37 | 34.45 | 68.79 | 70.94 |
172
+ | RefCOCO+ val | 74.44 | 70.91 | 65.58 | 56.34 | 30.23 | 62.14 | 65.81 |
173
+ | RefCOCO+ test A | 80.68 | 77.96 | 72.90 | 62.21 | 32.36 | 68.38 | 70.97 |
174
+ | RefCOCO+ test B | 65.66 | 61.85 | 55.94 | 47.56 | 27.24 | 55.10 | 59.23 |
175
+ | G-Ref val (UMD) | 70.81 | 65.28 | 58.60 | 47.49 | 22.73 | 61.24 | 63.34 |
176
+ | G-Ref test (UMD)| 71.54 | 66.38 | 59.00 | 48.21 | 23.10 | 62.09 | 63.62 |
177
+ |G-Ref val (Goog.)| 71.16 | 67.21 | 61.76 | 51.98 | 27.30 | 60.50 | 63.66 |
178
+
179
+ We have validated LAVT on RefCOCO with multiple runs.
180
+ The overall IoU on the val set generally lies in the range of 72.73±0.5%.
181
+
182
+
183
+ ## Demo: Try LAVT on Your Own Image-text Pairs!
184
+ One can run inference on a custom image-text pair
185
+ and visualize the result by running the script `./demo_inference.py`.
186
+ Choose your photos and expessions and have fun.
187
+
188
+
189
+ ## Citing LAVT
190
+ ```
191
+ @inproceedings{yang2022lavt,
192
+ title={LAVT: Language-Aware Vision Transformer for Referring Image Segmentation},
193
+ author={Yang, Zhao and Wang, Jiaqi and Tang, Yansong and Chen, Kai and Zhao, Hengshuang and Torr, Philip HS},
194
+ booktitle={CVPR},
195
+ year={2022}
196
+ }
197
+ ```
198
+
199
+
200
+ ## Contributing
201
+ We appreciate all contributions.
202
+ It helps the project if you could
203
+ - report issues you are facing,
204
+ - give a :+1: on issues reported by others that are relevant to you,
205
+ - answer issues reported by others for which you have found solutions,
206
+ - and implement helpful new features or improve the code otherwise with pull requests.
207
+
208
+ ## Acknowledgements
209
+ Code in this repository is built upon several public repositories.
210
+ Specifically,
211
+ * data pre-processing leverages the [refer](https://github.com/lichengunc/refer) repository,
212
+ * the backbone model is implemented based on code from [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation),
213
+ * the training and testing pipelines are adapted from [RefVOS](https://github.com/miriambellver/refvos),
214
+ * and implementation of the BERT model (files in the bert directory) is from [Hugging Face Transformers v3.0.2](https://github.com/huggingface/transformers/tree/v3.0.2)
215
+ (we migrated over the relevant code to fix a bug and simplify the installation process).
216
+
217
+ Some of these repositories in turn adapt code from [OpenMMLab](https://github.com/open-mmlab) and [TorchVision](https://github.com/pytorch/vision).
218
+ We'd like to thank the authors/organizations of these repositories for open sourcing their projects.
219
+
220
+
221
+ ## License
222
+ GNU GPLv3
elia/__pycache__/args.cpython-37.pyc ADDED
Binary file (3.67 kB). View file
 
elia/__pycache__/args.cpython-38.pyc ADDED
Binary file (3.71 kB). View file
 
elia/__pycache__/transforms.cpython-37.pyc ADDED
Binary file (5.02 kB). View file
 
elia/__pycache__/transforms.cpython-38.pyc ADDED
Binary file (4.99 kB). View file
 
elia/__pycache__/utils.cpython-37.pyc ADDED
Binary file (7.12 kB). View file
 
elia/__pycache__/utils.cpython-38.pyc ADDED
Binary file (7.23 kB). View file
 
elia/app.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ image_path = './image001.png'
4
+ sentence = 'spoon on the dish'
5
+ weights = './checkpoints/model_best_refcoco_0508.pth'
6
+ device = 'cpu'
7
+
8
+ # pre-process the input image
9
+ from PIL import Image
10
+ import torchvision.transforms as T
11
+ import numpy as np
12
+ import datetime
13
+ import os
14
+ import time
15
+
16
+ import torch
17
+ import torch.utils.data
18
+ from torch import nn
19
+
20
+ from bert.multimodal_bert import MultiModalBert
21
+ import torchvision
22
+
23
+ from lib import multimodal_segmentation_ppm
24
+ #import transforms as T
25
+ import utils
26
+
27
+ import numpy as np
28
+ from PIL import Image
29
+ import torch.nn.functional as F
30
+
31
+ from modeling.MaskFormerModel import MaskFormerHead
32
+ from addict import Dict
33
+ #from bert.modeling_bert import BertLMPredictionHead, BertEncoder
34
+ import cv2
35
+ import textwrap
36
+
37
+ class WrapperModel(nn.Module):
38
+ def __init__(self, image_model, language_model, classifier) :
39
+ super(WrapperModel, self).__init__()
40
+ self.image_model = image_model
41
+ self.language_model = language_model
42
+ self.classifier = classifier
43
+
44
+ config = Dict({
45
+ "architectures": [
46
+ "BertForMaskedLM"
47
+ ],
48
+ "attention_probs_dropout_prob": 0.1,
49
+ "gradient_checkpointing": False,
50
+ "hidden_act": "gelu",
51
+ "hidden_dropout_prob": 0.1,
52
+ "hidden_size": 512,
53
+ "initializer_range": 0.02,
54
+ "intermediate_size": 3072,
55
+ "layer_norm_eps": 1e-12,
56
+ #"max_position_embeddings": 16+20,
57
+ "model_type": "bert",
58
+ "num_attention_heads": 8,
59
+ "num_hidden_layers": 8,
60
+ "pad_token_id": 0,
61
+ "position_embedding_type": "absolute",
62
+ "transformers_version": "4.6.0.dev0",
63
+ "type_vocab_size": 2,
64
+ "use_cache": True,
65
+ "vocab_size": 30522
66
+ })
67
+
68
+
69
+
70
+ def _get_binary_mask(self, target):
71
+ # 返回每类的binary mask
72
+ y, x = target.size()
73
+ target_onehot = torch.zeros(self.num_classes + 1, y, x)
74
+ target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
75
+ return target_onehot[1:]
76
+
77
+ def semantic_inference(self, mask_cls, mask_pred):
78
+ mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
79
+ mask_pred = mask_pred.sigmoid()
80
+ semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
81
+ return semseg
82
+
83
+ def forward(self, image, sentences, attentions):
84
+ print(image.sum(), sentences.sum(), attentions.sum())
85
+ input_shape = image.shape[-2:]
86
+ l_mask = attentions.unsqueeze(dim=-1)
87
+
88
+ i0, Wh, Ww = self.image_model.forward_stem(image)
89
+ l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)
90
+
91
+ i1 = self.image_model.forward_stage1(i0, Wh, Ww)
92
+ l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
93
+ i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
94
+ l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask)
95
+ i1 = i1_temp
96
+
97
+ i2 = self.image_model.forward_stage2(i1, Wh, Ww)
98
+ l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
99
+ i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
100
+ l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask)
101
+ i2 = i2_temp
102
+
103
+ i3 = self.image_model.forward_stage3(i2, Wh, Ww)
104
+ l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
105
+ i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
106
+ l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask)
107
+ i3 = i3_temp
108
+
109
+ i4 = self.image_model.forward_stage4(i3, Wh, Ww)
110
+ l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
111
+ i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
112
+ l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask)
113
+ i4 = i4_temp
114
+
115
+ #i1_residual, i2_residual, i3_residual, i4_residual = features
116
+ #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
117
+ #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
118
+ outputs = {}
119
+ outputs['s1'] = i1_residual
120
+ outputs['s2'] = i2_residual
121
+ outputs['s3'] = i3_residual
122
+ outputs['s4'] = i4_residual
123
+
124
+ predictions = self.classifier(outputs)
125
+ return predictions
126
+
127
+ #img = Image.open(image_path).convert("RGB")
128
+
129
+ # pre-process the raw sentence
130
+ from bert.tokenization_bert import BertTokenizer
131
+ import torch
132
+
133
+ # initialize model and load weights
134
+ #from bert.modeling_bert import BertModel
135
+ #from lib import segmentation
136
+
137
+ # construct a mini args class; like from a config file
138
+
139
+
140
+ class args:
141
+ swin_type = 'base'
142
+ window12 = True
143
+ mha = ''
144
+ fusion_drop = 0.0
145
+
146
+
147
+ #single_model = segmentation.__dict__['lavt'](pretrained='', args=args)
148
+ single_model = multimodal_segmentation_ppm.__dict__['lavt'](pretrained='',args=args)
149
+ single_model.to(device)
150
+ model_class = MultiModalBert
151
+ single_bert_model = model_class.from_pretrained('bert-base-uncased', embed_dim=single_model.backbone.embed_dim)
152
+ single_bert_model.pooler = None
153
+
154
+ input_shape = dict()
155
+ input_shape['s1'] = Dict({'channel': 128, 'stride': 4})
156
+ input_shape['s2'] = Dict({'channel': 256, 'stride': 8})
157
+ input_shape['s3'] = Dict({'channel': 512, 'stride': 16})
158
+ input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})
159
+
160
+
161
+
162
+ cfg = Dict()
163
+ cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
164
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.0
165
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
166
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4
167
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
168
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
169
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]
170
+
171
+ cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
172
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
173
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1
174
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
175
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10
176
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
177
+
178
+
179
+ maskformer_head = MaskFormerHead(cfg, input_shape)
180
+
181
+
182
+ model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head)
183
+
184
+
185
+
186
+ checkpoint = torch.load(weights, map_location='cpu')
187
+
188
+ model.load_state_dict(checkpoint['model'], strict=False)
189
+ model.to(device)
190
+ model.eval()
191
+ #single_bert_model.load_state_dict(checkpoint['bert_model'])
192
+ #single_model.load_state_dict(checkpoint['model'])
193
+ #model = single_model.to(device)
194
+ #bert_model = single_bert_model.to(device)
195
+
196
+
197
+ # inference
198
+ #import torch.nn.functional as F
199
+ #last_hidden_states = bert_model(padded_sent_toks, attention_mask=attention_mask)[0]
200
+ #embedding = last_hidden_states.permute(0, 2, 1)
201
+ #output = model(img, embedding, l_mask=attention_mask.unsqueeze(-1))
202
+ #output = output.argmax(1, keepdim=True) # (1, 1, 480, 480)
203
+ #output = F.interpolate(output.float(), (original_h, original_w)) # 'nearest'; resize to the original image size
204
+ #output = output.squeeze() # (orig_h, orig_w)
205
+ #output = output.cpu().data.numpy() # (orig_h, orig_w)
206
+
207
+ #output = pred_masks[0]
208
+
209
+ #output = output.cpu()
210
+
211
+
212
+
213
+ #print(output.shape)
214
+ #output_mask = output.argmax(1).data.numpy()
215
+ #output = (output > 0.5).data.cpu().numpy()
216
+
217
+
218
+ # show/save results
219
+ def overlay_davis(image, mask, colors=[[0, 0, 0], [255, 0, 0]], cscale=1, alpha=0.4):
220
+ from scipy.ndimage.morphology import binary_dilation
221
+
222
+ colors = np.reshape(colors, (-1, 3))
223
+ colors = np.atleast_2d(colors) * cscale
224
+
225
+ im_overlay = image.copy()
226
+ object_ids = np.unique(mask)
227
+
228
+ for object_id in object_ids[1:]:
229
+ # Overlay color on binary mask
230
+ foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id])
231
+ binary_mask = mask == object_id
232
+
233
+ # Compose image
234
+ im_overlay[binary_mask] = foreground[binary_mask]
235
+
236
+ # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask
237
+ countours = binary_dilation(binary_mask) ^ binary_mask
238
+ # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask
239
+ im_overlay[countours, :] = 0
240
+
241
+ return im_overlay.astype(image.dtype)
242
+
243
+
244
+ def run_model(img, sentence):
245
+
246
+ #img = Image.open(image_path).convert("RGB")
247
+ img = Image.fromarray(img)
248
+ img = img.convert("RGB")
249
+ #print(img.shape)
250
+ img_ndarray = np.array(img) # (orig_h, orig_w, 3); for visualization
251
+ original_w, original_h = img.size # PIL .size returns width first and height second
252
+
253
+ image_transforms = T.Compose(
254
+ [
255
+ T.Resize((480, 480)),
256
+ T.ToTensor(),
257
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
258
+ ]
259
+ )
260
+
261
+ img = image_transforms(img).unsqueeze(0) # (1, 3, 480, 480)
262
+ img = img.to(device) # for inference (input)
263
+
264
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
265
+ sentence_tokenized = tokenizer.encode(text=sentence, add_special_tokens=True)
266
+ sentence_tokenized = sentence_tokenized[:20] # if the sentence is longer than 20, then this truncates it to 20 words
267
+ # pad the tokenized sentence
268
+ padded_sent_toks = [0] * 20
269
+ padded_sent_toks[:len(sentence_tokenized)] = sentence_tokenized
270
+ # create a sentence token mask: 1 for real words; 0 for padded tokens
271
+ attention_mask = [0] * 20
272
+ attention_mask[:len(sentence_tokenized)] = [1]*len(sentence_tokenized)
273
+ # convert lists to tensors
274
+ padded_sent_toks = torch.tensor(padded_sent_toks).unsqueeze(0) # (1, 20)
275
+ attention_mask = torch.tensor(attention_mask).unsqueeze(0) # (1, 20)
276
+ padded_sent_toks = padded_sent_toks.to(device) # for inference (input)
277
+ attention_mask = attention_mask.to(device) # for inference (input)
278
+
279
+ output = model(img, padded_sent_toks, attention_mask)[0]
280
+ #print(output[0].keys())
281
+ #print(output[1].shape)
282
+ mask_cls_results = output["pred_logits"]
283
+ mask_pred_results = output["pred_masks"]
284
+
285
+ target_shape = img_ndarray.shape[:2]
286
+ #print(target_shape, mask_pred_results.shape)
287
+ mask_pred_results = F.interpolate(mask_pred_results, size=(480,480), mode='bilinear', align_corners=True)
288
+
289
+ pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)
290
+
291
+ output = torch.nn.functional.interpolate(pred_masks, target_shape)
292
+ output = (output > 0.5).data.cpu().numpy()
293
+
294
+ output = output.astype(np.uint8) # (orig_h, orig_w), np.uint8
295
+ # Overlay the mask on the image
296
+ print(img_ndarray.shape, output.shape)
297
+ visualization = overlay_davis(img_ndarray, output[0][0]) # red
298
+ visualization = Image.fromarray(visualization)
299
+ # show the visualization
300
+ #visualization.show()
301
+ # Save the visualization
302
+ #visualization.save('./demo/spoon_on_the_dish.jpg')
303
+ return visualization
304
+
305
+
306
+
307
+
308
+ demo = gr.Interface(run_model, inputs=[gr.Image(), "text"], outputs=["image"])
309
+ demo.launch()
310
+
elia/args.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def get_parser():
5
+ parser = argparse.ArgumentParser(description='LAVT training and testing')
6
+ parser.add_argument('--amsgrad', action='store_true',
7
+ help='if true, set amsgrad to True in an Adam or AdamW optimizer.')
8
+ parser.add_argument('-b', '--batch-size', default=8, type=int)
9
+ parser.add_argument('--bert_tokenizer', default='bert-base-uncased', help='BERT tokenizer')
10
+ parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights')
11
+ parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog')
12
+ parser.add_argument('--ddp_trained_weights', action='store_true',
13
+ help='Only needs specified when testing,'
14
+ 'whether the weights to be loaded are from a DDP-trained model')
15
+ parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine
16
+ parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run')
17
+ parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs')
18
+ parser.add_argument('--img_size', default=480, type=int, help='input image size')
19
+ parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel')
20
+ parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate')
21
+ parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,'
22
+ 'where a, b, c, and d refer to the numbers of heads in stage-1,'
23
+ 'stage-2, stage-3, and stage-4 PWAMs')
24
+ parser.add_argument('--model', default='lavt', help='model: lavt, lavt_one')
25
+ parser.add_argument('--model_id', default='lavt', help='name to identify the model')
26
+ parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights')
27
+ parser.add_argument('--pin_mem', action='store_true',
28
+ help='If true, pin memory when using the data loader.')
29
+ parser.add_argument('--pretrained_swin_weights', default='',
30
+ help='path to pre-trained Swin backbone weights')
31
+ parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
32
+ parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory')
33
+ parser.add_argument('--resume', default='auto', help='resume from checkpoint')
34
+ parser.add_argument('--split', default='test', help='only used when testing')
35
+ parser.add_argument('--splitBy', default='unc', help='change to umd or google when the dataset is G-Ref (RefCOCOg)')
36
+ parser.add_argument('--swin_type', default='base',
37
+ help='tiny, small, base, or large variants of the Swin Transformer')
38
+ parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay',
39
+ dest='weight_decay')
40
+ parser.add_argument('--window12', action='store_true',
41
+ help='only needs specified when testing,'
42
+ 'when training, window size is inferred from pre-trained weights file name'
43
+ '(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.')
44
+ parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers')
45
+ parser.add_argument('--seed', default=0, type=int)
46
+ parser.add_argument('--max_ckpt', default=2, type=int)
47
+ parser.add_argument('--num_object_queries', default=1, type=int)
48
+ parser.add_argument('--no_object_weight', default=0.0, type=float)
49
+ parser.add_argument('--class_weight', default=2.0, type=float)
50
+ parser.add_argument('--dice_weight', default=2.0, type=float)
51
+ parser.add_argument('--mask_weight', default=2.0, type=float)
52
+ parser.add_argument('--train_num_points', default=12544, type=int)
53
+ parser.add_argument('--dim_feedforward', default=2048, type=int)
54
+ parser.add_argument('--dec_layers', default=10, type=int)
55
+ parser.add_argument('--transformer_enc_layers', default=4, type=int)
56
+
57
+ parser.add_argument('--plic_pos_weight', default=0.5, type=float)
58
+ parser.add_argument('--plic_neg_weight', default=0.5, type=float)
59
+ parser.add_argument('--plic_lang_weight', default=0.5, type=float)
60
+ parser.add_argument('--plic_pos_alpha', default=0.0, type=float)
61
+ parser.add_argument('--plic_neg_alpha', default=0.0, type=float)
62
+ parser.add_argument('--plic_lang_alpha', default=0.0, type=float)
63
+ parser.add_argument('--plic_pos_temp', default=0.2, type=float)
64
+ parser.add_argument('--plic_neg_temp', default=0.2, type=float)
65
+ parser.add_argument('--plic_lang_temp', default=0.2, type=float)
66
+ parser.add_argument('--smlm_weight', default=1.0, type=float)
67
+ parser.add_argument('--vis_dir', default='./vis_dir')
68
+
69
+ return parser
70
+
71
+
72
+ if __name__ == "__main__":
73
+ parser = get_parser()
74
+ args_dict = parser.parse_args()
elia/bert/__pycache__/activations.cpython-37.pyc ADDED
Binary file (1.94 kB). View file
 
elia/bert/__pycache__/activations.cpython-38.pyc ADDED
Binary file (1.94 kB). View file
 
elia/bert/__pycache__/configuration_bert.cpython-37.pyc ADDED
Binary file (7.86 kB). View file
 
elia/bert/__pycache__/configuration_bert.cpython-38.pyc ADDED
Binary file (7.87 kB). View file
 
elia/bert/__pycache__/configuration_utils.cpython-37.pyc ADDED
Binary file (16.2 kB). View file
 
elia/bert/__pycache__/configuration_utils.cpython-38.pyc ADDED
Binary file (16.3 kB). View file
 
elia/bert/__pycache__/file_utils.cpython-37.pyc ADDED
Binary file (24.4 kB). View file
 
elia/bert/__pycache__/file_utils.cpython-38.pyc ADDED
Binary file (24.5 kB). View file
 
elia/bert/__pycache__/generation_utils.cpython-37.pyc ADDED
Binary file (27.9 kB). View file
 
elia/bert/__pycache__/generation_utils.cpython-38.pyc ADDED
Binary file (28.2 kB). View file
 
elia/bert/__pycache__/modeling_bert.cpython-37.pyc ADDED
Binary file (56.4 kB). View file
 
elia/bert/__pycache__/modeling_bert.cpython-38.pyc ADDED
Binary file (55.3 kB). View file
 
elia/bert/__pycache__/modeling_utils.cpython-37.pyc ADDED
Binary file (48 kB). View file
 
elia/bert/__pycache__/modeling_utils.cpython-38.pyc ADDED
Binary file (48 kB). View file
 
elia/bert/__pycache__/multimodal_bert.cpython-37.pyc ADDED
Binary file (6.23 kB). View file
 
elia/bert/__pycache__/multimodal_bert.cpython-38.pyc ADDED
Binary file (6.04 kB). View file
 
elia/bert/__pycache__/tokenization_bert.cpython-37.pyc ADDED
Binary file (19.1 kB). View file
 
elia/bert/__pycache__/tokenization_bert.cpython-38.pyc ADDED
Binary file (19.3 kB). View file
 
elia/bert/__pycache__/tokenization_utils.cpython-37.pyc ADDED
Binary file (24.7 kB). View file
 
elia/bert/__pycache__/tokenization_utils.cpython-38.pyc ADDED
Binary file (24.9 kB). View file
 
elia/bert/__pycache__/tokenization_utils_base.cpython-37.pyc ADDED
Binary file (81.7 kB). View file
 
elia/bert/__pycache__/tokenization_utils_base.cpython-38.pyc ADDED
Binary file (82.4 kB). View file
 
elia/bert/activations.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def swish(x):
12
+ return x * torch.sigmoid(x)
13
+
14
+
15
+ def _gelu_python(x):
16
+ """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
17
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
18
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
19
+ This is now written in C in torch.nn.functional
20
+ Also see https://arxiv.org/abs/1606.08415
21
+ """
22
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
23
+
24
+
25
+ def gelu_new(x):
26
+ """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
27
+ Also see https://arxiv.org/abs/1606.08415
28
+ """
29
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
30
+
31
+
32
+ if torch.__version__ < "1.4.0":
33
+ gelu = _gelu_python
34
+ else:
35
+ gelu = F.gelu
36
+
37
+
38
+ def gelu_fast(x):
39
+ return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
40
+
41
+
42
+ ACT2FN = {
43
+ "relu": F.relu,
44
+ "swish": swish,
45
+ "gelu": gelu,
46
+ "tanh": torch.tanh,
47
+ "gelu_new": gelu_new,
48
+ "gelu_fast": gelu_fast,
49
+ }
50
+
51
+
52
+ def get_activation(activation_string):
53
+ if activation_string in ACT2FN:
54
+ return ACT2FN[activation_string]
55
+ else:
56
+ raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
elia/bert/configuration_bert.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ BERT model configuration """
17
+
18
+
19
+ import logging
20
+
21
+ from .configuration_utils import PretrainedConfig
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
28
+ "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
29
+ "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
30
+ "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
31
+ "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
32
+ "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
33
+ "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
34
+ "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
35
+ "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
36
+ "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
37
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
38
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
39
+ "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
40
+ "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
41
+ "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
42
+ "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json",
43
+ "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json",
44
+ "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json",
45
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json",
46
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
47
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
48
+ "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
49
+ # See all BERT models at https://huggingface.co/models?filter=bert
50
+ }
51
+
52
+
53
+ class BertConfig(PretrainedConfig):
54
+ r"""
55
+ This is the configuration class to store the configuration of a :class:`~transformers.BertModel`.
56
+ It is used to instantiate an BERT model according to the specified arguments, defining the model
57
+ architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
58
+ the BERT `bert-base-uncased <https://huggingface.co/bert-base-uncased>`__ architecture.
59
+
60
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
61
+ to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
62
+ for more information.
63
+
64
+
65
+ Args:
66
+ vocab_size (:obj:`int`, optional, defaults to 30522):
67
+ Vocabulary size of the BERT model. Defines the different tokens that
68
+ can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
69
+ hidden_size (:obj:`int`, optional, defaults to 768):
70
+ Dimensionality of the encoder layers and the pooler layer.
71
+ num_hidden_layers (:obj:`int`, optional, defaults to 12):
72
+ Number of hidden layers in the Transformer encoder.
73
+ num_attention_heads (:obj:`int`, optional, defaults to 12):
74
+ Number of attention heads for each attention layer in the Transformer encoder.
75
+ intermediate_size (:obj:`int`, optional, defaults to 3072):
76
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
77
+ hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
78
+ The non-linear activation function (function or string) in the encoder and pooler.
79
+ If string, "gelu", "relu", "swish" and "gelu_new" are supported.
80
+ hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
81
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
82
+ attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
83
+ The dropout ratio for the attention probabilities.
84
+ max_position_embeddings (:obj:`int`, optional, defaults to 512):
85
+ The maximum sequence length that this model might ever be used with.
86
+ Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
87
+ type_vocab_size (:obj:`int`, optional, defaults to 2):
88
+ The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
89
+ initializer_range (:obj:`float`, optional, defaults to 0.02):
90
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
91
+ layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
92
+ The epsilon used by the layer normalization layers.
93
+ gradient_checkpointing (:obj:`bool`, optional, defaults to False):
94
+ If True, use gradient checkpointing to save memory at the expense of slower backward pass.
95
+
96
+ Example::
97
+
98
+ >>> from transformers import BertModel, BertConfig
99
+
100
+ >>> # Initializing a BERT bert-base-uncased style configuration
101
+ >>> configuration = BertConfig()
102
+
103
+ >>> # Initializing a model from the bert-base-uncased style configuration
104
+ >>> model = BertModel(configuration)
105
+
106
+ >>> # Accessing the model configuration
107
+ >>> configuration = model.config
108
+ """
109
+ model_type = "bert"
110
+
111
+ def __init__(
112
+ self,
113
+ vocab_size=30522,
114
+ hidden_size=768,
115
+ num_hidden_layers=12,
116
+ num_attention_heads=12,
117
+ intermediate_size=3072,
118
+ hidden_act="gelu",
119
+ hidden_dropout_prob=0.1,
120
+ attention_probs_dropout_prob=0.1,
121
+ max_position_embeddings=512,
122
+ type_vocab_size=2,
123
+ initializer_range=0.02,
124
+ layer_norm_eps=1e-12,
125
+ pad_token_id=0,
126
+ gradient_checkpointing=False,
127
+ **kwargs
128
+ ):
129
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
130
+
131
+ self.vocab_size = vocab_size
132
+ self.hidden_size = hidden_size
133
+ self.num_hidden_layers = num_hidden_layers
134
+ self.num_attention_heads = num_attention_heads
135
+ self.hidden_act = hidden_act
136
+ self.intermediate_size = intermediate_size
137
+ self.hidden_dropout_prob = hidden_dropout_prob
138
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
139
+ self.max_position_embeddings = max_position_embeddings
140
+ self.type_vocab_size = type_vocab_size
141
+ self.initializer_range = initializer_range
142
+ self.layer_norm_eps = layer_norm_eps
143
+ self.gradient_checkpointing = gradient_checkpointing
elia/bert/configuration_utils.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Configuration base class and utilities."""
17
+
18
+
19
+ import copy
20
+ import json
21
+ import logging
22
+ import os
23
+ from typing import Dict, Tuple
24
+
25
+ from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class PretrainedConfig(object):
32
+ r""" Base class for all configuration classes.
33
+ Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
34
+
35
+ Note:
36
+ A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
37
+ It only affects the model's configuration.
38
+
39
+ Class attributes (overridden by derived classes):
40
+ - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`.
41
+
42
+ Args:
43
+ finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`):
44
+ Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
45
+ num_labels (:obj:`int`, `optional`, defaults to `2`):
46
+ Number of classes to use when the model is a classification model (sequences/tokens)
47
+ output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`):
48
+ Should the model returns all hidden-states.
49
+ output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
50
+ Should the model returns all attentions.
51
+ torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
52
+ Is the model used with Torchscript (for PyTorch models).
53
+ """
54
+ model_type: str = ""
55
+
56
+ def __init__(self, **kwargs):
57
+ # Attributes with defaults
58
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
59
+ self.output_attentions = kwargs.pop("output_attentions", False)
60
+ self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
61
+ self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
62
+ self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
63
+ self.pruned_heads = kwargs.pop("pruned_heads", {})
64
+
65
+ # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
66
+ self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
67
+ self.is_decoder = kwargs.pop("is_decoder", False)
68
+
69
+ # Parameters for sequence generation
70
+ self.max_length = kwargs.pop("max_length", 20)
71
+ self.min_length = kwargs.pop("min_length", 0)
72
+ self.do_sample = kwargs.pop("do_sample", False)
73
+ self.early_stopping = kwargs.pop("early_stopping", False)
74
+ self.num_beams = kwargs.pop("num_beams", 1)
75
+ self.temperature = kwargs.pop("temperature", 1.0)
76
+ self.top_k = kwargs.pop("top_k", 50)
77
+ self.top_p = kwargs.pop("top_p", 1.0)
78
+ self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
79
+ self.length_penalty = kwargs.pop("length_penalty", 1.0)
80
+ self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
81
+ self.bad_words_ids = kwargs.pop("bad_words_ids", None)
82
+ self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
83
+
84
+ # Fine-tuning task arguments
85
+ self.architectures = kwargs.pop("architectures", None)
86
+ self.finetuning_task = kwargs.pop("finetuning_task", None)
87
+ self.id2label = kwargs.pop("id2label", None)
88
+ self.label2id = kwargs.pop("label2id", None)
89
+ if self.id2label is not None:
90
+ kwargs.pop("num_labels", None)
91
+ self.id2label = dict((int(key), value) for key, value in self.id2label.items())
92
+ # Keys are always strings in JSON so convert ids to int here.
93
+ else:
94
+ self.num_labels = kwargs.pop("num_labels", 2)
95
+
96
+ # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
97
+ self.prefix = kwargs.pop("prefix", None)
98
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
99
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
100
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
101
+ self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
102
+
103
+ # task specific arguments
104
+ self.task_specific_params = kwargs.pop("task_specific_params", None)
105
+
106
+ # TPU arguments
107
+ self.xla_device = kwargs.pop("xla_device", None)
108
+
109
+ # Additional attributes without default values
110
+ for key, value in kwargs.items():
111
+ try:
112
+ setattr(self, key, value)
113
+ except AttributeError as err:
114
+ logger.error("Can't set {} with value {} for {}".format(key, value, self))
115
+ raise err
116
+
117
+ @property
118
+ def num_labels(self):
119
+ return len(self.id2label)
120
+
121
+ @num_labels.setter
122
+ def num_labels(self, num_labels):
123
+ self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
124
+ self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
125
+
126
+ def save_pretrained(self, save_directory):
127
+ """
128
+ Save a configuration object to the directory `save_directory`, so that it
129
+ can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
130
+
131
+ Args:
132
+ save_directory (:obj:`string`):
133
+ Directory where the configuration JSON file will be saved.
134
+ """
135
+ if os.path.isfile(save_directory):
136
+ raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory))
137
+ os.makedirs(save_directory, exist_ok=True)
138
+ # If we save using the predefined names, we can load using `from_pretrained`
139
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
140
+
141
+ self.to_json_file(output_config_file, use_diff=True)
142
+ logger.info("Configuration saved in {}".format(output_config_file))
143
+
144
+ @classmethod
145
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
146
+ r"""
147
+
148
+ Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
149
+
150
+ Args:
151
+ pretrained_model_name_or_path (:obj:`string`):
152
+ either:
153
+ - a string with the `shortcut name` of a pre-trained model configuration to load from cache or
154
+ download, e.g.: ``bert-base-uncased``.
155
+ - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to
156
+ our S3, e.g.: ``dbmdz/bert-base-german-cased``.
157
+ - a path to a `directory` containing a configuration file saved using the
158
+ :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
159
+ - a path or url to a saved configuration JSON `file`, e.g.:
160
+ ``./my_model_directory/configuration.json``.
161
+ cache_dir (:obj:`string`, `optional`):
162
+ Path to a directory in which a downloaded pre-trained model
163
+ configuration should be cached if the standard cache should not be used.
164
+ kwargs (:obj:`Dict[str, any]`, `optional`):
165
+ The values in kwargs of any keys which are configuration attributes will be used to override the loaded
166
+ values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is
167
+ controlled by the `return_unused_kwargs` keyword parameter.
168
+ force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
169
+ Force to (re-)download the model weights and configuration files and override the cached versions if they exist.
170
+ resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
171
+ Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
172
+ proxies (:obj:`Dict`, `optional`):
173
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.:
174
+ :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.`
175
+ The proxies are used on each request.
176
+ return_unused_kwargs: (`optional`) bool:
177
+ If False, then this function returns just the final configuration object.
178
+ If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a
179
+ dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part
180
+ of kwargs which has not been used to update `config` and is otherwise ignored.
181
+
182
+ Returns:
183
+ :class:`PretrainedConfig`: An instance of a configuration object
184
+
185
+ Examples::
186
+
187
+ # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
188
+ # derived class: BertConfig
189
+ config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
190
+ config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
191
+ config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
192
+ config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
193
+ assert config.output_attention == True
194
+ config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
195
+ foo=False, return_unused_kwargs=True)
196
+ assert config.output_attention == True
197
+ assert unused_kwargs == {'foo': False}
198
+
199
+ """
200
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
201
+ return cls.from_dict(config_dict, **kwargs)
202
+
203
+ @classmethod
204
+ def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict, Dict]:
205
+ """
206
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
207
+ for instantiating a Config using `from_dict`.
208
+
209
+ Parameters:
210
+ pretrained_model_name_or_path (:obj:`string`):
211
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
212
+
213
+ Returns:
214
+ :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object.
215
+
216
+ """
217
+ cache_dir = kwargs.pop("cache_dir", None)
218
+ force_download = kwargs.pop("force_download", False)
219
+ resume_download = kwargs.pop("resume_download", False)
220
+ proxies = kwargs.pop("proxies", None)
221
+ local_files_only = kwargs.pop("local_files_only", False)
222
+
223
+ if os.path.isdir(pretrained_model_name_or_path):
224
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
225
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
226
+ config_file = pretrained_model_name_or_path
227
+ else:
228
+ config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
229
+
230
+ try:
231
+ # Load from URL or cache if already cached
232
+ resolved_config_file = cached_path(
233
+ config_file,
234
+ cache_dir=cache_dir,
235
+ force_download=force_download,
236
+ proxies=proxies,
237
+ resume_download=resume_download,
238
+ local_files_only=local_files_only,
239
+ )
240
+ # Load config dict
241
+ if resolved_config_file is None:
242
+ raise EnvironmentError
243
+ config_dict = cls._dict_from_json_file(resolved_config_file)
244
+
245
+ except EnvironmentError:
246
+ msg = (
247
+ f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
248
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
249
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
250
+ )
251
+ raise EnvironmentError(msg)
252
+
253
+ except json.JSONDecodeError:
254
+ msg = (
255
+ "Couldn't reach server at '{}' to download configuration file or "
256
+ "configuration file is not a valid JSON file. "
257
+ "Please check network or file content here: {}.".format(config_file, resolved_config_file)
258
+ )
259
+ raise EnvironmentError(msg)
260
+
261
+ if resolved_config_file == config_file:
262
+ logger.info("loading configuration file {}".format(config_file))
263
+ else:
264
+ logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
265
+
266
+ return config_dict, kwargs
267
+
268
+ @classmethod
269
+ def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig":
270
+ """
271
+ Constructs a `Config` from a Python dictionary of parameters.
272
+
273
+ Args:
274
+ config_dict (:obj:`Dict[str, any]`):
275
+ Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved
276
+ from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict`
277
+ method.
278
+ kwargs (:obj:`Dict[str, any]`):
279
+ Additional parameters from which to initialize the configuration object.
280
+
281
+ Returns:
282
+ :class:`PretrainedConfig`: An instance of a configuration object
283
+ """
284
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
285
+
286
+ config = cls(**config_dict)
287
+
288
+ if hasattr(config, "pruned_heads"):
289
+ config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
290
+
291
+ # Update config with kwargs if needed
292
+ to_remove = []
293
+ for key, value in kwargs.items():
294
+ if hasattr(config, key):
295
+ setattr(config, key, value)
296
+ to_remove.append(key)
297
+ for key in to_remove:
298
+ kwargs.pop(key, None)
299
+
300
+ logger.info("Model config %s", str(config))
301
+ if return_unused_kwargs:
302
+ return config, kwargs
303
+ else:
304
+ return config
305
+
306
+ @classmethod
307
+ def from_json_file(cls, json_file: str) -> "PretrainedConfig":
308
+ """
309
+ Constructs a `Config` from the path to a json file of parameters.
310
+
311
+ Args:
312
+ json_file (:obj:`string`):
313
+ Path to the JSON file containing the parameters.
314
+
315
+ Returns:
316
+ :class:`PretrainedConfig`: An instance of a configuration object
317
+
318
+ """
319
+ config_dict = cls._dict_from_json_file(json_file)
320
+ return cls(**config_dict)
321
+
322
+ @classmethod
323
+ def _dict_from_json_file(cls, json_file: str):
324
+ with open(json_file, "r", encoding="utf-8") as reader:
325
+ text = reader.read()
326
+ return json.loads(text)
327
+
328
+ def __eq__(self, other):
329
+ return self.__dict__ == other.__dict__
330
+
331
+ def __repr__(self):
332
+ return "{} {}".format(self.__class__.__name__, self.to_json_string())
333
+
334
+ def to_diff_dict(self):
335
+ """
336
+ Removes all attributes from config which correspond to the default
337
+ config attributes for better readability and serializes to a Python
338
+ dictionary.
339
+
340
+ Returns:
341
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
342
+ """
343
+ config_dict = self.to_dict()
344
+
345
+ # get the default config dict
346
+ default_config_dict = PretrainedConfig().to_dict()
347
+
348
+ serializable_config_dict = {}
349
+
350
+ # only serialize values that differ from the default config
351
+ for key, value in config_dict.items():
352
+ if key not in default_config_dict or value != default_config_dict[key]:
353
+ serializable_config_dict[key] = value
354
+
355
+ return serializable_config_dict
356
+
357
+ def to_dict(self):
358
+ """
359
+ Serializes this instance to a Python dictionary.
360
+
361
+ Returns:
362
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
363
+ """
364
+ output = copy.deepcopy(self.__dict__)
365
+ if hasattr(self.__class__, "model_type"):
366
+ output["model_type"] = self.__class__.model_type
367
+ return output
368
+
369
+ def to_json_string(self, use_diff=True):
370
+ """
371
+ Serializes this instance to a JSON string.
372
+
373
+ Args:
374
+ use_diff (:obj:`bool`):
375
+ If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string.
376
+
377
+ Returns:
378
+ :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
379
+ """
380
+ if use_diff is True:
381
+ config_dict = self.to_diff_dict()
382
+ else:
383
+ config_dict = self.to_dict()
384
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
385
+
386
+ def to_json_file(self, json_file_path, use_diff=True):
387
+ """
388
+ Save this instance to a json file.
389
+
390
+ Args:
391
+ json_file_path (:obj:`string`):
392
+ Path to the JSON file in which this configuration instance's parameters will be saved.
393
+ use_diff (:obj:`bool`):
394
+ If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.
395
+ """
396
+ with open(json_file_path, "w", encoding="utf-8") as writer:
397
+ writer.write(self.to_json_string(use_diff=use_diff))
398
+
399
+ def update(self, config_dict: Dict):
400
+ """
401
+ Updates attributes of this class
402
+ with attributes from `config_dict`.
403
+
404
+ Args:
405
+ :obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class.
406
+ """
407
+ for key, value in config_dict.items():
408
+ setattr(self, key, value)
elia/bert/file_utils.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for working with the local dataset cache.
3
+ This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4
+ Copyright by the AllenNLP authors.
5
+ """
6
+
7
+ import fnmatch
8
+ import json
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import sys
13
+ import tarfile
14
+ import tempfile
15
+ from contextlib import contextmanager
16
+ from functools import partial, wraps
17
+ from hashlib import sha256
18
+ from pathlib import Path
19
+ from typing import Dict, Optional, Union
20
+ from urllib.parse import urlparse
21
+ from zipfile import ZipFile, is_zipfile
22
+
23
+ import requests
24
+ from filelock import FileLock
25
+ from tqdm.auto import tqdm
26
+
27
+ #from . import __version__
28
+ __version__ = "3.0.2"
29
+
30
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
31
+
32
+ try:
33
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
34
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
35
+ if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
36
+ import torch
37
+
38
+ _torch_available = True # pylint: disable=invalid-name
39
+ logger.info("PyTorch version {} available.".format(torch.__version__))
40
+ else:
41
+ logger.info("Disabling PyTorch because USE_TF is set")
42
+ _torch_available = False
43
+ except ImportError:
44
+ _torch_available = False # pylint: disable=invalid-name
45
+
46
+ try:
47
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
48
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
49
+
50
+ if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
51
+ import tensorflow as tf
52
+
53
+ assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
54
+ _tf_available = True # pylint: disable=invalid-name
55
+ logger.info("TensorFlow version {} available.".format(tf.__version__))
56
+ else:
57
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
58
+ _tf_available = False
59
+ except (ImportError, AssertionError):
60
+ _tf_available = False # pylint: disable=invalid-name
61
+
62
+
63
+ try:
64
+ from torch.hub import _get_torch_home
65
+
66
+ torch_cache_home = _get_torch_home()
67
+ except ImportError:
68
+ torch_cache_home = os.path.expanduser(
69
+ os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
70
+ )
71
+
72
+
73
+ try:
74
+ import torch_xla.core.xla_model as xm # noqa: F401
75
+
76
+ if _torch_available:
77
+ _torch_tpu_available = True # pylint: disable=
78
+ else:
79
+ _torch_tpu_available = False
80
+ except ImportError:
81
+ _torch_tpu_available = False
82
+
83
+
84
+ try:
85
+ import psutil # noqa: F401
86
+
87
+ _psutil_available = True
88
+
89
+ except ImportError:
90
+ _psutil_available = False
91
+
92
+
93
+ try:
94
+ import py3nvml # noqa: F401
95
+
96
+ _py3nvml_available = True
97
+
98
+ except ImportError:
99
+ _py3nvml_available = False
100
+
101
+
102
+ try:
103
+ from apex import amp # noqa: F401
104
+
105
+ _has_apex = True
106
+ except ImportError:
107
+ _has_apex = False
108
+
109
+ default_cache_path = os.path.join(torch_cache_home, "transformers")
110
+
111
+
112
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
113
+ PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
114
+ TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
115
+
116
+ WEIGHTS_NAME = "pytorch_model.bin"
117
+ TF2_WEIGHTS_NAME = "tf_model.h5"
118
+ TF_WEIGHTS_NAME = "model.ckpt"
119
+ CONFIG_NAME = "config.json"
120
+ MODEL_CARD_NAME = "modelcard.json"
121
+
122
+
123
+ MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]]
124
+ DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
125
+ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
126
+
127
+ S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
128
+ CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
129
+
130
+
131
+ def is_torch_available():
132
+ return _torch_available
133
+
134
+
135
+ def is_tf_available():
136
+ return _tf_available
137
+
138
+
139
+ def is_torch_tpu_available():
140
+ return _torch_tpu_available
141
+
142
+
143
+ def is_psutil_available():
144
+ return _psutil_available
145
+
146
+
147
+ def is_py3nvml_available():
148
+ return _py3nvml_available
149
+
150
+
151
+ def is_apex_available():
152
+ return _has_apex
153
+
154
+
155
+ def add_start_docstrings(*docstr):
156
+ def docstring_decorator(fn):
157
+ fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
158
+ return fn
159
+
160
+ return docstring_decorator
161
+
162
+
163
+ def add_start_docstrings_to_callable(*docstr):
164
+ def docstring_decorator(fn):
165
+ class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0])
166
+ intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name)
167
+ note = r"""
168
+
169
+ .. note::
170
+ Although the recipe for forward pass needs to be defined within
171
+ this function, one should call the :class:`Module` instance afterwards
172
+ instead of this since the former takes care of running the
173
+ pre and post processing steps while the latter silently ignores them.
174
+ """
175
+ fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
176
+ return fn
177
+
178
+ return docstring_decorator
179
+
180
+
181
+ def add_end_docstrings(*docstr):
182
+ def docstring_decorator(fn):
183
+ fn.__doc__ = fn.__doc__ + "".join(docstr)
184
+ return fn
185
+
186
+ return docstring_decorator
187
+
188
+
189
+ PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
190
+ Example::
191
+
192
+ >>> from transformers import {tokenizer_class}, {model_class}
193
+ >>> import torch
194
+
195
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
196
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
197
+
198
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
199
+ >>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1
200
+
201
+ >>> outputs = model(**inputs, labels=labels)
202
+ >>> loss, scores = outputs[:2]
203
+ """
204
+
205
+ PT_QUESTION_ANSWERING_SAMPLE = r"""
206
+ Example::
207
+
208
+ >>> from transformers import {tokenizer_class}, {model_class}
209
+ >>> import torch
210
+
211
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
212
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
213
+
214
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
215
+ >>> start_positions = torch.tensor([1])
216
+ >>> end_positions = torch.tensor([3])
217
+
218
+ >>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
219
+ >>> loss, start_scores, end_scores = outputs[:3]
220
+ """
221
+
222
+ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
223
+ Example::
224
+
225
+ >>> from transformers import {tokenizer_class}, {model_class}
226
+ >>> import torch
227
+
228
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
229
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
230
+
231
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
232
+ >>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
233
+ >>> outputs = model(**inputs, labels=labels)
234
+ >>> loss, logits = outputs[:2]
235
+ """
236
+
237
+ PT_MASKED_LM_SAMPLE = r"""
238
+ Example::
239
+
240
+ >>> from transformers import {tokenizer_class}, {model_class}
241
+ >>> import torch
242
+
243
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
244
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
245
+
246
+ >>> input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"]
247
+
248
+ >>> outputs = model(input_ids, labels=input_ids)
249
+ >>> loss, prediction_scores = outputs[:2]
250
+ """
251
+
252
+ PT_BASE_MODEL_SAMPLE = r"""
253
+ Example::
254
+
255
+ >>> from transformers import {tokenizer_class}, {model_class}
256
+ >>> import torch
257
+
258
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
259
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
260
+
261
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
262
+ >>> outputs = model(**inputs)
263
+
264
+ >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
265
+ """
266
+
267
+ PT_MULTIPLE_CHOICE_SAMPLE = r"""
268
+ Example::
269
+
270
+ >>> from transformers import {tokenizer_class}, {model_class}
271
+ >>> import torch
272
+
273
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
274
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
275
+
276
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
277
+ >>> choice0 = "It is eaten with a fork and a knife."
278
+ >>> choice1 = "It is eaten while held in the hand."
279
+ >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
280
+
281
+ >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)
282
+ >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels) # batch size is 1
283
+
284
+ >>> # the linear classifier still needs to be trained
285
+ >>> loss, logits = outputs[:2]
286
+ """
287
+
288
+ PT_CAUSAL_LM_SAMPLE = r"""
289
+ Example::
290
+
291
+ >>> import torch
292
+ >>> from transformers import {tokenizer_class}, {model_class}
293
+
294
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
295
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
296
+
297
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
298
+ >>> outputs = model(**inputs, labels=inputs["input_ids"])
299
+ >>> loss, logits = outputs[:2]
300
+ """
301
+
302
+ TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
303
+ Example::
304
+
305
+ >>> from transformers import {tokenizer_class}, {model_class}
306
+ >>> import tensorflow as tf
307
+
308
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
309
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
310
+
311
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
312
+ >>> input_ids = inputs["input_ids"]
313
+ >>> inputs["labels"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
314
+
315
+ >>> outputs = model(inputs)
316
+ >>> loss, scores = outputs[:2]
317
+ """
318
+
319
+ TF_QUESTION_ANSWERING_SAMPLE = r"""
320
+ Example::
321
+
322
+ >>> from transformers import {tokenizer_class}, {model_class}
323
+ >>> import tensorflow as tf
324
+
325
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
326
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
327
+
328
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
329
+ >>> input_dict = tokenizer(question, text, return_tensors='tf')
330
+ >>> start_scores, end_scores = model(input_dict)
331
+
332
+ >>> all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
333
+ >>> answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
334
+ """
335
+
336
+ TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
337
+ Example::
338
+
339
+ >>> from transformers import {tokenizer_class}, {model_class}
340
+ >>> import tensorflow as tf
341
+
342
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
343
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
344
+
345
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
346
+ >>> inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
347
+
348
+ >>> outputs = model(inputs)
349
+ >>> loss, logits = outputs[:2]
350
+ """
351
+
352
+ TF_MASKED_LM_SAMPLE = r"""
353
+ Example::
354
+ >>> from transformers import {tokenizer_class}, {model_class}
355
+ >>> import tensorflow as tf
356
+
357
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
358
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
359
+
360
+ >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
361
+
362
+ >>> outputs = model(input_ids)
363
+ >>> prediction_scores = outputs[0]
364
+ """
365
+
366
+ TF_BASE_MODEL_SAMPLE = r"""
367
+ Example::
368
+
369
+ >>> from transformers import {tokenizer_class}, {model_class}
370
+ >>> import tensorflow as tf
371
+
372
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
373
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
374
+
375
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
376
+ >>> outputs = model(inputs)
377
+
378
+ >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
379
+ """
380
+
381
+ TF_MULTIPLE_CHOICE_SAMPLE = r"""
382
+ Example::
383
+
384
+ >>> from transformers import {tokenizer_class}, {model_class}
385
+ >>> import tensorflow as tf
386
+
387
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
388
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
389
+
390
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
391
+ >>> choice0 = "It is eaten with a fork and a knife."
392
+ >>> choice1 = "It is eaten while held in the hand."
393
+
394
+ >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='tf', padding=True)
395
+ >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
396
+ >>> outputs = model(inputs) # batch size is 1
397
+
398
+ >>> # the linear classifier still needs to be trained
399
+ >>> logits = outputs[0]
400
+ """
401
+
402
+ TF_CAUSAL_LM_SAMPLE = r"""
403
+ Example::
404
+
405
+ >>> from transformers import {tokenizer_class}, {model_class}
406
+ >>> import tensorflow as tf
407
+
408
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
409
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
410
+
411
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
412
+ >>> outputs = model(inputs)
413
+ >>> logits = outputs[0]
414
+ """
415
+
416
+
417
+ def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None):
418
+ def docstring_decorator(fn):
419
+ model_class = fn.__qualname__.split(".")[0]
420
+ is_tf_class = model_class[:2] == "TF"
421
+
422
+ if "SequenceClassification" in model_class:
423
+ code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE
424
+ elif "QuestionAnswering" in model_class:
425
+ code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE
426
+ elif "TokenClassification" in model_class:
427
+ code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE
428
+ elif "MultipleChoice" in model_class:
429
+ code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE
430
+ elif "MaskedLM" in model_class:
431
+ code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
432
+ elif "LMHead" in model_class:
433
+ code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
434
+ elif "Model" in model_class:
435
+ code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
436
+ else:
437
+ raise ValueError(f"Docstring can't be built for model {model_class}")
438
+
439
+ built_doc = code_sample.format(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
440
+ fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + built_doc
441
+ return fn
442
+
443
+ return docstring_decorator
444
+
445
+
446
+ def is_remote_url(url_or_filename):
447
+ parsed = urlparse(url_or_filename)
448
+ return parsed.scheme in ("http", "https")
449
+
450
+
451
+ def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
452
+ """
453
+ Resolve a model identifier, and a file name, to a HF-hosted url
454
+ on either S3 or Cloudfront (a Content Delivery Network, or CDN).
455
+
456
+ Cloudfront is replicated over the globe so downloads are way faster
457
+ for the end user (and it also lowers our bandwidth costs). However, it
458
+ is more aggressively cached by default, so may not always reflect the
459
+ latest changes to the underlying file (default TTL is 24 hours).
460
+
461
+ In terms of client-side caching from this library, even though
462
+ Cloudfront relays the ETags from S3, using one or the other
463
+ (or switching from one to the other) will affect caching: cached files
464
+ are not shared between the two because the cached file's name contains
465
+ a hash of the url.
466
+ """
467
+ endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX
468
+ legacy_format = "/" not in model_id
469
+ if legacy_format:
470
+ return f"{endpoint}/{model_id}-{filename}"
471
+ else:
472
+ return f"{endpoint}/{model_id}/{filename}"
473
+
474
+
475
+ def url_to_filename(url, etag=None):
476
+ """
477
+ Convert `url` into a hashed filename in a repeatable way.
478
+ If `etag` is specified, append its hash to the url's, delimited
479
+ by a period.
480
+ If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name
481
+ so that TF 2.0 can identify it as a HDF5 file
482
+ (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
483
+ """
484
+ url_bytes = url.encode("utf-8")
485
+ url_hash = sha256(url_bytes)
486
+ filename = url_hash.hexdigest()
487
+
488
+ if etag:
489
+ etag_bytes = etag.encode("utf-8")
490
+ etag_hash = sha256(etag_bytes)
491
+ filename += "." + etag_hash.hexdigest()
492
+
493
+ if url.endswith(".h5"):
494
+ filename += ".h5"
495
+
496
+ return filename
497
+
498
+
499
+ def filename_to_url(filename, cache_dir=None):
500
+ """
501
+ Return the url and etag (which may be ``None``) stored for `filename`.
502
+ Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
503
+ """
504
+ if cache_dir is None:
505
+ cache_dir = TRANSFORMERS_CACHE
506
+ if isinstance(cache_dir, Path):
507
+ cache_dir = str(cache_dir)
508
+
509
+ cache_path = os.path.join(cache_dir, filename)
510
+ if not os.path.exists(cache_path):
511
+ raise EnvironmentError("file {} not found".format(cache_path))
512
+
513
+ meta_path = cache_path + ".json"
514
+ if not os.path.exists(meta_path):
515
+ raise EnvironmentError("file {} not found".format(meta_path))
516
+
517
+ with open(meta_path, encoding="utf-8") as meta_file:
518
+ metadata = json.load(meta_file)
519
+ url = metadata["url"]
520
+ etag = metadata["etag"]
521
+
522
+ return url, etag
523
+
524
+
525
+ def cached_path(
526
+ url_or_filename,
527
+ cache_dir=None,
528
+ force_download=False,
529
+ proxies=None,
530
+ resume_download=False,
531
+ user_agent: Union[Dict, str, None] = None,
532
+ extract_compressed_file=False,
533
+ force_extract=False,
534
+ local_files_only=False,
535
+ ) -> Optional[str]:
536
+ """
537
+ Given something that might be a URL (or might be a local path),
538
+ determine which. If it's a URL, download the file and cache it, and
539
+ return the path to the cached file. If it's already a local path,
540
+ make sure the file exists and then return the path.
541
+ Args:
542
+ cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
543
+ force_download: if True, re-dowload the file even if it's already cached in the cache dir.
544
+ resume_download: if True, resume the download if incompletly recieved file is found.
545
+ user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
546
+ extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
547
+ file in a folder along the archive.
548
+ force_extract: if True when extract_compressed_file is True and the archive was already extracted,
549
+ re-extract the archive and overide the folder where it was extracted.
550
+
551
+ Return:
552
+ None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
553
+ Local path (string) otherwise
554
+ """
555
+ if cache_dir is None:
556
+ cache_dir = TRANSFORMERS_CACHE
557
+ if isinstance(url_or_filename, Path):
558
+ url_or_filename = str(url_or_filename)
559
+ if isinstance(cache_dir, Path):
560
+ cache_dir = str(cache_dir)
561
+
562
+ if is_remote_url(url_or_filename):
563
+ # URL, so get it from the cache (downloading if necessary)
564
+ output_path = get_from_cache(
565
+ url_or_filename,
566
+ cache_dir=cache_dir,
567
+ force_download=force_download,
568
+ proxies=proxies,
569
+ resume_download=resume_download,
570
+ user_agent=user_agent,
571
+ local_files_only=local_files_only,
572
+ )
573
+ elif os.path.exists(url_or_filename):
574
+ # File, and it exists.
575
+ output_path = url_or_filename
576
+ elif urlparse(url_or_filename).scheme == "":
577
+ # File, but it doesn't exist.
578
+ raise EnvironmentError("file {} not found".format(url_or_filename))
579
+ else:
580
+ # Something unknown
581
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
582
+
583
+ if extract_compressed_file:
584
+ if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
585
+ return output_path
586
+
587
+ # Path where we extract compressed archives
588
+ # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
589
+ output_dir, output_file = os.path.split(output_path)
590
+ output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
591
+ output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
592
+
593
+ if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
594
+ return output_path_extracted
595
+
596
+ # Prevent parallel extractions
597
+ lock_path = output_path + ".lock"
598
+ with FileLock(lock_path):
599
+ shutil.rmtree(output_path_extracted, ignore_errors=True)
600
+ os.makedirs(output_path_extracted)
601
+ if is_zipfile(output_path):
602
+ with ZipFile(output_path, "r") as zip_file:
603
+ zip_file.extractall(output_path_extracted)
604
+ zip_file.close()
605
+ elif tarfile.is_tarfile(output_path):
606
+ tar_file = tarfile.open(output_path)
607
+ tar_file.extractall(output_path_extracted)
608
+ tar_file.close()
609
+ else:
610
+ raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
611
+
612
+ return output_path_extracted
613
+
614
+ return output_path
615
+
616
+
617
+ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
618
+ ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
619
+ if is_torch_available():
620
+ ua += "; torch/{}".format(torch.__version__)
621
+ if is_tf_available():
622
+ ua += "; tensorflow/{}".format(tf.__version__)
623
+ if isinstance(user_agent, dict):
624
+ ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
625
+ elif isinstance(user_agent, str):
626
+ ua += "; " + user_agent
627
+ headers = {"user-agent": ua}
628
+ if resume_size > 0:
629
+ headers["Range"] = "bytes=%d-" % (resume_size,)
630
+ response = requests.get(url, stream=True, proxies=proxies, headers=headers)
631
+ if response.status_code == 416: # Range not satisfiable
632
+ return
633
+ content_length = response.headers.get("Content-Length")
634
+ total = resume_size + int(content_length) if content_length is not None else None
635
+ progress = tqdm(
636
+ unit="B",
637
+ unit_scale=True,
638
+ total=total,
639
+ initial=resume_size,
640
+ desc="Downloading",
641
+ disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
642
+ )
643
+ for chunk in response.iter_content(chunk_size=1024):
644
+ if chunk: # filter out keep-alive new chunks
645
+ progress.update(len(chunk))
646
+ temp_file.write(chunk)
647
+ progress.close()
648
+
649
+
650
+ def get_from_cache(
651
+ url,
652
+ cache_dir=None,
653
+ force_download=False,
654
+ proxies=None,
655
+ etag_timeout=10,
656
+ resume_download=False,
657
+ user_agent: Union[Dict, str, None] = None,
658
+ local_files_only=False,
659
+ ) -> Optional[str]:
660
+ """
661
+ Given a URL, look for the corresponding file in the local cache.
662
+ If it's not there, download it. Then return the path to the cached file.
663
+
664
+ Return:
665
+ None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
666
+ Local path (string) otherwise
667
+ """
668
+ if cache_dir is None:
669
+ cache_dir = TRANSFORMERS_CACHE
670
+ if isinstance(cache_dir, Path):
671
+ cache_dir = str(cache_dir)
672
+
673
+ os.makedirs(cache_dir, exist_ok=True)
674
+
675
+ etag = None
676
+ if not local_files_only:
677
+ try:
678
+ response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
679
+ if response.status_code == 200:
680
+ etag = response.headers.get("ETag")
681
+ except (EnvironmentError, requests.exceptions.Timeout):
682
+ # etag is already None
683
+ pass
684
+
685
+ filename = url_to_filename(url, etag)
686
+
687
+ # get cache path to put the file
688
+ cache_path = os.path.join(cache_dir, filename)
689
+
690
+ # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
691
+ # try to get the last downloaded one
692
+ if etag is None:
693
+ if os.path.exists(cache_path):
694
+ return cache_path
695
+ else:
696
+ matching_files = [
697
+ file
698
+ for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
699
+ if not file.endswith(".json") and not file.endswith(".lock")
700
+ ]
701
+ if len(matching_files) > 0:
702
+ return os.path.join(cache_dir, matching_files[-1])
703
+ else:
704
+ # If files cannot be found and local_files_only=True,
705
+ # the models might've been found if local_files_only=False
706
+ # Notify the user about that
707
+ if local_files_only:
708
+ raise ValueError(
709
+ "Cannot find the requested files in the cached path and outgoing traffic has been"
710
+ " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
711
+ " to False."
712
+ )
713
+ return None
714
+
715
+ # From now on, etag is not None.
716
+ if os.path.exists(cache_path) and not force_download:
717
+ return cache_path
718
+
719
+ # Prevent parallel downloads of the same file with a lock.
720
+ lock_path = cache_path + ".lock"
721
+ with FileLock(lock_path):
722
+
723
+ # If the download just completed while the lock was activated.
724
+ if os.path.exists(cache_path) and not force_download:
725
+ # Even if returning early like here, the lock will be released.
726
+ return cache_path
727
+
728
+ if resume_download:
729
+ incomplete_path = cache_path + ".incomplete"
730
+
731
+ @contextmanager
732
+ def _resumable_file_manager():
733
+ with open(incomplete_path, "a+b") as f:
734
+ yield f
735
+
736
+ temp_file_manager = _resumable_file_manager
737
+ if os.path.exists(incomplete_path):
738
+ resume_size = os.stat(incomplete_path).st_size
739
+ else:
740
+ resume_size = 0
741
+ else:
742
+ temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
743
+ resume_size = 0
744
+
745
+ # Download to temporary file, then copy to cache dir once finished.
746
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
747
+ with temp_file_manager() as temp_file:
748
+ logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
749
+
750
+ http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
751
+
752
+ logger.info("storing %s in cache at %s", url, cache_path)
753
+ os.replace(temp_file.name, cache_path)
754
+
755
+ logger.info("creating metadata file for %s", cache_path)
756
+ meta = {"url": url, "etag": etag}
757
+ meta_path = cache_path + ".json"
758
+ with open(meta_path, "w") as meta_file:
759
+ json.dump(meta, meta_file)
760
+
761
+ return cache_path
762
+
763
+
764
+ class cached_property(property):
765
+ """
766
+ Descriptor that mimics @property but caches output in member variable.
767
+
768
+ From tensorflow_datasets
769
+
770
+ Built-in in functools from Python 3.8.
771
+ """
772
+
773
+ def __get__(self, obj, objtype=None):
774
+ # See docs.python.org/3/howto/descriptor.html#properties
775
+ if obj is None:
776
+ return self
777
+ if self.fget is None:
778
+ raise AttributeError("unreadable attribute")
779
+ attr = "__cached_" + self.fget.__name__
780
+ cached = getattr(obj, attr, None)
781
+ if cached is None:
782
+ cached = self.fget(obj)
783
+ setattr(obj, attr, cached)
784
+ return cached
785
+
786
+
787
+ def torch_required(func):
788
+ # Chose a different decorator name than in tests so it's clear they are not the same.
789
+ @wraps(func)
790
+ def wrapper(*args, **kwargs):
791
+ if is_torch_available():
792
+ return func(*args, **kwargs)
793
+ else:
794
+ raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
795
+
796
+ return wrapper
797
+
798
+
799
+ def tf_required(func):
800
+ # Chose a different decorator name than in tests so it's clear they are not the same.
801
+ @wraps(func)
802
+ def wrapper(*args, **kwargs):
803
+ if is_tf_available():
804
+ return func(*args, **kwargs)
805
+ else:
806
+ raise ImportError(f"Method `{func.__name__}` requires TF.")
807
+
808
+ return wrapper
elia/bert/generation_utils.py ADDED
@@ -0,0 +1,993 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+ from typing import Iterable, Optional, Tuple
19
+
20
+ import torch
21
+ from torch import Tensor
22
+ from torch.nn import functional as F
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class GenerationMixin:
29
+ """
30
+ A class contraining all of the functions supporting generation, to be used as a mixin in PreTrainedModel.
31
+ """
32
+
33
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
34
+ return {"input_ids": input_ids}
35
+
36
+ def adjust_logits_during_generation(self, logits, **kwargs):
37
+ return logits
38
+
39
+ def _use_cache(self, outputs, use_cache):
40
+ """During generation, decide whether to pass the `past` variable to the next forward pass."""
41
+ if len(outputs) <= 1 or use_cache is False:
42
+ return False
43
+ if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
44
+ return False
45
+ return True
46
+
47
+ def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
48
+ """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
49
+ for i in range(batch_size * num_beams):
50
+ for previous_token in set(prev_output_tokens[i].tolist()):
51
+ # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
52
+ if lprobs[i, previous_token] < 0:
53
+ lprobs[i, previous_token] *= repetition_penalty
54
+ else:
55
+ lprobs[i, previous_token] /= repetition_penalty
56
+
57
+ def postprocess_next_token_scores(
58
+ self,
59
+ scores,
60
+ input_ids,
61
+ no_repeat_ngram_size,
62
+ bad_words_ids,
63
+ cur_len,
64
+ min_length,
65
+ max_length,
66
+ eos_token_id,
67
+ repetition_penalty,
68
+ batch_size,
69
+ num_beams,
70
+ ):
71
+ # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
72
+ if repetition_penalty != 1.0:
73
+ self.enforce_repetition_penalty_(
74
+ scores, batch_size, num_beams, input_ids, repetition_penalty,
75
+ )
76
+
77
+ # set eos token prob to zero if min_length is not reached
78
+ if eos_token_id is not None and cur_len < min_length:
79
+ scores[:, eos_token_id] = -float("inf")
80
+
81
+ if no_repeat_ngram_size > 0:
82
+ # calculate a list of banned tokens to prevent repetitively generating the same ngrams
83
+ num_batch_hypotheses = batch_size * num_beams
84
+ # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
85
+ banned_batch_tokens = calc_banned_ngram_tokens(
86
+ input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
87
+ )
88
+ for i, banned_tokens in enumerate(banned_batch_tokens):
89
+ scores[i, banned_tokens] = -float("inf")
90
+
91
+ if bad_words_ids is not None:
92
+ # calculate a list of banned tokens according to bad words
93
+ banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
94
+
95
+ for i, banned_tokens in enumerate(banned_tokens):
96
+ scores[i, banned_tokens] = -float("inf")
97
+
98
+ return scores
99
+
100
+ @torch.no_grad()
101
+ def generate(
102
+ self,
103
+ input_ids: Optional[torch.LongTensor] = None,
104
+ max_length: Optional[int] = None,
105
+ min_length: Optional[int] = None,
106
+ do_sample: Optional[bool] = None,
107
+ early_stopping: Optional[bool] = None,
108
+ num_beams: Optional[int] = None,
109
+ temperature: Optional[float] = None,
110
+ top_k: Optional[int] = None,
111
+ top_p: Optional[float] = None,
112
+ repetition_penalty: Optional[float] = None,
113
+ bad_words_ids: Optional[Iterable[int]] = None,
114
+ bos_token_id: Optional[int] = None,
115
+ pad_token_id: Optional[int] = None,
116
+ eos_token_id: Optional[int] = None,
117
+ length_penalty: Optional[float] = None,
118
+ no_repeat_ngram_size: Optional[int] = None,
119
+ num_return_sequences: Optional[int] = None,
120
+ attention_mask: Optional[torch.LongTensor] = None,
121
+ decoder_start_token_id: Optional[int] = None,
122
+ use_cache: Optional[bool] = None,
123
+ **model_specific_kwargs
124
+ ) -> torch.LongTensor:
125
+ r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
126
+
127
+ Adapted in part from `Facebook's XLM beam search code`_.
128
+
129
+ .. _`Facebook's XLM beam search code`:
130
+ https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
131
+
132
+
133
+ Parameters:
134
+
135
+ input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
136
+ The sequence used as a prompt for the generation. If `None` the method initializes
137
+ it as an empty `torch.LongTensor` of shape `(1,)`.
138
+
139
+ max_length: (`optional`) int
140
+ The max length of the sequence to be generated. Between `min_length` and infinity. Default to 20.
141
+
142
+ min_length: (`optional`) int
143
+ The min length of the sequence to be generated. Between 0 and infinity. Default to 0.
144
+
145
+ do_sample: (`optional`) bool
146
+ If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
147
+
148
+ early_stopping: (`optional`) bool
149
+ if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
150
+
151
+ num_beams: (`optional`) int
152
+ Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
153
+
154
+ temperature: (`optional`) float
155
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
156
+
157
+ top_k: (`optional`) int
158
+ The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
159
+
160
+ top_p: (`optional`) float
161
+ The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
162
+
163
+ repetition_penalty: (`optional`) float
164
+ The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
165
+
166
+ pad_token_id: (`optional`) int
167
+ Padding token. Default to specicic model pad_token_id or None if it does not exist.
168
+
169
+ bos_token_id: (`optional`) int
170
+ BOS token. Defaults to `bos_token_id` as defined in the models config.
171
+
172
+ eos_token_id: (`optional`) int
173
+ EOS token. Defaults to `eos_token_id` as defined in the models config.
174
+
175
+ length_penalty: (`optional`) float
176
+ Exponential penalty to the length. Default to 1.
177
+
178
+ no_repeat_ngram_size: (`optional`) int
179
+ If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
180
+ bad_words_ids: (`optional`) list of lists of int
181
+ `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
182
+
183
+ num_return_sequences: (`optional`) int
184
+ The number of independently computed returned sequences for each element in the batch. Default to 1.
185
+
186
+ attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids`
187
+ Mask to avoid performing attention on padding token indices.
188
+ Mask values selected in ``[0, 1]``:
189
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
190
+ Defaults to `None`.
191
+
192
+ `What are attention masks? <../glossary.html#attention-mask>`__
193
+
194
+ decoder_start_token_id=None: (`optional`) int
195
+ If an encoder-decoder model starts decoding with a different token than BOS.
196
+ Defaults to `None` and is changed to `BOS` later.
197
+
198
+ use_cache: (`optional`) bool
199
+ If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
200
+
201
+ model_specific_kwargs: (`optional`) dict
202
+ Additional model specific kwargs will be forwarded to the `forward` function of the model.
203
+
204
+ Return:
205
+
206
+ output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
207
+ sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
208
+
209
+ Examples::
210
+
211
+ tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
212
+ model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
213
+ outputs = model.generate(max_length=40) # do greedy decoding
214
+ print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
215
+
216
+ tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
217
+ model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
218
+ input_context = 'The dog'
219
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
220
+ outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
221
+ for i in range(3): # 3 output sequences were generated
222
+ print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
223
+
224
+ tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
225
+ model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
226
+ input_context = 'The dog'
227
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
228
+ outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling
229
+ for i in range(3): # 3 output sequences were generated
230
+ print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
231
+
232
+ tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
233
+ model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
234
+ input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
235
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
236
+ outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
237
+ print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
238
+
239
+ tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
240
+ model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
241
+ input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
242
+ bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
243
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
244
+ outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
245
+ """
246
+
247
+ # We cannot generate if the model does not have a LM head
248
+ if self.get_output_embeddings() is None:
249
+ raise AttributeError(
250
+ "You tried to generate sequences with a model that does not have a LM Head."
251
+ "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
252
+ )
253
+
254
+ max_length = max_length if max_length is not None else self.config.max_length
255
+ min_length = min_length if min_length is not None else self.config.min_length
256
+ do_sample = do_sample if do_sample is not None else self.config.do_sample
257
+ early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
258
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
259
+ num_beams = num_beams if num_beams is not None else self.config.num_beams
260
+ temperature = temperature if temperature is not None else self.config.temperature
261
+ top_k = top_k if top_k is not None else self.config.top_k
262
+ top_p = top_p if top_p is not None else self.config.top_p
263
+ repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
264
+ bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
265
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
266
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
267
+ length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
268
+ no_repeat_ngram_size = (
269
+ no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
270
+ )
271
+ bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
272
+ num_return_sequences = (
273
+ num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
274
+ )
275
+ decoder_start_token_id = (
276
+ decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
277
+ )
278
+
279
+ if input_ids is not None:
280
+ batch_size = input_ids.shape[0] # overriden by the input batch_size
281
+ else:
282
+ batch_size = 1
283
+
284
+ assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
285
+ assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
286
+ assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
287
+ assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
288
+ assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
289
+ assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
290
+ assert temperature > 0, "`temperature` should be strictly positive."
291
+ assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
292
+ assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
293
+ assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
294
+ assert input_ids is not None or (
295
+ isinstance(bos_token_id, int) and bos_token_id >= 0
296
+ ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
297
+ assert pad_token_id is None or (
298
+ isinstance(pad_token_id, int) and (pad_token_id >= 0)
299
+ ), "`pad_token_id` should be a positive integer."
300
+ assert (eos_token_id is None) or (
301
+ isinstance(eos_token_id, int) and (eos_token_id >= 0)
302
+ ), "`eos_token_id` should be a positive integer."
303
+ assert length_penalty > 0, "`length_penalty` should be strictly positive."
304
+ assert (
305
+ isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
306
+ ), "`no_repeat_ngram_size` should be a positive integer."
307
+ assert (
308
+ isinstance(num_return_sequences, int) and num_return_sequences > 0
309
+ ), "`num_return_sequences` should be a strictly positive integer."
310
+ assert (
311
+ bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
312
+ ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
313
+
314
+ if input_ids is None:
315
+ assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
316
+ "you should either supply a context to complete as `input_ids` input "
317
+ "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
318
+ )
319
+ input_ids = torch.full(
320
+ (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
321
+ )
322
+ else:
323
+ assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
324
+
325
+ # not allow to duplicate outputs when greedy decoding
326
+ if do_sample is False:
327
+ if num_beams == 1:
328
+ # no_beam_search greedy generation conditions
329
+ assert (
330
+ num_return_sequences == 1
331
+ ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
332
+
333
+ else:
334
+ # beam_search greedy generation conditions
335
+ assert (
336
+ num_beams >= num_return_sequences
337
+ ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
338
+
339
+ # create attention mask if necessary
340
+ # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
341
+ if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
342
+ attention_mask = input_ids.ne(pad_token_id).long()
343
+ elif attention_mask is None:
344
+ attention_mask = input_ids.new_ones(input_ids.shape)
345
+
346
+ # set pad_token_id to eos_token_id if not set. Important that this is done after
347
+ # attention_mask is created
348
+ if pad_token_id is None and eos_token_id is not None:
349
+ logger.warning(
350
+ "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
351
+ )
352
+ pad_token_id = eos_token_id
353
+
354
+ # current position and vocab size
355
+ if hasattr(self.config, "vocab_size"):
356
+ vocab_size = self.config.vocab_size
357
+ elif (
358
+ self.config.is_encoder_decoder
359
+ and hasattr(self.config, "decoder")
360
+ and hasattr(self.config.decoder, "vocab_size")
361
+ ):
362
+ vocab_size = self.config.decoder.vocab_size
363
+
364
+ # set effective batch size and effective batch multiplier according to do_sample
365
+ if do_sample:
366
+ effective_batch_size = batch_size * num_return_sequences
367
+ effective_batch_mult = num_return_sequences
368
+ else:
369
+ effective_batch_size = batch_size
370
+ effective_batch_mult = 1
371
+
372
+ if self.config.is_encoder_decoder:
373
+ if decoder_start_token_id is None:
374
+ decoder_start_token_id = bos_token_id
375
+
376
+ assert (
377
+ decoder_start_token_id is not None
378
+ ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
379
+ assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
380
+ assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
381
+
382
+ # get encoder and store encoder outputs
383
+ encoder = self.get_encoder()
384
+
385
+ encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
386
+
387
+ # Expand input ids if num_beams > 1 or num_return_sequences > 1
388
+ if num_return_sequences > 1 or num_beams > 1:
389
+ input_ids_len = input_ids.shape[-1]
390
+ input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
391
+ attention_mask = attention_mask.unsqueeze(1).expand(
392
+ batch_size, effective_batch_mult * num_beams, input_ids_len
393
+ )
394
+
395
+ input_ids = input_ids.contiguous().view(
396
+ effective_batch_size * num_beams, input_ids_len
397
+ ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
398
+ attention_mask = attention_mask.contiguous().view(
399
+ effective_batch_size * num_beams, input_ids_len
400
+ ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
401
+
402
+ if self.config.is_encoder_decoder:
403
+ # create empty decoder_input_ids
404
+ input_ids = torch.full(
405
+ (effective_batch_size * num_beams, 1),
406
+ decoder_start_token_id,
407
+ dtype=torch.long,
408
+ device=next(self.parameters()).device,
409
+ )
410
+ cur_len = 1
411
+
412
+ assert (
413
+ batch_size == encoder_outputs[0].shape[0]
414
+ ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
415
+
416
+ # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
417
+ expanded_batch_idxs = (
418
+ torch.arange(batch_size)
419
+ .view(-1, 1)
420
+ .repeat(1, num_beams * effective_batch_mult)
421
+ .view(-1)
422
+ .to(input_ids.device)
423
+ )
424
+ # expand encoder_outputs
425
+ encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
426
+
427
+ else:
428
+ encoder_outputs = None
429
+ cur_len = input_ids.shape[-1]
430
+
431
+ assert (
432
+ cur_len < max_length
433
+ ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
434
+
435
+ if num_beams > 1:
436
+ output = self._generate_beam_search(
437
+ input_ids,
438
+ cur_len=cur_len,
439
+ max_length=max_length,
440
+ min_length=min_length,
441
+ do_sample=do_sample,
442
+ early_stopping=early_stopping,
443
+ temperature=temperature,
444
+ top_k=top_k,
445
+ top_p=top_p,
446
+ repetition_penalty=repetition_penalty,
447
+ no_repeat_ngram_size=no_repeat_ngram_size,
448
+ bad_words_ids=bad_words_ids,
449
+ pad_token_id=pad_token_id,
450
+ eos_token_id=eos_token_id,
451
+ batch_size=effective_batch_size,
452
+ num_return_sequences=num_return_sequences,
453
+ length_penalty=length_penalty,
454
+ num_beams=num_beams,
455
+ vocab_size=vocab_size,
456
+ encoder_outputs=encoder_outputs,
457
+ attention_mask=attention_mask,
458
+ use_cache=use_cache,
459
+ model_specific_kwargs=model_specific_kwargs,
460
+ )
461
+ else:
462
+ output = self._generate_no_beam_search(
463
+ input_ids,
464
+ cur_len=cur_len,
465
+ max_length=max_length,
466
+ min_length=min_length,
467
+ do_sample=do_sample,
468
+ temperature=temperature,
469
+ top_k=top_k,
470
+ top_p=top_p,
471
+ repetition_penalty=repetition_penalty,
472
+ no_repeat_ngram_size=no_repeat_ngram_size,
473
+ bad_words_ids=bad_words_ids,
474
+ pad_token_id=pad_token_id,
475
+ eos_token_id=eos_token_id,
476
+ batch_size=effective_batch_size,
477
+ encoder_outputs=encoder_outputs,
478
+ attention_mask=attention_mask,
479
+ use_cache=use_cache,
480
+ model_specific_kwargs=model_specific_kwargs,
481
+ )
482
+
483
+ return output
484
+
485
+ def _generate_no_beam_search(
486
+ self,
487
+ input_ids,
488
+ cur_len,
489
+ max_length,
490
+ min_length,
491
+ do_sample,
492
+ temperature,
493
+ top_k,
494
+ top_p,
495
+ repetition_penalty,
496
+ no_repeat_ngram_size,
497
+ bad_words_ids,
498
+ pad_token_id,
499
+ eos_token_id,
500
+ batch_size,
501
+ encoder_outputs,
502
+ attention_mask,
503
+ use_cache,
504
+ model_specific_kwargs,
505
+ ):
506
+ """ Generate sequences for each example without beam search (num_beams == 1).
507
+ All returned sequence are generated independantly.
508
+ """
509
+ # length of generated sentences / unfinished sentences
510
+ unfinished_sents = input_ids.new(batch_size).fill_(1)
511
+ sent_lengths = input_ids.new(batch_size).fill_(max_length)
512
+
513
+ past = (encoder_outputs, None) if encoder_outputs is not None else None
514
+
515
+ while cur_len < max_length:
516
+ model_inputs = self.prepare_inputs_for_generation(
517
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
518
+ )
519
+
520
+ outputs = self(**model_inputs)
521
+ next_token_logits = outputs[0][:, -1, :]
522
+
523
+ scores = self.postprocess_next_token_scores(
524
+ scores=next_token_logits,
525
+ input_ids=input_ids,
526
+ no_repeat_ngram_size=no_repeat_ngram_size,
527
+ bad_words_ids=bad_words_ids,
528
+ cur_len=cur_len,
529
+ min_length=min_length,
530
+ max_length=max_length,
531
+ eos_token_id=eos_token_id,
532
+ repetition_penalty=repetition_penalty,
533
+ batch_size=batch_size,
534
+ num_beams=1,
535
+ )
536
+
537
+ # if model has past, then set the past variable to speed up decoding
538
+ if self._use_cache(outputs, use_cache):
539
+ past = outputs[1]
540
+
541
+ if do_sample:
542
+ # Temperature (higher temperature => more likely to sample low probability tokens)
543
+ if temperature != 1.0:
544
+ scores = scores / temperature
545
+ # Top-p/top-k filtering
546
+ next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
547
+ # Sample
548
+ probs = F.softmax(next_token_logscores, dim=-1)
549
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
550
+ else:
551
+ # Greedy decoding
552
+ next_token = torch.argmax(next_token_logits, dim=-1)
553
+
554
+ # update generations and finished sentences
555
+ if eos_token_id is not None:
556
+ # pad finished sentences if eos_token_id exist
557
+ tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
558
+ else:
559
+ tokens_to_add = next_token
560
+
561
+ # add token and increase length by one
562
+ input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
563
+ cur_len = cur_len + 1
564
+
565
+ if eos_token_id is not None:
566
+ eos_in_sents = tokens_to_add == eos_token_id
567
+ # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
568
+ is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
569
+ sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
570
+ # unfinished_sents is set to zero if eos in sentence
571
+ unfinished_sents.mul_((~eos_in_sents).long())
572
+
573
+ # stop when there is a </s> in each sentence, or if we exceed the maximul length
574
+ if unfinished_sents.max() == 0:
575
+ break
576
+
577
+ # extend attention_mask for new generated input if only decoder
578
+ if self.config.is_encoder_decoder is False:
579
+ attention_mask = torch.cat(
580
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
581
+ )
582
+
583
+ return input_ids
584
+
585
+ def _generate_beam_search(
586
+ self,
587
+ input_ids,
588
+ cur_len,
589
+ max_length,
590
+ min_length,
591
+ do_sample,
592
+ early_stopping,
593
+ temperature,
594
+ top_k,
595
+ top_p,
596
+ repetition_penalty,
597
+ no_repeat_ngram_size,
598
+ bad_words_ids,
599
+ pad_token_id,
600
+ eos_token_id,
601
+ batch_size,
602
+ num_return_sequences,
603
+ length_penalty,
604
+ num_beams,
605
+ vocab_size,
606
+ encoder_outputs,
607
+ attention_mask,
608
+ use_cache,
609
+ model_specific_kwargs,
610
+ ):
611
+ """ Generate sequences for each example with beam search.
612
+ """
613
+
614
+ # generated hypotheses
615
+ generated_hyps = [
616
+ BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
617
+ for _ in range(batch_size)
618
+ ]
619
+
620
+ # scores for each sentence in the beam
621
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
622
+
623
+ # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
624
+ if do_sample is False:
625
+ beam_scores[:, 1:] = -1e9
626
+ beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
627
+
628
+ # cache compute states
629
+ past = (encoder_outputs, None) if encoder_outputs is not None else None
630
+
631
+ # done sentences
632
+ done = [False for _ in range(batch_size)]
633
+
634
+ while cur_len < max_length:
635
+ model_inputs = self.prepare_inputs_for_generation(
636
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
637
+ )
638
+ outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
639
+ next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
640
+
641
+ # if model has past, then set the past variable to speed up decoding
642
+ if self._use_cache(outputs, use_cache):
643
+ past = outputs[1]
644
+ if self.config.is_encoder_decoder and do_sample is False:
645
+ # TODO (PVP) still a bit hacky here - there might be a better solution
646
+ next_token_logits = self.adjust_logits_during_generation(
647
+ next_token_logits, cur_len=cur_len, max_length=max_length
648
+ )
649
+
650
+ scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
651
+
652
+ scores = self.postprocess_next_token_scores(
653
+ scores=scores,
654
+ input_ids=input_ids,
655
+ no_repeat_ngram_size=no_repeat_ngram_size,
656
+ bad_words_ids=bad_words_ids,
657
+ cur_len=cur_len,
658
+ min_length=min_length,
659
+ max_length=max_length,
660
+ eos_token_id=eos_token_id,
661
+ repetition_penalty=repetition_penalty,
662
+ batch_size=batch_size,
663
+ num_beams=num_beams,
664
+ )
665
+
666
+ assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
667
+ scores.shape, (batch_size * num_beams, vocab_size)
668
+ )
669
+
670
+ if do_sample:
671
+ _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
672
+ # Temperature
673
+ if temperature != 1.0:
674
+ _scores = _scores / temperature
675
+ # Top-p/top-k filtering
676
+ _scores = top_k_top_p_filtering(
677
+ _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
678
+ ) # (batch_size * num_beams, vocab_size)
679
+ # re-organize to group the beam together to sample from all beam_idxs
680
+ _scores = _scores.contiguous().view(
681
+ batch_size, num_beams * vocab_size
682
+ ) # (batch_size, num_beams * vocab_size)
683
+
684
+ # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
685
+ probs = F.softmax(_scores, dim=-1)
686
+ next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
687
+ # Compute next scores
688
+ next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
689
+ # sort the sampled vector to make sure that the first num_beams samples are the best
690
+ next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
691
+ next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
692
+
693
+ else:
694
+ next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
695
+
696
+ # re-organize to group the beam together (we are keeping top hypothesis accross beams)
697
+ next_scores = next_scores.view(
698
+ batch_size, num_beams * vocab_size
699
+ ) # (batch_size, num_beams * vocab_size)
700
+
701
+ next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
702
+
703
+ assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
704
+
705
+ # next batch beam content
706
+ next_batch_beam = []
707
+
708
+ # for each sentence
709
+ for batch_idx in range(batch_size):
710
+
711
+ # if we are done with this sentence, add a pad token
712
+ if done[batch_idx]:
713
+ assert (
714
+ len(generated_hyps[batch_idx]) >= num_beams
715
+ ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
716
+ assert (
717
+ eos_token_id is not None and pad_token_id is not None
718
+ ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
719
+ next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
720
+ continue
721
+
722
+ # next sentence beam content, this will get added to next_batch_beam
723
+ next_sent_beam = []
724
+
725
+ # next tokens for this sentence
726
+ for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
727
+ zip(next_tokens[batch_idx], next_scores[batch_idx])
728
+ ):
729
+ # get beam and token IDs
730
+ beam_id = beam_token_id // vocab_size
731
+ token_id = beam_token_id % vocab_size
732
+
733
+ effective_beam_id = batch_idx * num_beams + beam_id
734
+ # add to generated hypotheses if end of sentence
735
+ if (eos_token_id is not None) and (token_id.item() == eos_token_id):
736
+ # if beam_token does not belong to top num_beams tokens, it should not be added
737
+ is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
738
+ if is_beam_token_worse_than_top_num_beams:
739
+ continue
740
+ generated_hyps[batch_idx].add(
741
+ input_ids[effective_beam_id].clone(), beam_token_score.item(),
742
+ )
743
+ else:
744
+ # add next predicted token since it is not eos_token
745
+ next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
746
+
747
+ # once the beam for next step is full, don't add more tokens to it.
748
+ if len(next_sent_beam) == num_beams:
749
+ break
750
+
751
+ # Check if we are done so that we can save a pad step if all(done)
752
+ done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
753
+ next_scores[batch_idx].max().item(), cur_len
754
+ )
755
+
756
+ # update next beam content
757
+ assert len(next_sent_beam) == num_beams, "Beam should always be full"
758
+ next_batch_beam.extend(next_sent_beam)
759
+ assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
760
+
761
+ # stop when we are done with each sentence
762
+ if all(done):
763
+ break
764
+
765
+ # sanity check / prepare next batch
766
+ assert len(next_batch_beam) == batch_size * num_beams
767
+ beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
768
+ beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
769
+ beam_idx = input_ids.new([x[2] for x in next_batch_beam])
770
+
771
+ # re-order batch and update current length
772
+ input_ids = input_ids[beam_idx, :]
773
+ input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
774
+ cur_len = cur_len + 1
775
+
776
+ # re-order internal states
777
+ if past is not None:
778
+ past = self._reorder_cache(past, beam_idx)
779
+
780
+ # extend attention_mask for new generated input if only decoder
781
+ if self.config.is_encoder_decoder is False:
782
+ attention_mask = torch.cat(
783
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
784
+ )
785
+
786
+ # finalize all open beam hypotheses and add to generated hypotheses
787
+ for batch_idx in range(batch_size):
788
+ if done[batch_idx]:
789
+ continue
790
+
791
+ # test that beam scores match previously calculated scores if not eos and batch_idx not done
792
+ if eos_token_id is not None and all(
793
+ (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
794
+ ):
795
+ assert torch.all(
796
+ next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
797
+ ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
798
+ next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
799
+ )
800
+
801
+ # need to add best num_beams hypotheses to generated hyps
802
+ for beam_id in range(num_beams):
803
+ effective_beam_id = batch_idx * num_beams + beam_id
804
+ final_score = beam_scores[effective_beam_id].item()
805
+ final_tokens = input_ids[effective_beam_id]
806
+ generated_hyps[batch_idx].add(final_tokens, final_score)
807
+
808
+ # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
809
+ output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
810
+ output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
811
+
812
+ # select the best hypotheses
813
+ sent_lengths = input_ids.new(output_batch_size)
814
+ best = []
815
+
816
+ # retrieve best hypotheses
817
+ for i, hypotheses in enumerate(generated_hyps):
818
+ sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
819
+ for j in range(output_num_return_sequences_per_batch):
820
+ effective_batch_idx = output_num_return_sequences_per_batch * i + j
821
+ best_hyp = sorted_hyps.pop()[1]
822
+ sent_lengths[effective_batch_idx] = len(best_hyp)
823
+ best.append(best_hyp)
824
+
825
+ # shorter batches are padded
826
+ if sent_lengths.min().item() != sent_lengths.max().item():
827
+ assert pad_token_id is not None, "`Pad_token_id` has to be defined"
828
+ sent_max_len = min(sent_lengths.max().item() + 1, max_length)
829
+ decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
830
+
831
+ # fill with hypothesis and eos_token_id if necessary
832
+ for i, hypo in enumerate(best):
833
+ decoded[i, : sent_lengths[i]] = hypo
834
+ if sent_lengths[i] < max_length:
835
+ decoded[i, sent_lengths[i]] = eos_token_id
836
+ else:
837
+ # none of the hypotheses have an eos_token
838
+ assert (len(hypo) == max_length for hypo in best)
839
+ decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
840
+
841
+ return decoded
842
+
843
+ @staticmethod
844
+ def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
845
+ return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
846
+
847
+
848
+ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
849
+ """Copied from fairseq for no_repeat_ngram in beam_search"""
850
+ if cur_len + 1 < no_repeat_ngram_size:
851
+ # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
852
+ return [[] for _ in range(num_hypos)]
853
+ generated_ngrams = [{} for _ in range(num_hypos)]
854
+ for idx in range(num_hypos):
855
+ gen_tokens = prev_input_ids[idx].tolist()
856
+ generated_ngram = generated_ngrams[idx]
857
+ for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
858
+ prev_ngram_tuple = tuple(ngram[:-1])
859
+ generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
860
+
861
+ def _get_generated_ngrams(hypo_idx):
862
+ # Before decoding the next token, prevent decoding of ngrams that have already appeared
863
+ start_idx = cur_len + 1 - no_repeat_ngram_size
864
+ ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
865
+ return generated_ngrams[hypo_idx].get(ngram_idx, [])
866
+
867
+ banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
868
+ return banned_tokens
869
+
870
+
871
+ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
872
+ banned_tokens = []
873
+
874
+ def _tokens_match(prev_tokens, tokens):
875
+ if len(tokens) == 0:
876
+ # if bad word tokens is just one token always ban it
877
+ return True
878
+ if len(tokens) > len(prev_input_ids):
879
+ # if bad word tokens are longer then prev input_ids they can't be equal
880
+ return False
881
+
882
+ if prev_tokens[-len(tokens) :] == tokens:
883
+ # if tokens match
884
+ return True
885
+ else:
886
+ return False
887
+
888
+ for prev_input_ids_slice in prev_input_ids:
889
+ banned_tokens_slice = []
890
+
891
+ for banned_token_seq in bad_words_ids:
892
+ assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
893
+ bad_words_ids
894
+ )
895
+
896
+ if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
897
+ # if tokens do not match continue
898
+ continue
899
+
900
+ banned_tokens_slice.append(banned_token_seq[-1])
901
+
902
+ banned_tokens.append(banned_tokens_slice)
903
+
904
+ return banned_tokens
905
+
906
+
907
+ def top_k_top_p_filtering(
908
+ logits: Tensor,
909
+ top_k: int = 0,
910
+ top_p: float = 1.0,
911
+ filter_value: float = -float("Inf"),
912
+ min_tokens_to_keep: int = 1,
913
+ ) -> Tensor:
914
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
915
+ Args:
916
+ logits: logits distribution shape (batch size, vocabulary size)
917
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
918
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
919
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
920
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
921
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
922
+ """
923
+ if top_k > 0:
924
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
925
+ # Remove all tokens with a probability less than the last token of the top-k
926
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
927
+ logits[indices_to_remove] = filter_value
928
+
929
+ if top_p < 1.0:
930
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
931
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
932
+
933
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
934
+ sorted_indices_to_remove = cumulative_probs > top_p
935
+ if min_tokens_to_keep > 1:
936
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
937
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
938
+ # Shift the indices to the right to keep also the first token above the threshold
939
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
940
+ sorted_indices_to_remove[..., 0] = 0
941
+
942
+ # scatter sorted tensors to original indexing
943
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
944
+ logits[indices_to_remove] = filter_value
945
+ return logits
946
+
947
+
948
+ class BeamHypotheses(object):
949
+ def __init__(self, num_beams, max_length, length_penalty, early_stopping):
950
+ """
951
+ Initialize n-best list of hypotheses.
952
+ """
953
+ self.max_length = max_length - 1 # ignoring bos_token
954
+ self.length_penalty = length_penalty
955
+ self.early_stopping = early_stopping
956
+ self.num_beams = num_beams
957
+ self.beams = []
958
+ self.worst_score = 1e9
959
+
960
+ def __len__(self):
961
+ """
962
+ Number of hypotheses in the list.
963
+ """
964
+ return len(self.beams)
965
+
966
+ def add(self, hyp, sum_logprobs):
967
+ """
968
+ Add a new hypothesis to the list.
969
+ """
970
+ score = sum_logprobs / len(hyp) ** self.length_penalty
971
+ if len(self) < self.num_beams or score > self.worst_score:
972
+ self.beams.append((score, hyp))
973
+ if len(self) > self.num_beams:
974
+ sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
975
+ del self.beams[sorted_scores[0][1]]
976
+ self.worst_score = sorted_scores[1][0]
977
+ else:
978
+ self.worst_score = min(score, self.worst_score)
979
+
980
+ def is_done(self, best_sum_logprobs, cur_len):
981
+ """
982
+ If there are enough hypotheses and that none of the hypotheses being generated
983
+ can become better than the worst one in the heap, then we are done with this sentence.
984
+ """
985
+
986
+ if len(self) < self.num_beams:
987
+ return False
988
+ elif self.early_stopping:
989
+ return True
990
+ else:
991
+ cur_score = best_sum_logprobs / cur_len ** self.length_penalty
992
+ ret = self.worst_score >= cur_score
993
+ return ret
elia/bert/modeling_bert.py ADDED
@@ -0,0 +1,1569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model. """
17
+
18
+
19
+ import logging
20
+ import math
21
+ import os
22
+ import warnings
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import CrossEntropyLoss, MSELoss
28
+
29
+ from .activations import gelu, gelu_new, swish
30
+ from .configuration_bert import BertConfig
31
+ from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
32
+ from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
33
+
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ _TOKENIZER_FOR_DOC = "BertTokenizer"
38
+
39
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
40
+ "bert-base-uncased",
41
+ "bert-large-uncased",
42
+ "bert-base-cased",
43
+ "bert-large-cased",
44
+ "bert-base-multilingual-uncased",
45
+ "bert-base-multilingual-cased",
46
+ "bert-base-chinese",
47
+ "bert-base-german-cased",
48
+ "bert-large-uncased-whole-word-masking",
49
+ "bert-large-cased-whole-word-masking",
50
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
51
+ "bert-large-cased-whole-word-masking-finetuned-squad",
52
+ "bert-base-cased-finetuned-mrpc",
53
+ "bert-base-german-dbmdz-cased",
54
+ "bert-base-german-dbmdz-uncased",
55
+ "cl-tohoku/bert-base-japanese",
56
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
57
+ "cl-tohoku/bert-base-japanese-char",
58
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
59
+ "TurkuNLP/bert-base-finnish-cased-v1",
60
+ "TurkuNLP/bert-base-finnish-uncased-v1",
61
+ "wietsedv/bert-base-dutch-cased",
62
+ # See all BERT models at https://huggingface.co/models?filter=bert
63
+ ]
64
+
65
+
66
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
67
+ """ Load tf checkpoints in a pytorch model.
68
+ """
69
+ try:
70
+ import re
71
+ import numpy as np
72
+ import tensorflow as tf
73
+ except ImportError:
74
+ logger.error(
75
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
76
+ "https://www.tensorflow.org/install/ for installation instructions."
77
+ )
78
+ raise
79
+ tf_path = os.path.abspath(tf_checkpoint_path)
80
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
81
+ # Load weights from TF model
82
+ init_vars = tf.train.list_variables(tf_path)
83
+ names = []
84
+ arrays = []
85
+ for name, shape in init_vars:
86
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
87
+ array = tf.train.load_variable(tf_path, name)
88
+ names.append(name)
89
+ arrays.append(array)
90
+
91
+ for name, array in zip(names, arrays):
92
+ name = name.split("/")
93
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
94
+ # which are not required for using pretrained model
95
+ if any(
96
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
97
+ for n in name
98
+ ):
99
+ logger.info("Skipping {}".format("/".join(name)))
100
+ continue
101
+ pointer = model
102
+ for m_name in name:
103
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
104
+ scope_names = re.split(r"_(\d+)", m_name)
105
+ else:
106
+ scope_names = [m_name]
107
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
108
+ pointer = getattr(pointer, "weight")
109
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
110
+ pointer = getattr(pointer, "bias")
111
+ elif scope_names[0] == "output_weights":
112
+ pointer = getattr(pointer, "weight")
113
+ elif scope_names[0] == "squad":
114
+ pointer = getattr(pointer, "classifier")
115
+ else:
116
+ try:
117
+ pointer = getattr(pointer, scope_names[0])
118
+ except AttributeError:
119
+ logger.info("Skipping {}".format("/".join(name)))
120
+ continue
121
+ if len(scope_names) >= 2:
122
+ num = int(scope_names[1])
123
+ pointer = pointer[num]
124
+ if m_name[-11:] == "_embeddings":
125
+ pointer = getattr(pointer, "weight")
126
+ elif m_name == "kernel":
127
+ array = np.transpose(array)
128
+ try:
129
+ assert pointer.shape == array.shape
130
+ except AssertionError as e:
131
+ e.args += (pointer.shape, array.shape)
132
+ raise
133
+ logger.info("Initialize PyTorch weight {}".format(name))
134
+ pointer.data = torch.from_numpy(array)
135
+ return model
136
+
137
+
138
+ def mish(x):
139
+ return x * torch.tanh(nn.functional.softplus(x))
140
+
141
+
142
+ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
143
+
144
+
145
+ BertLayerNorm = torch.nn.LayerNorm
146
+
147
+
148
+ class BertEmbeddings(nn.Module):
149
+ """Construct the embeddings from word, position and token_type embeddings.
150
+ """
151
+
152
+ def __init__(self, config):
153
+ super().__init__()
154
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
155
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
156
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
157
+
158
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
159
+ # any TensorFlow checkpoint file
160
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
161
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
162
+
163
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
164
+ if input_ids is not None:
165
+ input_shape = input_ids.size()
166
+ else:
167
+ input_shape = inputs_embeds.size()[:-1]
168
+
169
+ seq_length = input_shape[1]
170
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
171
+ if position_ids is None:
172
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
173
+ position_ids = position_ids.unsqueeze(0).expand(input_shape)
174
+ if token_type_ids is None:
175
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
176
+
177
+ if inputs_embeds is None:
178
+ inputs_embeds = self.word_embeddings(input_ids)
179
+ position_embeddings = self.position_embeddings(position_ids)
180
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
181
+
182
+ embeddings = inputs_embeds + position_embeddings + token_type_embeddings
183
+ embeddings = self.LayerNorm(embeddings)
184
+ embeddings = self.dropout(embeddings)
185
+ return embeddings
186
+
187
+
188
+ class BertSelfAttention(nn.Module):
189
+ def __init__(self, config):
190
+ super().__init__()
191
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
192
+ raise ValueError(
193
+ "The hidden size (%d) is not a multiple of the number of attention "
194
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
195
+ )
196
+
197
+ self.num_attention_heads = config.num_attention_heads
198
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
199
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
200
+
201
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
202
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
203
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
204
+
205
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
206
+
207
+ def transpose_for_scores(self, x):
208
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
209
+ x = x.view(*new_x_shape)
210
+ return x.permute(0, 2, 1, 3)
211
+
212
+ def forward(
213
+ self,
214
+ hidden_states,
215
+ attention_mask=None,
216
+ head_mask=None,
217
+ encoder_hidden_states=None,
218
+ encoder_attention_mask=None,
219
+ output_attentions=False,
220
+ ):
221
+ mixed_query_layer = self.query(hidden_states)
222
+
223
+ # If this is instantiated as a cross-attention module, the keys
224
+ # and values come from an encoder; the attention mask needs to be
225
+ # such that the encoder's padding tokens are not attended to.
226
+ if encoder_hidden_states is not None:
227
+ mixed_key_layer = self.key(encoder_hidden_states)
228
+ mixed_value_layer = self.value(encoder_hidden_states)
229
+ attention_mask = encoder_attention_mask
230
+ else:
231
+ mixed_key_layer = self.key(hidden_states)
232
+ mixed_value_layer = self.value(hidden_states)
233
+
234
+ query_layer = self.transpose_for_scores(mixed_query_layer)
235
+ key_layer = self.transpose_for_scores(mixed_key_layer)
236
+ value_layer = self.transpose_for_scores(mixed_value_layer)
237
+
238
+ # Take the dot product between "query" and "key" to get the raw attention scores.
239
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
240
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
241
+ if attention_mask is not None:
242
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
243
+ attention_scores = attention_scores + attention_mask
244
+
245
+ # Normalize the attention scores to probabilities.
246
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
247
+
248
+ # This is actually dropping out entire tokens to attend to, which might
249
+ # seem a bit unusual, but is taken from the original Transformer paper.
250
+ attention_probs = self.dropout(attention_probs)
251
+
252
+ # Mask heads if we want to
253
+ if head_mask is not None:
254
+ attention_probs = attention_probs * head_mask
255
+
256
+ context_layer = torch.matmul(attention_probs, value_layer)
257
+
258
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
259
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
260
+ context_layer = context_layer.view(*new_context_layer_shape)
261
+
262
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
263
+ return outputs
264
+
265
+
266
+ class BertSelfOutput(nn.Module):
267
+ def __init__(self, config):
268
+ super().__init__()
269
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
270
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
271
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
272
+
273
+ def forward(self, hidden_states, input_tensor):
274
+ hidden_states = self.dense(hidden_states)
275
+ hidden_states = self.dropout(hidden_states)
276
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
277
+ return hidden_states
278
+
279
+
280
+ class BertAttention(nn.Module):
281
+ def __init__(self, config):
282
+ super().__init__()
283
+ self.self = BertSelfAttention(config)
284
+ self.output = BertSelfOutput(config)
285
+ self.pruned_heads = set()
286
+
287
+ def prune_heads(self, heads):
288
+ if len(heads) == 0:
289
+ return
290
+ heads, index = find_pruneable_heads_and_indices(
291
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
292
+ )
293
+
294
+ # Prune linear layers
295
+ self.self.query = prune_linear_layer(self.self.query, index)
296
+ self.self.key = prune_linear_layer(self.self.key, index)
297
+ self.self.value = prune_linear_layer(self.self.value, index)
298
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
299
+
300
+ # Update hyper params and store pruned heads
301
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
302
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
303
+ self.pruned_heads = self.pruned_heads.union(heads)
304
+
305
+ def forward(
306
+ self,
307
+ hidden_states,
308
+ attention_mask=None,
309
+ head_mask=None,
310
+ encoder_hidden_states=None,
311
+ encoder_attention_mask=None,
312
+ output_attentions=False,
313
+ ):
314
+ self_outputs = self.self(
315
+ hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
316
+ )
317
+ attention_output = self.output(self_outputs[0], hidden_states)
318
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
319
+ return outputs
320
+
321
+
322
+ class BertIntermediate(nn.Module):
323
+ def __init__(self, config):
324
+ super().__init__()
325
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
326
+ if isinstance(config.hidden_act, str):
327
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
328
+ else:
329
+ self.intermediate_act_fn = config.hidden_act
330
+
331
+ def forward(self, hidden_states):
332
+ hidden_states = self.dense(hidden_states)
333
+ hidden_states = self.intermediate_act_fn(hidden_states)
334
+ return hidden_states
335
+
336
+
337
+ class BertOutput(nn.Module):
338
+ def __init__(self, config):
339
+ super().__init__()
340
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
341
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
342
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
343
+
344
+ def forward(self, hidden_states, input_tensor):
345
+ hidden_states = self.dense(hidden_states)
346
+ hidden_states = self.dropout(hidden_states)
347
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
348
+ return hidden_states
349
+
350
+
351
+ class BertLayer(nn.Module):
352
+ def __init__(self, config):
353
+ super().__init__()
354
+ self.attention = BertAttention(config)
355
+ self.is_decoder = config.is_decoder
356
+ if self.is_decoder:
357
+ self.crossattention = BertAttention(config)
358
+ self.intermediate = BertIntermediate(config)
359
+ self.output = BertOutput(config)
360
+
361
+ def forward(
362
+ self,
363
+ hidden_states,
364
+ attention_mask=None,
365
+ head_mask=None,
366
+ encoder_hidden_states=None,
367
+ encoder_attention_mask=None,
368
+ output_attentions=False,
369
+ ):
370
+ self_attention_outputs = self.attention(
371
+ hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
372
+ )
373
+ attention_output = self_attention_outputs[0]
374
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
375
+
376
+ if self.is_decoder and encoder_hidden_states is not None:
377
+ cross_attention_outputs = self.crossattention(
378
+ attention_output,
379
+ attention_mask,
380
+ head_mask,
381
+ encoder_hidden_states,
382
+ encoder_attention_mask,
383
+ output_attentions,
384
+ )
385
+ attention_output = cross_attention_outputs[0]
386
+ outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
387
+
388
+ intermediate_output = self.intermediate(attention_output)
389
+ layer_output = self.output(intermediate_output, attention_output)
390
+ outputs = (layer_output,) + outputs
391
+ return outputs
392
+
393
+
394
+ class BertEncoder(nn.Module):
395
+ def __init__(self, config):
396
+ super().__init__()
397
+ self.config = config
398
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
399
+
400
+ def forward(
401
+ self,
402
+ hidden_states,
403
+ attention_mask=None,
404
+ head_mask=None,
405
+ encoder_hidden_states=None,
406
+ encoder_attention_mask=None,
407
+ output_attentions=False,
408
+ output_hidden_states=False,
409
+ ):
410
+ all_hidden_states = ()
411
+ all_attentions = ()
412
+ for i, layer_module in enumerate(self.layer):
413
+ if output_hidden_states:
414
+ all_hidden_states = all_hidden_states + (hidden_states,)
415
+
416
+ if getattr(self.config, "gradient_checkpointing", False):
417
+
418
+ def create_custom_forward(module):
419
+ def custom_forward(*inputs):
420
+ return module(*inputs, output_attentions)
421
+
422
+ return custom_forward
423
+
424
+ layer_outputs = torch.utils.checkpoint.checkpoint(
425
+ create_custom_forward(layer_module),
426
+ hidden_states,
427
+ attention_mask,
428
+ head_mask[i],
429
+ encoder_hidden_states,
430
+ encoder_attention_mask,
431
+ )
432
+ else:
433
+ layer_outputs = layer_module(
434
+ hidden_states,
435
+ attention_mask,
436
+ head_mask[i],
437
+ encoder_hidden_states,
438
+ encoder_attention_mask,
439
+ output_attentions,
440
+ )
441
+ hidden_states = layer_outputs[0]
442
+
443
+ if output_attentions:
444
+ all_attentions = all_attentions + (layer_outputs[1],)
445
+
446
+ # Add last layer
447
+ if output_hidden_states:
448
+ all_hidden_states = all_hidden_states + (hidden_states,)
449
+
450
+ outputs = (hidden_states,)
451
+ if output_hidden_states:
452
+ outputs = outputs + (all_hidden_states,)
453
+ if output_attentions:
454
+ outputs = outputs + (all_attentions,)
455
+ return outputs # last-layer hidden state, (all hidden states), (all attentions)
456
+
457
+
458
+ class BertPooler(nn.Module):
459
+ def __init__(self, config):
460
+ super().__init__()
461
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
462
+ self.activation = nn.Tanh()
463
+
464
+ def forward(self, hidden_states):
465
+ # We "pool" the model by simply taking the hidden state corresponding
466
+ # to the first token.
467
+ first_token_tensor = hidden_states[:, 0]
468
+ pooled_output = self.dense(first_token_tensor)
469
+ pooled_output = self.activation(pooled_output)
470
+ return pooled_output
471
+
472
+
473
+ class BertPredictionHeadTransform(nn.Module):
474
+ def __init__(self, config):
475
+ super().__init__()
476
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
477
+ if isinstance(config.hidden_act, str):
478
+ self.transform_act_fn = ACT2FN[config.hidden_act]
479
+ else:
480
+ self.transform_act_fn = config.hidden_act
481
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
482
+
483
+ def forward(self, hidden_states):
484
+ hidden_states = self.dense(hidden_states)
485
+ hidden_states = self.transform_act_fn(hidden_states)
486
+ hidden_states = self.LayerNorm(hidden_states)
487
+ return hidden_states
488
+
489
+
490
+ class BertLMPredictionHead(nn.Module):
491
+ def __init__(self, config):
492
+ super().__init__()
493
+ self.transform = BertPredictionHeadTransform(config)
494
+
495
+ # The output weights are the same as the input embeddings, but there is
496
+ # an output-only bias for each token.
497
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
498
+
499
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
500
+
501
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
502
+ self.decoder.bias = self.bias
503
+
504
+ def forward(self, hidden_states):
505
+ hidden_states = self.transform(hidden_states)
506
+ hidden_states = self.decoder(hidden_states)
507
+ return hidden_states
508
+
509
+
510
+ class BertOnlyMLMHead(nn.Module):
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.predictions = BertLMPredictionHead(config)
514
+
515
+ def forward(self, sequence_output):
516
+ prediction_scores = self.predictions(sequence_output)
517
+ return prediction_scores
518
+
519
+
520
+ class BertOnlyNSPHead(nn.Module):
521
+ def __init__(self, config):
522
+ super().__init__()
523
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
524
+
525
+ def forward(self, pooled_output):
526
+ seq_relationship_score = self.seq_relationship(pooled_output)
527
+ return seq_relationship_score
528
+
529
+
530
+ class BertPreTrainingHeads(nn.Module):
531
+ def __init__(self, config):
532
+ super().__init__()
533
+ self.predictions = BertLMPredictionHead(config)
534
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
535
+
536
+ def forward(self, sequence_output, pooled_output):
537
+ prediction_scores = self.predictions(sequence_output)
538
+ seq_relationship_score = self.seq_relationship(pooled_output)
539
+ return prediction_scores, seq_relationship_score
540
+
541
+
542
+ class BertPreTrainedModel(PreTrainedModel):
543
+ """ An abstract class to handle weights initialization and
544
+ a simple interface for downloading and loading pretrained models.
545
+ """
546
+
547
+ config_class = BertConfig
548
+ load_tf_weights = load_tf_weights_in_bert
549
+ base_model_prefix = "bert"
550
+
551
+ def _init_weights(self, module):
552
+ """ Initialize the weights """
553
+ if isinstance(module, (nn.Linear, nn.Embedding)):
554
+ # Slightly different from the TF version which uses truncated_normal for initialization
555
+ # cf https://github.com/pytorch/pytorch/pull/5617
556
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
557
+ elif isinstance(module, BertLayerNorm):
558
+ module.bias.data.zero_()
559
+ module.weight.data.fill_(1.0)
560
+ if isinstance(module, nn.Linear) and module.bias is not None:
561
+ module.bias.data.zero_()
562
+
563
+
564
+ BERT_START_DOCSTRING = r"""
565
+ This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
566
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
567
+ usage and behavior.
568
+
569
+ Parameters:
570
+ config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
571
+ Initializing with a config file does not load the weights associated with the model, only the configuration.
572
+ Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
573
+ """
574
+
575
+ BERT_INPUTS_DOCSTRING = r"""
576
+ Args:
577
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
578
+ Indices of input sequence tokens in the vocabulary.
579
+
580
+ Indices can be obtained using :class:`transformers.BertTokenizer`.
581
+ See :func:`transformers.PreTrainedTokenizer.encode` and
582
+ :func:`transformers.PreTrainedTokenizer.__call__` for details.
583
+
584
+ `What are input IDs? <../glossary.html#input-ids>`__
585
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
586
+ Mask to avoid performing attention on padding token indices.
587
+ Mask values selected in ``[0, 1]``:
588
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
589
+
590
+ `What are attention masks? <../glossary.html#attention-mask>`__
591
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
592
+ Segment token indices to indicate first and second portions of the inputs.
593
+ Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
594
+ corresponds to a `sentence B` token
595
+
596
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
597
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
598
+ Indices of positions of each input sequence tokens in the position embeddings.
599
+ Selected in the range ``[0, config.max_position_embeddings - 1]``.
600
+
601
+ `What are position IDs? <../glossary.html#position-ids>`_
602
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
603
+ Mask to nullify selected heads of the self-attention modules.
604
+ Mask values selected in ``[0, 1]``:
605
+ :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
606
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
607
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
608
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
609
+ than the model's internal embedding lookup matrix.
610
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
611
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
612
+ if the model is configured as a decoder.
613
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
614
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask
615
+ is used in the cross-attention if the model is configured as a decoder.
616
+ Mask values selected in ``[0, 1]``:
617
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
618
+ output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
619
+ If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
620
+ """
621
+
622
+
623
+ @add_start_docstrings(
624
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
625
+ BERT_START_DOCSTRING,
626
+ )
627
+ class BertModel(BertPreTrainedModel):
628
+ """
629
+
630
+ The model can behave as an encoder (with only self-attention) as well
631
+ as a decoder, in which case a layer of cross-attention is added between
632
+ the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
633
+ Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
634
+
635
+ To behave as an decoder the model needs to be initialized with the
636
+ :obj:`is_decoder` argument of the configuration set to :obj:`True`; an
637
+ :obj:`encoder_hidden_states` is expected as an input to the forward pass.
638
+
639
+ .. _`Attention is all you need`:
640
+ https://arxiv.org/abs/1706.03762
641
+
642
+ """
643
+
644
+ def __init__(self, config):
645
+ super().__init__(config)
646
+ self.config = config
647
+
648
+ self.embeddings = BertEmbeddings(config)
649
+ self.encoder = BertEncoder(config)
650
+ self.pooler = BertPooler(config)
651
+
652
+ self.init_weights()
653
+
654
+ def get_input_embeddings(self):
655
+ return self.embeddings.word_embeddings
656
+
657
+ def set_input_embeddings(self, value):
658
+ self.embeddings.word_embeddings = value
659
+
660
+ def _prune_heads(self, heads_to_prune):
661
+ """ Prunes heads of the model.
662
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
663
+ See base class PreTrainedModel
664
+ """
665
+ for layer, heads in heads_to_prune.items():
666
+ self.encoder.layer[layer].attention.prune_heads(heads)
667
+
668
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
669
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
670
+ def forward(
671
+ self,
672
+ input_ids=None,
673
+ attention_mask=None,
674
+ token_type_ids=None,
675
+ position_ids=None,
676
+ head_mask=None,
677
+ inputs_embeds=None,
678
+ encoder_hidden_states=None,
679
+ encoder_attention_mask=None,
680
+ output_attentions=None,
681
+ output_hidden_states=None,
682
+ ):
683
+ r"""
684
+ Return:
685
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
686
+ last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
687
+ Sequence of hidden-states at the output of the last layer of the model.
688
+ pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
689
+ Last layer hidden-state of the first token of the sequence (classification token)
690
+ further processed by a Linear layer and a Tanh activation function. The Linear
691
+ layer weights are trained from the next sentence prediction (classification)
692
+ objective during pre-training.
693
+
694
+ This output is usually *not* a good summary
695
+ of the semantic content of the input, you're often better with averaging or pooling
696
+ the sequence of hidden-states for the whole input sequence.
697
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
698
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
699
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
700
+
701
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
702
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
703
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
704
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
705
+
706
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
707
+ heads.
708
+ """
709
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
710
+ output_hidden_states = (
711
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
712
+ )
713
+
714
+ if input_ids is not None and inputs_embeds is not None:
715
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
716
+ elif input_ids is not None:
717
+ input_shape = input_ids.size()
718
+ elif inputs_embeds is not None:
719
+ input_shape = inputs_embeds.size()[:-1]
720
+ else:
721
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
722
+
723
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
724
+
725
+ if attention_mask is None:
726
+ attention_mask = torch.ones(input_shape, device=device)
727
+ if token_type_ids is None:
728
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
729
+
730
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
731
+ # ourselves in which case we just need to make it broadcastable to all heads.
732
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
733
+
734
+ # If a 2D ou 3D attention mask is provided for the cross-attention
735
+ # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
736
+ if self.config.is_decoder and encoder_hidden_states is not None:
737
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
738
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
739
+ if encoder_attention_mask is None:
740
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
741
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
742
+ else:
743
+ encoder_extended_attention_mask = None
744
+
745
+ # Prepare head mask if needed
746
+ # 1.0 in head_mask indicate we keep the head
747
+ # attention_probs has shape bsz x n_heads x N x N
748
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
749
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
750
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
751
+
752
+ embedding_output = self.embeddings(
753
+ input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
754
+ )
755
+ encoder_outputs = self.encoder(
756
+ embedding_output,
757
+ attention_mask=extended_attention_mask,
758
+ head_mask=head_mask,
759
+ encoder_hidden_states=encoder_hidden_states,
760
+ encoder_attention_mask=encoder_extended_attention_mask,
761
+ output_attentions=output_attentions,
762
+ output_hidden_states=output_hidden_states,
763
+ )
764
+ sequence_output = encoder_outputs[0]
765
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
766
+
767
+ outputs = (sequence_output, pooled_output,) + encoder_outputs[
768
+ 1:
769
+ ] # add hidden_states and attentions if they are here
770
+ return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
771
+
772
+
773
+ @add_start_docstrings(
774
+ """Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
775
+ a `next sentence prediction (classification)` head. """,
776
+ BERT_START_DOCSTRING,
777
+ )
778
+ class BertForPreTraining(BertPreTrainedModel):
779
+ def __init__(self, config):
780
+ super().__init__(config)
781
+
782
+ self.bert = BertModel(config)
783
+ self.cls = BertPreTrainingHeads(config)
784
+
785
+ self.init_weights()
786
+
787
+ def get_output_embeddings(self):
788
+ return self.cls.predictions.decoder
789
+
790
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
791
+ def forward(
792
+ self,
793
+ input_ids=None,
794
+ attention_mask=None,
795
+ token_type_ids=None,
796
+ position_ids=None,
797
+ head_mask=None,
798
+ inputs_embeds=None,
799
+ labels=None,
800
+ next_sentence_label=None,
801
+ output_attentions=None,
802
+ output_hidden_states=None,
803
+ **kwargs
804
+ ):
805
+ r"""
806
+ labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
807
+ Labels for computing the masked language modeling loss.
808
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
809
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
810
+ in ``[0, ..., config.vocab_size]``
811
+ next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
812
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
813
+ Indices should be in ``[0, 1]``.
814
+ ``0`` indicates sequence B is a continuation of sequence A,
815
+ ``1`` indicates sequence B is a random sequence.
816
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
817
+ Used to hide legacy arguments that have been deprecated.
818
+
819
+ Returns:
820
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
821
+ loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
822
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
823
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
824
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
825
+ seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
826
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False
827
+ continuation before SoftMax).
828
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
829
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
830
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
831
+
832
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
833
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
834
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
835
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
836
+
837
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
838
+ heads.
839
+
840
+
841
+ Examples::
842
+
843
+ >>> from transformers import BertTokenizer, BertForPreTraining
844
+ >>> import torch
845
+
846
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
847
+ >>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
848
+
849
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
850
+ >>> outputs = model(**inputs)
851
+
852
+ >>> prediction_scores, seq_relationship_scores = outputs[:2]
853
+
854
+ """
855
+ if "masked_lm_labels" in kwargs:
856
+ warnings.warn(
857
+ "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
858
+ DeprecationWarning,
859
+ )
860
+ labels = kwargs.pop("masked_lm_labels")
861
+ assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
862
+
863
+ outputs = self.bert(
864
+ input_ids,
865
+ attention_mask=attention_mask,
866
+ token_type_ids=token_type_ids,
867
+ position_ids=position_ids,
868
+ head_mask=head_mask,
869
+ inputs_embeds=inputs_embeds,
870
+ output_attentions=output_attentions,
871
+ output_hidden_states=output_hidden_states,
872
+ )
873
+
874
+ sequence_output, pooled_output = outputs[:2]
875
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
876
+
877
+ outputs = (prediction_scores, seq_relationship_score,) + outputs[
878
+ 2:
879
+ ] # add hidden states and attention if they are here
880
+
881
+ if labels is not None and next_sentence_label is not None:
882
+ loss_fct = CrossEntropyLoss()
883
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
884
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
885
+ total_loss = masked_lm_loss + next_sentence_loss
886
+ outputs = (total_loss,) + outputs
887
+
888
+ return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
889
+
890
+
891
+ @add_start_docstrings(
892
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
893
+ )
894
+ class BertLMHeadModel(BertPreTrainedModel):
895
+ def __init__(self, config):
896
+ super().__init__(config)
897
+ assert config.is_decoder, "If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True`."
898
+
899
+ self.bert = BertModel(config)
900
+ self.cls = BertOnlyMLMHead(config)
901
+
902
+ self.init_weights()
903
+
904
+ def get_output_embeddings(self):
905
+ return self.cls.predictions.decoder
906
+
907
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
908
+ def forward(
909
+ self,
910
+ input_ids=None,
911
+ attention_mask=None,
912
+ token_type_ids=None,
913
+ position_ids=None,
914
+ head_mask=None,
915
+ inputs_embeds=None,
916
+ labels=None,
917
+ encoder_hidden_states=None,
918
+ encoder_attention_mask=None,
919
+ output_attentions=None,
920
+ output_hidden_states=None,
921
+ **kwargs
922
+ ):
923
+ r"""
924
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
925
+ Labels for computing the left-to-right language modeling loss (next word prediction).
926
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
927
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
928
+ in ``[0, ..., config.vocab_size]``
929
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
930
+ Used to hide legacy arguments that have been deprecated.
931
+
932
+ Returns:
933
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
934
+ ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
935
+ Next token prediction loss.
936
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
937
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
938
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
939
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
940
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
941
+
942
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
943
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
944
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
945
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
946
+
947
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
948
+ heads.
949
+
950
+ Example::
951
+
952
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
953
+ >>> import torch
954
+
955
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
956
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
957
+ >>> config.is_decoder = True
958
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
959
+
960
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
961
+ >>> outputs = model(**inputs)
962
+
963
+ >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
964
+ """
965
+
966
+ outputs = self.bert(
967
+ input_ids,
968
+ attention_mask=attention_mask,
969
+ token_type_ids=token_type_ids,
970
+ position_ids=position_ids,
971
+ head_mask=head_mask,
972
+ inputs_embeds=inputs_embeds,
973
+ encoder_hidden_states=encoder_hidden_states,
974
+ encoder_attention_mask=encoder_attention_mask,
975
+ output_attentions=output_attentions,
976
+ output_hidden_states=output_hidden_states,
977
+ )
978
+
979
+ sequence_output = outputs[0]
980
+ prediction_scores = self.cls(sequence_output)
981
+
982
+ outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
983
+
984
+ if labels is not None:
985
+ # we are doing next-token prediction; shift prediction scores and input ids by one
986
+ prediction_scores = prediction_scores[:, :-1, :].contiguous()
987
+ labels = labels[:, 1:].contiguous()
988
+ loss_fct = CrossEntropyLoss()
989
+ ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
990
+ outputs = (ltr_lm_loss,) + outputs
991
+
992
+ return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
993
+
994
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
995
+ input_shape = input_ids.shape
996
+
997
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
998
+ if attention_mask is None:
999
+ attention_mask = input_ids.new_ones(input_shape)
1000
+
1001
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1002
+
1003
+
1004
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
1005
+ class BertForMaskedLM(BertPreTrainedModel):
1006
+ def __init__(self, config):
1007
+ super().__init__(config)
1008
+ assert (
1009
+ not config.is_decoder
1010
+ ), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
1011
+
1012
+ self.bert = BertModel(config)
1013
+ self.cls = BertOnlyMLMHead(config)
1014
+
1015
+ self.init_weights()
1016
+
1017
+ def get_output_embeddings(self):
1018
+ return self.cls.predictions.decoder
1019
+
1020
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1021
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1022
+ def forward(
1023
+ self,
1024
+ input_ids=None,
1025
+ attention_mask=None,
1026
+ token_type_ids=None,
1027
+ position_ids=None,
1028
+ head_mask=None,
1029
+ inputs_embeds=None,
1030
+ labels=None,
1031
+ encoder_hidden_states=None,
1032
+ encoder_attention_mask=None,
1033
+ output_attentions=None,
1034
+ output_hidden_states=None,
1035
+ **kwargs
1036
+ ):
1037
+ r"""
1038
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1039
+ Labels for computing the masked language modeling loss.
1040
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
1041
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
1042
+ in ``[0, ..., config.vocab_size]``
1043
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
1044
+ Used to hide legacy arguments that have been deprecated.
1045
+
1046
+ Returns:
1047
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1048
+ masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
1049
+ Masked language modeling loss.
1050
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
1051
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1052
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1053
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1054
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1055
+
1056
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1057
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1058
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1059
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1060
+
1061
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1062
+ heads.
1063
+ """
1064
+ if "masked_lm_labels" in kwargs:
1065
+ warnings.warn(
1066
+ "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1067
+ DeprecationWarning,
1068
+ )
1069
+ labels = kwargs.pop("masked_lm_labels")
1070
+ assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
1071
+ assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
1072
+
1073
+ outputs = self.bert(
1074
+ input_ids,
1075
+ attention_mask=attention_mask,
1076
+ token_type_ids=token_type_ids,
1077
+ position_ids=position_ids,
1078
+ head_mask=head_mask,
1079
+ inputs_embeds=inputs_embeds,
1080
+ encoder_hidden_states=encoder_hidden_states,
1081
+ encoder_attention_mask=encoder_attention_mask,
1082
+ output_attentions=output_attentions,
1083
+ output_hidden_states=output_hidden_states,
1084
+ )
1085
+
1086
+ sequence_output = outputs[0]
1087
+ prediction_scores = self.cls(sequence_output)
1088
+
1089
+ outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
1090
+
1091
+ if labels is not None:
1092
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1093
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1094
+ outputs = (masked_lm_loss,) + outputs
1095
+
1096
+ return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
1097
+
1098
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1099
+ input_shape = input_ids.shape
1100
+ effective_batch_size = input_shape[0]
1101
+
1102
+ # add a dummy token
1103
+ assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
1104
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1105
+ dummy_token = torch.full(
1106
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1107
+ )
1108
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1109
+
1110
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1111
+
1112
+
1113
+ @add_start_docstrings(
1114
+ """Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
1115
+ )
1116
+ class BertForNextSentencePrediction(BertPreTrainedModel):
1117
+ def __init__(self, config):
1118
+ super().__init__(config)
1119
+
1120
+ self.bert = BertModel(config)
1121
+ self.cls = BertOnlyNSPHead(config)
1122
+
1123
+ self.init_weights()
1124
+
1125
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1126
+ def forward(
1127
+ self,
1128
+ input_ids=None,
1129
+ attention_mask=None,
1130
+ token_type_ids=None,
1131
+ position_ids=None,
1132
+ head_mask=None,
1133
+ inputs_embeds=None,
1134
+ next_sentence_label=None,
1135
+ output_attentions=None,
1136
+ output_hidden_states=None,
1137
+ ):
1138
+ r"""
1139
+ next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1140
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
1141
+ Indices should be in ``[0, 1]``.
1142
+ ``0`` indicates sequence B is a continuation of sequence A,
1143
+ ``1`` indicates sequence B is a random sequence.
1144
+
1145
+ Returns:
1146
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1147
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
1148
+ Next sequence prediction (classification) loss.
1149
+ seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
1150
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
1151
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1152
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1153
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1154
+
1155
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1156
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1157
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1158
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1159
+
1160
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1161
+ heads.
1162
+
1163
+ Examples::
1164
+
1165
+ >>> from transformers import BertTokenizer, BertForNextSentencePrediction
1166
+ >>> import torch
1167
+
1168
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1169
+ >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
1170
+
1171
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1172
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1173
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
1174
+
1175
+ >>> loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1]))
1176
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1177
+ """
1178
+
1179
+ outputs = self.bert(
1180
+ input_ids,
1181
+ attention_mask=attention_mask,
1182
+ token_type_ids=token_type_ids,
1183
+ position_ids=position_ids,
1184
+ head_mask=head_mask,
1185
+ inputs_embeds=inputs_embeds,
1186
+ output_attentions=output_attentions,
1187
+ output_hidden_states=output_hidden_states,
1188
+ )
1189
+
1190
+ pooled_output = outputs[1]
1191
+
1192
+ seq_relationship_score = self.cls(pooled_output)
1193
+
1194
+ outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
1195
+ if next_sentence_label is not None:
1196
+ loss_fct = CrossEntropyLoss()
1197
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1198
+ outputs = (next_sentence_loss,) + outputs
1199
+
1200
+ return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
1201
+
1202
+
1203
+ @add_start_docstrings(
1204
+ """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
1205
+ the pooled output) e.g. for GLUE tasks. """,
1206
+ BERT_START_DOCSTRING,
1207
+ )
1208
+ class BertForSequenceClassification(BertPreTrainedModel):
1209
+ def __init__(self, config):
1210
+ super().__init__(config)
1211
+ self.num_labels = config.num_labels
1212
+
1213
+ self.bert = BertModel(config)
1214
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1215
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1216
+
1217
+ self.init_weights()
1218
+
1219
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1220
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1221
+ def forward(
1222
+ self,
1223
+ input_ids=None,
1224
+ attention_mask=None,
1225
+ token_type_ids=None,
1226
+ position_ids=None,
1227
+ head_mask=None,
1228
+ inputs_embeds=None,
1229
+ labels=None,
1230
+ output_attentions=None,
1231
+ output_hidden_states=None,
1232
+ ):
1233
+ r"""
1234
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1235
+ Labels for computing the sequence classification/regression loss.
1236
+ Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
1237
+ If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1238
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1239
+
1240
+ Returns:
1241
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1242
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
1243
+ Classification (or regression if config.num_labels==1) loss.
1244
+ logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
1245
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
1246
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1247
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1248
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1249
+
1250
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1251
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1252
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1253
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1254
+
1255
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1256
+ heads.
1257
+ """
1258
+
1259
+ outputs = self.bert(
1260
+ input_ids,
1261
+ attention_mask=attention_mask,
1262
+ token_type_ids=token_type_ids,
1263
+ position_ids=position_ids,
1264
+ head_mask=head_mask,
1265
+ inputs_embeds=inputs_embeds,
1266
+ output_attentions=output_attentions,
1267
+ output_hidden_states=output_hidden_states,
1268
+ )
1269
+
1270
+ pooled_output = outputs[1]
1271
+
1272
+ pooled_output = self.dropout(pooled_output)
1273
+ logits = self.classifier(pooled_output)
1274
+
1275
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
1276
+
1277
+ if labels is not None:
1278
+ if self.num_labels == 1:
1279
+ # We are doing regression
1280
+ loss_fct = MSELoss()
1281
+ loss = loss_fct(logits.view(-1), labels.view(-1))
1282
+ else:
1283
+ loss_fct = CrossEntropyLoss()
1284
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1285
+ outputs = (loss,) + outputs
1286
+
1287
+ return outputs # (loss), logits, (hidden_states), (attentions)
1288
+
1289
+
1290
+ @add_start_docstrings(
1291
+ """Bert Model with a multiple choice classification head on top (a linear layer on top of
1292
+ the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
1293
+ BERT_START_DOCSTRING,
1294
+ )
1295
+ class BertForMultipleChoice(BertPreTrainedModel):
1296
+ def __init__(self, config):
1297
+ super().__init__(config)
1298
+
1299
+ self.bert = BertModel(config)
1300
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1301
+ self.classifier = nn.Linear(config.hidden_size, 1)
1302
+
1303
+ self.init_weights()
1304
+
1305
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
1306
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1307
+ def forward(
1308
+ self,
1309
+ input_ids=None,
1310
+ attention_mask=None,
1311
+ token_type_ids=None,
1312
+ position_ids=None,
1313
+ head_mask=None,
1314
+ inputs_embeds=None,
1315
+ labels=None,
1316
+ output_attentions=None,
1317
+ output_hidden_states=None,
1318
+ ):
1319
+ r"""
1320
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1321
+ Labels for computing the multiple choice classification loss.
1322
+ Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
1323
+ of the input tensors. (see `input_ids` above)
1324
+
1325
+ Returns:
1326
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1327
+ loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
1328
+ Classification loss.
1329
+ classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
1330
+ `num_choices` is the second dimension of the input tensors. (see `input_ids` above).
1331
+
1332
+ Classification scores (before SoftMax).
1333
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1334
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1335
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1336
+
1337
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1338
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1339
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1340
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1341
+
1342
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1343
+ heads.
1344
+ """
1345
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1346
+
1347
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1348
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1349
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1350
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1351
+ inputs_embeds = (
1352
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1353
+ if inputs_embeds is not None
1354
+ else None
1355
+ )
1356
+
1357
+ outputs = self.bert(
1358
+ input_ids,
1359
+ attention_mask=attention_mask,
1360
+ token_type_ids=token_type_ids,
1361
+ position_ids=position_ids,
1362
+ head_mask=head_mask,
1363
+ inputs_embeds=inputs_embeds,
1364
+ output_attentions=output_attentions,
1365
+ output_hidden_states=output_hidden_states,
1366
+ )
1367
+
1368
+ pooled_output = outputs[1]
1369
+
1370
+ pooled_output = self.dropout(pooled_output)
1371
+ logits = self.classifier(pooled_output)
1372
+ reshaped_logits = logits.view(-1, num_choices)
1373
+
1374
+ outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
1375
+
1376
+ if labels is not None:
1377
+ loss_fct = CrossEntropyLoss()
1378
+ loss = loss_fct(reshaped_logits, labels)
1379
+ outputs = (loss,) + outputs
1380
+
1381
+ return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
1382
+
1383
+
1384
+ @add_start_docstrings(
1385
+ """Bert Model with a token classification head on top (a linear layer on top of
1386
+ the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1387
+ BERT_START_DOCSTRING,
1388
+ )
1389
+ class BertForTokenClassification(BertPreTrainedModel):
1390
+ def __init__(self, config):
1391
+ super().__init__(config)
1392
+ self.num_labels = config.num_labels
1393
+
1394
+ self.bert = BertModel(config)
1395
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1396
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1397
+
1398
+ self.init_weights()
1399
+
1400
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1401
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1402
+ def forward(
1403
+ self,
1404
+ input_ids=None,
1405
+ attention_mask=None,
1406
+ token_type_ids=None,
1407
+ position_ids=None,
1408
+ head_mask=None,
1409
+ inputs_embeds=None,
1410
+ labels=None,
1411
+ output_attentions=None,
1412
+ output_hidden_states=None,
1413
+ ):
1414
+ r"""
1415
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1416
+ Labels for computing the token classification loss.
1417
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
1418
+
1419
+ Returns:
1420
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1421
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
1422
+ Classification loss.
1423
+ scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
1424
+ Classification scores (before SoftMax).
1425
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1426
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1427
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1428
+
1429
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1430
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1431
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1432
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1433
+
1434
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1435
+ heads.
1436
+ """
1437
+
1438
+ outputs = self.bert(
1439
+ input_ids,
1440
+ attention_mask=attention_mask,
1441
+ token_type_ids=token_type_ids,
1442
+ position_ids=position_ids,
1443
+ head_mask=head_mask,
1444
+ inputs_embeds=inputs_embeds,
1445
+ output_attentions=output_attentions,
1446
+ output_hidden_states=output_hidden_states,
1447
+ )
1448
+
1449
+ sequence_output = outputs[0]
1450
+
1451
+ sequence_output = self.dropout(sequence_output)
1452
+ logits = self.classifier(sequence_output)
1453
+
1454
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
1455
+ if labels is not None:
1456
+ loss_fct = CrossEntropyLoss()
1457
+ # Only keep active parts of the loss
1458
+ if attention_mask is not None:
1459
+ active_loss = attention_mask.view(-1) == 1
1460
+ active_logits = logits.view(-1, self.num_labels)
1461
+ active_labels = torch.where(
1462
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
1463
+ )
1464
+ loss = loss_fct(active_logits, active_labels)
1465
+ else:
1466
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1467
+ outputs = (loss,) + outputs
1468
+
1469
+ return outputs # (loss), scores, (hidden_states), (attentions)
1470
+
1471
+
1472
+ @add_start_docstrings(
1473
+ """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1474
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
1475
+ BERT_START_DOCSTRING,
1476
+ )
1477
+ class BertForQuestionAnswering(BertPreTrainedModel):
1478
+ def __init__(self, config):
1479
+ super().__init__(config)
1480
+ self.num_labels = config.num_labels
1481
+
1482
+ self.bert = BertModel(config)
1483
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1484
+
1485
+ self.init_weights()
1486
+
1487
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1488
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1489
+ def forward(
1490
+ self,
1491
+ input_ids=None,
1492
+ attention_mask=None,
1493
+ token_type_ids=None,
1494
+ position_ids=None,
1495
+ head_mask=None,
1496
+ inputs_embeds=None,
1497
+ start_positions=None,
1498
+ end_positions=None,
1499
+ output_attentions=None,
1500
+ output_hidden_states=None,
1501
+ ):
1502
+ r"""
1503
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1504
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1505
+ Positions are clamped to the length of the sequence (`sequence_length`).
1506
+ Position outside of the sequence are not taken into account for computing the loss.
1507
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1508
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1509
+ Positions are clamped to the length of the sequence (`sequence_length`).
1510
+ Position outside of the sequence are not taken into account for computing the loss.
1511
+
1512
+ Returns:
1513
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1514
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
1515
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
1516
+ start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
1517
+ Span-start scores (before SoftMax).
1518
+ end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
1519
+ Span-end scores (before SoftMax).
1520
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1521
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1522
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1523
+
1524
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1525
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1526
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1527
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1528
+
1529
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1530
+ heads.
1531
+ """
1532
+
1533
+ outputs = self.bert(
1534
+ input_ids,
1535
+ attention_mask=attention_mask,
1536
+ token_type_ids=token_type_ids,
1537
+ position_ids=position_ids,
1538
+ head_mask=head_mask,
1539
+ inputs_embeds=inputs_embeds,
1540
+ output_attentions=output_attentions,
1541
+ output_hidden_states=output_hidden_states,
1542
+ )
1543
+
1544
+ sequence_output = outputs[0]
1545
+
1546
+ logits = self.qa_outputs(sequence_output)
1547
+ start_logits, end_logits = logits.split(1, dim=-1)
1548
+ start_logits = start_logits.squeeze(-1)
1549
+ end_logits = end_logits.squeeze(-1)
1550
+
1551
+ outputs = (start_logits, end_logits,) + outputs[2:]
1552
+ if start_positions is not None and end_positions is not None:
1553
+ # If we are on multi-GPU, split add a dimension
1554
+ if len(start_positions.size()) > 1:
1555
+ start_positions = start_positions.squeeze(-1)
1556
+ if len(end_positions.size()) > 1:
1557
+ end_positions = end_positions.squeeze(-1)
1558
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1559
+ ignored_index = start_logits.size(1)
1560
+ start_positions.clamp_(0, ignored_index)
1561
+ end_positions.clamp_(0, ignored_index)
1562
+
1563
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1564
+ start_loss = loss_fct(start_logits, start_positions)
1565
+ end_loss = loss_fct(end_logits, end_positions)
1566
+ total_loss = (start_loss + end_loss) / 2
1567
+ outputs = (total_loss,) + outputs
1568
+
1569
+ return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
elia/bert/modeling_utils.py ADDED
@@ -0,0 +1,1268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import logging
19
+ import os
20
+ from typing import Callable, Dict, List, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import Tensor, device, dtype, nn
24
+ from torch.nn import CrossEntropyLoss
25
+ from torch.nn import functional as F
26
+
27
+ from .activations import get_activation
28
+ from .configuration_utils import PretrainedConfig
29
+ from .file_utils import (
30
+ DUMMY_INPUTS,
31
+ TF2_WEIGHTS_NAME,
32
+ TF_WEIGHTS_NAME,
33
+ WEIGHTS_NAME,
34
+ cached_path,
35
+ hf_bucket_url,
36
+ is_remote_url,
37
+ )
38
+ from .generation_utils import GenerationMixin
39
+
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ try:
45
+ from torch.nn import Identity
46
+ except ImportError:
47
+ # Older PyTorch compatibility
48
+ class Identity(nn.Module):
49
+ r"""A placeholder identity operator that is argument-insensitive.
50
+ """
51
+
52
+ def __init__(self, *args, **kwargs):
53
+ super().__init__()
54
+
55
+ def forward(self, input):
56
+ return input
57
+
58
+
59
+ def find_pruneable_heads_and_indices(
60
+ heads: List, n_heads: int, head_size: int, already_pruned_heads: set
61
+ ) -> Tuple[set, "torch.LongTensor"]:
62
+ mask = torch.ones(n_heads, head_size)
63
+ heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
64
+ for head in heads:
65
+ # Compute how many pruned heads are before the head and move the index accordingly
66
+ head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
67
+ mask[head] = 0
68
+ mask = mask.view(-1).contiguous().eq(1)
69
+ index: torch.LongTensor = torch.arange(len(mask))[mask].long()
70
+ return heads, index
71
+
72
+
73
+ class ModuleUtilsMixin:
74
+ """
75
+ A few utilities for torch.nn.Modules, to be used as a mixin.
76
+ """
77
+
78
+ def num_parameters(self, only_trainable: bool = False) -> int:
79
+ """
80
+ Get number of (optionally, trainable) parameters in the module.
81
+ """
82
+ params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
83
+ return sum(p.numel() for p in params)
84
+
85
+ @staticmethod
86
+ def _hook_rss_memory_pre_forward(module, *args, **kwargs):
87
+ try:
88
+ import psutil
89
+ except (ImportError):
90
+ raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
91
+
92
+ process = psutil.Process(os.getpid())
93
+ mem = process.memory_info()
94
+ module.mem_rss_pre_forward = mem.rss
95
+ return None
96
+
97
+ @staticmethod
98
+ def _hook_rss_memory_post_forward(module, *args, **kwargs):
99
+ try:
100
+ import psutil
101
+ except (ImportError):
102
+ raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
103
+
104
+ process = psutil.Process(os.getpid())
105
+ mem = process.memory_info()
106
+ module.mem_rss_post_forward = mem.rss
107
+ mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
108
+ module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
109
+ return None
110
+
111
+ def add_memory_hooks(self):
112
+ """ Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
113
+ Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
114
+ """
115
+ for module in self.modules():
116
+ module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
117
+ module.register_forward_hook(self._hook_rss_memory_post_forward)
118
+ self.reset_memory_hooks_state()
119
+
120
+ def reset_memory_hooks_state(self):
121
+ for module in self.modules():
122
+ module.mem_rss_diff = 0
123
+ module.mem_rss_post_forward = 0
124
+ module.mem_rss_pre_forward = 0
125
+
126
+ @property
127
+ def device(self) -> device:
128
+ """
129
+ Get torch.device from module, assuming that the whole module has one device.
130
+ """
131
+ try:
132
+ return next(self.parameters()).device
133
+ except StopIteration:
134
+ # For nn.DataParallel compatibility in PyTorch 1.5
135
+
136
+ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
137
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
138
+ return tuples
139
+
140
+ gen = self._named_members(get_members_fn=find_tensor_attributes)
141
+ first_tuple = next(gen)
142
+ return first_tuple[1].device
143
+
144
+ @property
145
+ def dtype(self) -> dtype:
146
+ """
147
+ Get torch.dtype from module, assuming that the whole module has one dtype.
148
+ """
149
+ try:
150
+ return next(self.parameters()).dtype
151
+ except StopIteration:
152
+ # For nn.DataParallel compatibility in PyTorch 1.5
153
+
154
+ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
155
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
156
+ return tuples
157
+
158
+ gen = self._named_members(get_members_fn=find_tensor_attributes)
159
+ first_tuple = next(gen)
160
+ return first_tuple[1].dtype
161
+
162
+ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
163
+ """type: torch.Tensor -> torch.Tensor"""
164
+ if encoder_attention_mask.dim() == 3:
165
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
166
+ if encoder_attention_mask.dim() == 2:
167
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
168
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
169
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
170
+ # /transformer/transformer_layers.py#L270
171
+ # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
172
+ # encoder_extended_attention_mask.transpose(-1, -2))
173
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
174
+
175
+ if self.dtype == torch.float16:
176
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
177
+ elif self.dtype == torch.float32:
178
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
179
+ else:
180
+ raise ValueError(
181
+ "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
182
+ self.dtype
183
+ )
184
+ )
185
+
186
+ return encoder_extended_attention_mask
187
+
188
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor:
189
+ """Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
190
+
191
+ Arguments:
192
+ attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to
193
+ input_shape: tuple, shape of input_ids
194
+ device: torch.Device, usually self.device
195
+
196
+ Returns:
197
+ torch.Tensor with dtype of attention_mask.dtype
198
+ """
199
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
200
+ # ourselves in which case we just need to make it broadcastable to all heads.
201
+ if attention_mask.dim() == 3:
202
+ extended_attention_mask = attention_mask[:, None, :, :]
203
+ elif attention_mask.dim() == 2:
204
+ # Provided a padding mask of dimensions [batch_size, seq_length]
205
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
206
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
207
+ if self.config.is_decoder:
208
+ batch_size, seq_length = input_shape
209
+ seq_ids = torch.arange(seq_length, device=device)
210
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
211
+ # causal and attention masks must have same type with pytorch version < 1.3
212
+ causal_mask = causal_mask.to(attention_mask.dtype)
213
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
214
+ else:
215
+ extended_attention_mask = attention_mask[:, None, None, :]
216
+ else:
217
+ raise ValueError(
218
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
219
+ input_shape, attention_mask.shape
220
+ )
221
+ )
222
+
223
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
224
+ # masked positions, this operation will create a tensor which is 0.0 for
225
+ # positions we want to attend and -10000.0 for masked positions.
226
+ # Since we are adding it to the raw scores before the softmax, this is
227
+ # effectively the same as removing these entirely.
228
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
229
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
230
+ return extended_attention_mask
231
+
232
+ def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor:
233
+ """
234
+ # Prepare head mask if needed
235
+ # 1.0 in head_mask indicate we keep the head
236
+ attention_probs has shape bsz x n_heads x N x N
237
+ Arguments:
238
+ head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads]
239
+ num_hidden_layers: int
240
+ Returns:
241
+ Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
242
+ or list with [None] for each layer
243
+ """
244
+ if head_mask is not None:
245
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
246
+ if is_attention_chunked is True:
247
+ head_mask = head_mask.unsqueeze(-1)
248
+ else:
249
+ head_mask = [None] * num_hidden_layers
250
+
251
+ return head_mask
252
+
253
+ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
254
+ """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
255
+ if head_mask.dim() == 1:
256
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
257
+ head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
258
+ elif head_mask.dim() == 2:
259
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
260
+ assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
261
+ head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
262
+ return head_mask
263
+
264
+
265
+ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
266
+ r""" Base class for all models.
267
+
268
+ :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
269
+ as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
270
+
271
+ Class attributes (overridden by derived classes):
272
+ - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
273
+ - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
274
+
275
+ - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
276
+ - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
277
+ - ``path``: a path (string) to the TensorFlow checkpoint.
278
+
279
+ - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
280
+ """
281
+ config_class = None
282
+ base_model_prefix = ""
283
+
284
+ @property
285
+ def dummy_inputs(self):
286
+ """ Dummy inputs to do a forward pass in the network.
287
+
288
+ Returns:
289
+ torch.Tensor with dummy inputs
290
+ """
291
+ return {"input_ids": torch.tensor(DUMMY_INPUTS)}
292
+
293
+ def __init__(self, config, *inputs, **kwargs):
294
+ super().__init__()
295
+ if not isinstance(config, PretrainedConfig):
296
+ raise ValueError(
297
+ "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
298
+ "To create a model from a pretrained model use "
299
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
300
+ self.__class__.__name__, self.__class__.__name__
301
+ )
302
+ )
303
+ # Save config in model
304
+ self.config = config
305
+
306
+ @property
307
+ def base_model(self):
308
+ return getattr(self, self.base_model_prefix, self)
309
+
310
+ def get_input_embeddings(self):
311
+ """
312
+ Returns the model's input embeddings.
313
+
314
+ Returns:
315
+ :obj:`nn.Module`:
316
+ A torch module mapping vocabulary to hidden states.
317
+ """
318
+ base_model = getattr(self, self.base_model_prefix, self)
319
+ if base_model is not self:
320
+ return base_model.get_input_embeddings()
321
+ else:
322
+ raise NotImplementedError
323
+
324
+ def set_input_embeddings(self, value: nn.Module):
325
+ """
326
+ Set model's input embeddings
327
+
328
+ Args:
329
+ value (:obj:`nn.Module`):
330
+ A module mapping vocabulary to hidden states.
331
+ """
332
+ base_model = getattr(self, self.base_model_prefix, self)
333
+ if base_model is not self:
334
+ base_model.set_input_embeddings(value)
335
+ else:
336
+ raise NotImplementedError
337
+
338
+ def get_output_embeddings(self):
339
+ """
340
+ Returns the model's output embeddings.
341
+
342
+ Returns:
343
+ :obj:`nn.Module`:
344
+ A torch module mapping hidden states to vocabulary.
345
+ """
346
+ return None # Overwrite for models with output embeddings
347
+
348
+ def tie_weights(self):
349
+ """
350
+ Tie the weights between the input embeddings and the output embeddings.
351
+ If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
352
+ the weights instead.
353
+ """
354
+ output_embeddings = self.get_output_embeddings()
355
+ if output_embeddings is not None:
356
+ self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
357
+
358
+ def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
359
+ """ Tie or clone module weights depending of whether we are using TorchScript or not
360
+ """
361
+ if self.config.torchscript:
362
+ output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
363
+ else:
364
+ output_embeddings.weight = input_embeddings.weight
365
+
366
+ if getattr(output_embeddings, "bias", None) is not None:
367
+ output_embeddings.bias.data = torch.nn.functional.pad(
368
+ output_embeddings.bias.data,
369
+ (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],),
370
+ "constant",
371
+ 0,
372
+ )
373
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
374
+ output_embeddings.out_features = input_embeddings.num_embeddings
375
+
376
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
377
+ """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
378
+ Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
379
+
380
+ Arguments:
381
+
382
+ new_num_tokens: (`optional`) int:
383
+ New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
384
+ If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
385
+
386
+ Return: ``torch.nn.Embeddings``
387
+ Pointer to the input tokens Embeddings Module of the model
388
+ """
389
+ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
390
+ model_embeds = base_model._resize_token_embeddings(new_num_tokens)
391
+ if new_num_tokens is None:
392
+ return model_embeds
393
+
394
+ # Update base model and current model config
395
+ self.config.vocab_size = new_num_tokens
396
+ base_model.vocab_size = new_num_tokens
397
+
398
+ # Tie weights again if needed
399
+ self.tie_weights()
400
+
401
+ return model_embeds
402
+
403
+ def _resize_token_embeddings(self, new_num_tokens):
404
+ old_embeddings = self.get_input_embeddings()
405
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
406
+ self.set_input_embeddings(new_embeddings)
407
+ return self.get_input_embeddings()
408
+
409
+ def _get_resized_embeddings(
410
+ self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
411
+ ) -> torch.nn.Embedding:
412
+ """ Build a resized Embedding Module from a provided token Embedding Module.
413
+ Increasing the size will add newly initialized vectors at the end
414
+ Reducing the size will remove vectors from the end
415
+
416
+ Args:
417
+ old_embeddings: ``torch.nn.Embedding``
418
+ Old embeddings to be resized.
419
+ new_num_tokens: (`optional`) int
420
+ New number of tokens in the embedding matrix.
421
+ Increasing the size will add newly initialized vectors at the end
422
+ Reducing the size will remove vectors from the end
423
+ If not provided or None: return the provided token Embedding Module.
424
+ Return: ``torch.nn.Embedding``
425
+ Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
426
+ """
427
+ if new_num_tokens is None:
428
+ return old_embeddings
429
+
430
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
431
+ if old_num_tokens == new_num_tokens:
432
+ return old_embeddings
433
+
434
+ # Build new embeddings
435
+ new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
436
+ new_embeddings.to(old_embeddings.weight.device)
437
+
438
+ # initialize all new embeddings (in particular added tokens)
439
+ self._init_weights(new_embeddings)
440
+
441
+ # Copy token embeddings from the previous weights
442
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
443
+ new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
444
+
445
+ return new_embeddings
446
+
447
+ def init_weights(self):
448
+ """ Initialize and prunes weights if needed. """
449
+ # Initialize weights
450
+ self.apply(self._init_weights)
451
+
452
+ # Prune heads if needed
453
+ if self.config.pruned_heads:
454
+ self.prune_heads(self.config.pruned_heads)
455
+
456
+ # Tie weights if needed
457
+ self.tie_weights()
458
+
459
+ def prune_heads(self, heads_to_prune: Dict):
460
+ """ Prunes heads of the base model.
461
+
462
+ Arguments:
463
+
464
+ heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
465
+ E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
466
+ """
467
+ # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
468
+ for layer, heads in heads_to_prune.items():
469
+ union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
470
+ self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
471
+
472
+ self.base_model._prune_heads(heads_to_prune)
473
+
474
+ def save_pretrained(self, save_directory):
475
+ """ Save a model and its configuration file to a directory, so that it
476
+ can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
477
+
478
+ Arguments:
479
+ save_directory: directory to which to save.
480
+ """
481
+ if os.path.isfile(save_directory):
482
+ logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
483
+ return
484
+ os.makedirs(save_directory, exist_ok=True)
485
+
486
+ # Only save the model itself if we are using distributed training
487
+ model_to_save = self.module if hasattr(self, "module") else self
488
+
489
+ # Attach architecture to the config
490
+ model_to_save.config.architectures = [model_to_save.__class__.__name__]
491
+
492
+ # If we save using the predefined names, we can load using `from_pretrained`
493
+ output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
494
+
495
+ if getattr(self.config, "xla_device", False):
496
+ import torch_xla.core.xla_model as xm
497
+
498
+ if xm.is_master_ordinal():
499
+ # Save configuration file
500
+ model_to_save.config.save_pretrained(save_directory)
501
+ # xm.save takes care of saving only from master
502
+ xm.save(model_to_save.state_dict(), output_model_file)
503
+ else:
504
+ model_to_save.config.save_pretrained(save_directory)
505
+ torch.save(model_to_save.state_dict(), output_model_file)
506
+
507
+ logger.info("Model weights saved in {}".format(output_model_file))
508
+
509
+ @classmethod
510
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
511
+ r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
512
+
513
+ The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
514
+ To train the model, you should first set it back in training mode with ``model.train()``
515
+
516
+ The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
517
+ It is up to you to train those weights with a downstream fine-tuning task.
518
+
519
+ The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
520
+
521
+ Parameters:
522
+ pretrained_model_name_or_path: either:
523
+ - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
524
+ - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
525
+ - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
526
+ - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
527
+ - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
528
+
529
+ model_args: (`optional`) Sequence of positional arguments:
530
+ All remaning positional arguments will be passed to the underlying model's ``__init__`` method
531
+
532
+ config: (`optional`) one of:
533
+ - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
534
+ - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
535
+
536
+ Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
537
+ - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
538
+ - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
539
+ - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
540
+
541
+ state_dict: (`optional`) dict:
542
+ an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
543
+ This option can be used if you want to create a model from a pretrained configuration but load your own weights.
544
+ In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
545
+
546
+ cache_dir: (`optional`) string:
547
+ Path to a directory in which a downloaded pre-trained model
548
+ configuration should be cached if the standard cache should not be used.
549
+
550
+ force_download: (`optional`) boolean, default False:
551
+ Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
552
+
553
+ resume_download: (`optional`) boolean, default False:
554
+ Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
555
+
556
+ proxies: (`optional`) dict, default None:
557
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
558
+ The proxies are used on each request.
559
+
560
+ output_loading_info: (`optional`) boolean:
561
+ Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
562
+
563
+ kwargs: (`optional`) Remaining dictionary of keyword arguments:
564
+ Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
565
+
566
+ - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
567
+ - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
568
+
569
+ Examples::
570
+
571
+ # For example purposes. Not runnable.
572
+ model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
573
+ model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
574
+ model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
575
+ assert model.config.output_attention == True
576
+ # Loading from a TF checkpoint file instead of a PyTorch model (slower)
577
+ config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
578
+ model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
579
+
580
+ """
581
+ config = kwargs.pop("config", None)
582
+ state_dict = kwargs.pop("state_dict", None)
583
+ cache_dir = kwargs.pop("cache_dir", None)
584
+ from_tf = kwargs.pop("from_tf", False)
585
+ force_download = kwargs.pop("force_download", False)
586
+ resume_download = kwargs.pop("resume_download", False)
587
+ proxies = kwargs.pop("proxies", None)
588
+ output_loading_info = kwargs.pop("output_loading_info", False)
589
+ local_files_only = kwargs.pop("local_files_only", False)
590
+ use_cdn = kwargs.pop("use_cdn", True)
591
+
592
+ # Load config if we don't provide a configuration
593
+ if not isinstance(config, PretrainedConfig):
594
+ config_path = config if config is not None else pretrained_model_name_or_path
595
+ config, model_kwargs = cls.config_class.from_pretrained(
596
+ config_path,
597
+ *model_args,
598
+ cache_dir=cache_dir,
599
+ return_unused_kwargs=True,
600
+ force_download=force_download,
601
+ resume_download=resume_download,
602
+ proxies=proxies,
603
+ local_files_only=local_files_only,
604
+ **kwargs,
605
+ )
606
+ else:
607
+ model_kwargs = kwargs
608
+
609
+ # Load model
610
+ if pretrained_model_name_or_path is not None:
611
+ if os.path.isdir(pretrained_model_name_or_path):
612
+ if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
613
+ # Load from a TF 1.0 checkpoint
614
+ archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
615
+ elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
616
+ # Load from a TF 2.0 checkpoint
617
+ archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
618
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
619
+ # Load from a PyTorch checkpoint
620
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
621
+ else:
622
+ raise EnvironmentError(
623
+ "Error no file named {} found in directory {} or `from_tf` set to False".format(
624
+ [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
625
+ pretrained_model_name_or_path,
626
+ )
627
+ )
628
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
629
+ archive_file = pretrained_model_name_or_path
630
+ elif os.path.isfile(pretrained_model_name_or_path + ".index"):
631
+ assert (
632
+ from_tf
633
+ ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
634
+ pretrained_model_name_or_path + ".index"
635
+ )
636
+ archive_file = pretrained_model_name_or_path + ".index"
637
+ else:
638
+ archive_file = hf_bucket_url(
639
+ pretrained_model_name_or_path,
640
+ filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
641
+ use_cdn=use_cdn,
642
+ )
643
+
644
+ try:
645
+ # Load from URL or cache if already cached
646
+ resolved_archive_file = cached_path(
647
+ archive_file,
648
+ cache_dir=cache_dir,
649
+ force_download=force_download,
650
+ proxies=proxies,
651
+ resume_download=resume_download,
652
+ local_files_only=local_files_only,
653
+ )
654
+ if resolved_archive_file is None:
655
+ raise EnvironmentError
656
+ except EnvironmentError:
657
+ msg = (
658
+ f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
659
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
660
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
661
+ )
662
+ raise EnvironmentError(msg)
663
+
664
+ if resolved_archive_file == archive_file:
665
+ logger.info("loading weights file {}".format(archive_file))
666
+ else:
667
+ logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
668
+ else:
669
+ resolved_archive_file = None
670
+
671
+ # Instantiate model.
672
+ model = cls(config, *model_args, **model_kwargs)
673
+
674
+ if state_dict is None and not from_tf:
675
+ try:
676
+ state_dict = torch.load(resolved_archive_file, map_location="cpu")
677
+ except Exception:
678
+ raise OSError(
679
+ "Unable to load weights from pytorch checkpoint file. "
680
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
681
+ )
682
+
683
+ missing_keys = []
684
+ unexpected_keys = []
685
+ error_msgs = []
686
+
687
+ if from_tf:
688
+ if resolved_archive_file.endswith(".index"):
689
+ # Load from a TensorFlow 1.X checkpoint - provided by original authors
690
+ model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
691
+ else:
692
+ # Load from our TensorFlow 2.0 checkpoints
693
+ try:
694
+ from transformers import load_tf2_checkpoint_in_pytorch_model
695
+
696
+ model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
697
+ except ImportError:
698
+ logger.error(
699
+ "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
700
+ "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
701
+ )
702
+ raise
703
+ else:
704
+ # Convert old format to new format if needed from a PyTorch state_dict
705
+ old_keys = []
706
+ new_keys = []
707
+ for key in state_dict.keys():
708
+ new_key = None
709
+ if "gamma" in key:
710
+ new_key = key.replace("gamma", "weight")
711
+ if "beta" in key:
712
+ new_key = key.replace("beta", "bias")
713
+ if new_key:
714
+ old_keys.append(key)
715
+ new_keys.append(new_key)
716
+ for old_key, new_key in zip(old_keys, new_keys):
717
+ state_dict[new_key] = state_dict.pop(old_key)
718
+
719
+ # copy state_dict so _load_from_state_dict can modify it
720
+ metadata = getattr(state_dict, "_metadata", None)
721
+ state_dict = state_dict.copy()
722
+ if metadata is not None:
723
+ state_dict._metadata = metadata
724
+
725
+ ##############################################################################################
726
+ # Print out state_dict's contents: keys
727
+ '''
728
+ for key, _ in state_dict.items():
729
+ print(key)
730
+ '''
731
+ ##############################################################################################
732
+
733
+
734
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
735
+ # so we need to apply the function recursively.
736
+ def load(module: nn.Module, prefix=""):
737
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
738
+ module._load_from_state_dict(
739
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
740
+ )
741
+ for name, child in module._modules.items():
742
+ if child is not None:
743
+ load(child, prefix + name + ".")
744
+
745
+ # Make sure we are able to load base models as well as derived models (with heads)
746
+ start_prefix = ""
747
+ model_to_load = model
748
+ has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
749
+ if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
750
+ start_prefix = cls.base_model_prefix + "."
751
+ if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
752
+ model_to_load = getattr(model, cls.base_model_prefix)
753
+
754
+ load(model_to_load, prefix=start_prefix)
755
+
756
+ if model.__class__.__name__ != model_to_load.__class__.__name__:
757
+ base_model_state_dict = model_to_load.state_dict().keys()
758
+ head_model_state_dict_without_base_prefix = [
759
+ key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
760
+ ]
761
+
762
+ missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
763
+
764
+ if len(unexpected_keys) > 0:
765
+ logger.warning(
766
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
767
+ f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
768
+ f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
769
+ f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
770
+ f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
771
+ f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
772
+ )
773
+ else:
774
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
775
+ if len(missing_keys) > 0:
776
+ logger.warning(
777
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
778
+ f"and are newly initialized: {missing_keys}\n"
779
+ f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
780
+ )
781
+ else:
782
+ logger.info(
783
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
784
+ f"If your task is similar to the task the model of the ckeckpoint was trained on, "
785
+ f"you can already use {model.__class__.__name__} for predictions without further training."
786
+ )
787
+ if len(error_msgs) > 0:
788
+ raise RuntimeError(
789
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
790
+ model.__class__.__name__, "\n\t".join(error_msgs)
791
+ )
792
+ )
793
+ model.tie_weights() # make sure token embedding weights are still tied if needed
794
+
795
+ # Set model in evaluation mode to deactivate DropOut modules by default
796
+ model.eval()
797
+
798
+ if output_loading_info:
799
+ loading_info = {
800
+ "missing_keys": missing_keys,
801
+ "unexpected_keys": unexpected_keys,
802
+ "error_msgs": error_msgs,
803
+ }
804
+ return model, loading_info
805
+
806
+ if hasattr(config, "xla_device") and config.xla_device:
807
+ import torch_xla.core.xla_model as xm
808
+
809
+ model = xm.send_cpu_data_to_device(model, xm.xla_device())
810
+ model.to(xm.xla_device())
811
+
812
+ return model
813
+
814
+
815
+ class Conv1D(nn.Module):
816
+ def __init__(self, nf, nx):
817
+ """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
818
+ Basically works like a Linear layer but the weights are transposed
819
+ """
820
+ super().__init__()
821
+ self.nf = nf
822
+ w = torch.empty(nx, nf)
823
+ nn.init.normal_(w, std=0.02)
824
+ self.weight = nn.Parameter(w)
825
+ self.bias = nn.Parameter(torch.zeros(nf))
826
+
827
+ def forward(self, x):
828
+ size_out = x.size()[:-1] + (self.nf,)
829
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
830
+ x = x.view(*size_out)
831
+ return x
832
+
833
+
834
+ class PoolerStartLogits(nn.Module):
835
+ """ Compute SQuAD start_logits from sequence hidden states. """
836
+
837
+ def __init__(self, config):
838
+ super().__init__()
839
+ self.dense = nn.Linear(config.hidden_size, 1)
840
+
841
+ def forward(self, hidden_states, p_mask=None):
842
+ """ Args:
843
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
844
+ invalid position mask such as query and special symbols (PAD, SEP, CLS)
845
+ 1.0 means token should be masked.
846
+ """
847
+ x = self.dense(hidden_states).squeeze(-1)
848
+
849
+ if p_mask is not None:
850
+ if next(self.parameters()).dtype == torch.float16:
851
+ x = x * (1 - p_mask) - 65500 * p_mask
852
+ else:
853
+ x = x * (1 - p_mask) - 1e30 * p_mask
854
+
855
+ return x
856
+
857
+
858
+ class PoolerEndLogits(nn.Module):
859
+ """ Compute SQuAD end_logits from sequence hidden states and start token hidden state.
860
+ """
861
+
862
+ def __init__(self, config):
863
+ super().__init__()
864
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
865
+ self.activation = nn.Tanh()
866
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
867
+ self.dense_1 = nn.Linear(config.hidden_size, 1)
868
+
869
+ def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
870
+ """ Args:
871
+ One of ``start_states``, ``start_positions`` should be not None.
872
+ If both are set, ``start_positions`` overrides ``start_states``.
873
+
874
+ **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
875
+ hidden states of the first tokens for the labeled span.
876
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
877
+ position of the first token for the labeled span:
878
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
879
+ Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
880
+ 1.0 means token should be masked.
881
+ """
882
+ assert (
883
+ start_states is not None or start_positions is not None
884
+ ), "One of start_states, start_positions should be not None"
885
+ if start_positions is not None:
886
+ slen, hsz = hidden_states.shape[-2:]
887
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
888
+ start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
889
+ start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
890
+
891
+ x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
892
+ x = self.activation(x)
893
+ x = self.LayerNorm(x)
894
+ x = self.dense_1(x).squeeze(-1)
895
+
896
+ if p_mask is not None:
897
+ if next(self.parameters()).dtype == torch.float16:
898
+ x = x * (1 - p_mask) - 65500 * p_mask
899
+ else:
900
+ x = x * (1 - p_mask) - 1e30 * p_mask
901
+
902
+ return x
903
+
904
+
905
+ class PoolerAnswerClass(nn.Module):
906
+ """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
907
+
908
+ def __init__(self, config):
909
+ super().__init__()
910
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
911
+ self.activation = nn.Tanh()
912
+ self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
913
+
914
+ def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
915
+ """
916
+ Args:
917
+ One of ``start_states``, ``start_positions`` should be not None.
918
+ If both are set, ``start_positions`` overrides ``start_states``.
919
+
920
+ **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
921
+ hidden states of the first tokens for the labeled span.
922
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
923
+ position of the first token for the labeled span.
924
+ **cls_index**: torch.LongTensor of shape ``(batch_size,)``
925
+ position of the CLS token. If None, take the last token.
926
+
927
+ note(Original repo):
928
+ no dependency on end_feature so that we can obtain one single `cls_logits`
929
+ for each sample
930
+ """
931
+ hsz = hidden_states.shape[-1]
932
+ assert (
933
+ start_states is not None or start_positions is not None
934
+ ), "One of start_states, start_positions should be not None"
935
+ if start_positions is not None:
936
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
937
+ start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
938
+
939
+ if cls_index is not None:
940
+ cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
941
+ cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
942
+ else:
943
+ cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
944
+
945
+ x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
946
+ x = self.activation(x)
947
+ x = self.dense_1(x).squeeze(-1)
948
+
949
+ return x
950
+
951
+
952
+ class SQuADHead(nn.Module):
953
+ r""" A SQuAD head inspired by XLNet.
954
+
955
+ Parameters:
956
+ config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
957
+
958
+ Inputs:
959
+ **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
960
+ hidden states of sequence tokens
961
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
962
+ position of the first token for the labeled span.
963
+ **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
964
+ position of the last token for the labeled span.
965
+ **cls_index**: torch.LongTensor of shape ``(batch_size,)``
966
+ position of the CLS token. If None, take the last token.
967
+ **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
968
+ Whether the question has a possible answer in the paragraph or not.
969
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
970
+ Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
971
+ 1.0 means token should be masked.
972
+
973
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
974
+ **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
975
+ Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
976
+ **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
977
+ ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
978
+ Log probabilities for the top config.start_n_top start token possibilities (beam-search).
979
+ **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
980
+ ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
981
+ Indices for the top config.start_n_top start token possibilities (beam-search).
982
+ **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
983
+ ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
984
+ Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
985
+ **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
986
+ ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
987
+ Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
988
+ **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
989
+ ``torch.FloatTensor`` of shape ``(batch_size,)``
990
+ Log probabilities for the ``is_impossible`` label of the answers.
991
+ """
992
+
993
+ def __init__(self, config):
994
+ super().__init__()
995
+ self.start_n_top = config.start_n_top
996
+ self.end_n_top = config.end_n_top
997
+
998
+ self.start_logits = PoolerStartLogits(config)
999
+ self.end_logits = PoolerEndLogits(config)
1000
+ self.answer_class = PoolerAnswerClass(config)
1001
+
1002
+ def forward(
1003
+ self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
1004
+ ):
1005
+ outputs = ()
1006
+
1007
+ start_logits = self.start_logits(hidden_states, p_mask=p_mask)
1008
+
1009
+ if start_positions is not None and end_positions is not None:
1010
+ # If we are on multi-GPU, let's remove the dimension added by batch splitting
1011
+ for x in (start_positions, end_positions, cls_index, is_impossible):
1012
+ if x is not None and x.dim() > 1:
1013
+ x.squeeze_(-1)
1014
+
1015
+ # during training, compute the end logits based on the ground truth of the start position
1016
+ end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
1017
+
1018
+ loss_fct = CrossEntropyLoss()
1019
+ start_loss = loss_fct(start_logits, start_positions)
1020
+ end_loss = loss_fct(end_logits, end_positions)
1021
+ total_loss = (start_loss + end_loss) / 2
1022
+
1023
+ if cls_index is not None and is_impossible is not None:
1024
+ # Predict answerability from the representation of CLS and START
1025
+ cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
1026
+ loss_fct_cls = nn.BCEWithLogitsLoss()
1027
+ cls_loss = loss_fct_cls(cls_logits, is_impossible)
1028
+
1029
+ # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
1030
+ total_loss += cls_loss * 0.5
1031
+
1032
+ outputs = (total_loss,) + outputs
1033
+
1034
+ else:
1035
+ # during inference, compute the end logits based on beam search
1036
+ bsz, slen, hsz = hidden_states.size()
1037
+ start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
1038
+
1039
+ start_top_log_probs, start_top_index = torch.topk(
1040
+ start_log_probs, self.start_n_top, dim=-1
1041
+ ) # shape (bsz, start_n_top)
1042
+ start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
1043
+ start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
1044
+ start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
1045
+
1046
+ hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
1047
+ start_states
1048
+ ) # shape (bsz, slen, start_n_top, hsz)
1049
+ p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
1050
+ end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
1051
+ end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
1052
+
1053
+ end_top_log_probs, end_top_index = torch.topk(
1054
+ end_log_probs, self.end_n_top, dim=1
1055
+ ) # shape (bsz, end_n_top, start_n_top)
1056
+ end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
1057
+ end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
1058
+
1059
+ start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
1060
+ cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
1061
+
1062
+ outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits,) + outputs
1063
+
1064
+ # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
1065
+ # or (if labels are provided) (total_loss,)
1066
+ return outputs
1067
+
1068
+
1069
+ class SequenceSummary(nn.Module):
1070
+ r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
1071
+ Args of the config class:
1072
+ summary_type:
1073
+ - 'last' => [default] take the last token hidden state (like XLNet)
1074
+ - 'first' => take the first token hidden state (like Bert)
1075
+ - 'mean' => take the mean of all tokens hidden states
1076
+ - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
1077
+ - 'attn' => Not implemented now, use multi-head attention
1078
+ summary_use_proj: Add a projection after the vector extraction
1079
+ summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
1080
+ summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
1081
+ summary_first_dropout: Add a dropout before the projection and activation
1082
+ summary_last_dropout: Add a dropout after the projection and activation
1083
+ """
1084
+
1085
+ def __init__(self, config: PretrainedConfig):
1086
+ super().__init__()
1087
+
1088
+ self.summary_type = getattr(config, "summary_type", "last")
1089
+ if self.summary_type == "attn":
1090
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
1091
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
1092
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
1093
+ raise NotImplementedError
1094
+
1095
+ self.summary = Identity()
1096
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
1097
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
1098
+ num_classes = config.num_labels
1099
+ else:
1100
+ num_classes = config.hidden_size
1101
+ self.summary = nn.Linear(config.hidden_size, num_classes)
1102
+
1103
+ activation_string = getattr(config, "summary_activation", None)
1104
+ self.activation: Callable = (get_activation(activation_string) if activation_string else Identity())
1105
+
1106
+ self.first_dropout = Identity()
1107
+ if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
1108
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
1109
+
1110
+ self.last_dropout = Identity()
1111
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
1112
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
1113
+
1114
+ def forward(self, hidden_states, cls_index=None):
1115
+ """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
1116
+ cls_index: [optional] position of the classification token if summary_type == 'cls_index',
1117
+ shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
1118
+ if summary_type == 'cls_index' and cls_index is None:
1119
+ we take the last token of the sequence as classification token
1120
+ """
1121
+ if self.summary_type == "last":
1122
+ output = hidden_states[:, -1]
1123
+ elif self.summary_type == "first":
1124
+ output = hidden_states[:, 0]
1125
+ elif self.summary_type == "mean":
1126
+ output = hidden_states.mean(dim=1)
1127
+ elif self.summary_type == "cls_index":
1128
+ if cls_index is None:
1129
+ cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long,)
1130
+ else:
1131
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
1132
+ cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
1133
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
1134
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
1135
+ elif self.summary_type == "attn":
1136
+ raise NotImplementedError
1137
+
1138
+ output = self.first_dropout(output)
1139
+ output = self.summary(output)
1140
+ output = self.activation(output)
1141
+ output = self.last_dropout(output)
1142
+
1143
+ return output
1144
+
1145
+
1146
+ def prune_linear_layer(layer, index, dim=0):
1147
+ """ Prune a linear layer (a model parameters) to keep only entries in index.
1148
+ Return the pruned layer as a new layer with requires_grad=True.
1149
+ Used to remove heads.
1150
+ """
1151
+ index = index.to(layer.weight.device)
1152
+ W = layer.weight.index_select(dim, index).clone().detach()
1153
+ if layer.bias is not None:
1154
+ if dim == 1:
1155
+ b = layer.bias.clone().detach()
1156
+ else:
1157
+ b = layer.bias[index].clone().detach()
1158
+ new_size = list(layer.weight.size())
1159
+ new_size[dim] = len(index)
1160
+ new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
1161
+ new_layer.weight.requires_grad = False
1162
+ new_layer.weight.copy_(W.contiguous())
1163
+ new_layer.weight.requires_grad = True
1164
+ if layer.bias is not None:
1165
+ new_layer.bias.requires_grad = False
1166
+ new_layer.bias.copy_(b.contiguous())
1167
+ new_layer.bias.requires_grad = True
1168
+ return new_layer
1169
+
1170
+
1171
+ def prune_conv1d_layer(layer, index, dim=1):
1172
+ """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
1173
+ A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
1174
+ Return the pruned layer as a new layer with requires_grad=True.
1175
+ Used to remove heads.
1176
+ """
1177
+ index = index.to(layer.weight.device)
1178
+ W = layer.weight.index_select(dim, index).clone().detach()
1179
+ if dim == 0:
1180
+ b = layer.bias.clone().detach()
1181
+ else:
1182
+ b = layer.bias[index].clone().detach()
1183
+ new_size = list(layer.weight.size())
1184
+ new_size[dim] = len(index)
1185
+ new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
1186
+ new_layer.weight.requires_grad = False
1187
+ new_layer.weight.copy_(W.contiguous())
1188
+ new_layer.weight.requires_grad = True
1189
+ new_layer.bias.requires_grad = False
1190
+ new_layer.bias.copy_(b.contiguous())
1191
+ new_layer.bias.requires_grad = True
1192
+ return new_layer
1193
+
1194
+
1195
+ def prune_layer(layer, index, dim=None):
1196
+ """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
1197
+ Return the pruned layer as a new layer with requires_grad=True.
1198
+ Used to remove heads.
1199
+ """
1200
+ if isinstance(layer, nn.Linear):
1201
+ return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
1202
+ elif isinstance(layer, Conv1D):
1203
+ return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
1204
+ else:
1205
+ raise ValueError("Can't prune layer of class {}".format(layer.__class__))
1206
+
1207
+
1208
+ def apply_chunking_to_forward(
1209
+ chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
1210
+ ) -> torch.Tensor:
1211
+ """
1212
+ This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`.
1213
+ It then applies a layer `forward_fn` to each chunk independently to save memory.
1214
+ If the `forward_fn` is independent across the `chunk_dim` this function will yield the
1215
+ same result as not applying it.
1216
+
1217
+ Args:
1218
+ chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size`
1219
+ chunk_dim: int - the dimension over which the input_tensors should be chunked
1220
+ forward_fn: fn - the forward fn of the model
1221
+ input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked
1222
+ Returns:
1223
+ a Tensor with the same shape the foward_fn would have given if applied
1224
+
1225
+
1226
+ Examples::
1227
+
1228
+ # rename the usual forward() fn to forward_chunk()
1229
+ def forward_chunk(self, hidden_states):
1230
+ hidden_states = self.decoder(hidden_states)
1231
+ return hidden_states
1232
+
1233
+ # implement a chunked forward function
1234
+ def forward(self, hidden_states):
1235
+ return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
1236
+ """
1237
+
1238
+ assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
1239
+ tensor_shape = input_tensors[0].shape
1240
+ assert all(
1241
+ input_tensor.shape == tensor_shape for input_tensor in input_tensors
1242
+ ), "All input tenors have to be of the same shape"
1243
+
1244
+ # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability
1245
+ num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
1246
+ assert num_args_in_forward_chunk_fn == len(
1247
+ input_tensors
1248
+ ), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format(
1249
+ num_args_in_forward_chunk_fn, len(input_tensors)
1250
+ )
1251
+
1252
+ if chunk_size > 0:
1253
+ assert (
1254
+ input_tensors[0].shape[chunk_dim] % chunk_size == 0
1255
+ ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
1256
+ input_tensors[0].shape[chunk_dim], chunk_size
1257
+ )
1258
+
1259
+ num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
1260
+
1261
+ # chunk input tensor into tuples
1262
+ input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
1263
+ # apply forward fn to every tuple
1264
+ output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
1265
+ # concatenate output at same dimension
1266
+ return torch.cat(output_chunks, dim=chunk_dim)
1267
+
1268
+ return forward_fn(*input_tensors)
elia/bert/multimodal_bert.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .modeling_bert import BertModel
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+
9
+ class MultiModalBert(BertModel):
10
+ def __init__(self, config, embed_dim, pwam_idx=[3,6,9,12], num_heads_fusion=[1,1,1,1], fusion_drop=0.0):
11
+ super().__init__(config)
12
+ self.pwam_idx = pwam_idx
13
+ self.num_heads_fusion = num_heads_fusion
14
+ self.fusion_drop = fusion_drop
15
+
16
+ pwam_dims=[embed_dim * 2** i for i in range(len(pwam_idx))]
17
+ #print(pwam_dims)
18
+ self.pwams = nn.ModuleList()
19
+ self.res_gates = nn.ModuleList()
20
+ self.norms = nn.ModuleList()
21
+ for i in range(0, len(pwam_idx)):
22
+ dim = pwam_dims[i]
23
+ fusion = PWAM(768, # both the visual input and for combining, num of channels
24
+ dim, # v_in
25
+ 768, # l_in
26
+ 768, # key
27
+ 768, # value
28
+ num_heads=num_heads_fusion[i],
29
+ dropout=fusion_drop)
30
+ self.pwams.append(fusion)
31
+
32
+ res_gate = nn.Sequential(
33
+ nn.Linear(768, 768, bias=False),
34
+ nn.ReLU(),
35
+ nn.Linear(768, 768, bias=False),
36
+ nn.Tanh()
37
+ )
38
+ nn.init.zeros_(res_gate[0].weight)
39
+ nn.init.zeros_(res_gate[2].weight)
40
+ self.res_gates.append(res_gate)
41
+
42
+ self.norms.append(nn.LayerNorm(768))
43
+
44
+ def forward_stem(self, input_ids, attention_mask):
45
+ input_shape = input_ids.size()
46
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
47
+
48
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, input_ids.device)
49
+
50
+ embedding_output = self.embeddings(
51
+ input_ids=input_ids, token_type_ids=token_type_ids
52
+ )
53
+ #print(embedding_output.shape, extended_attention_mask.shape, "?>>>")
54
+ return embedding_output, extended_attention_mask
55
+
56
+ def forward_stage1(self, hidden_states, attention_mask):
57
+ for i in range(0, self.pwam_idx[0]):
58
+ layer_module = self.encoder.layer[i]
59
+ layer_outputs = layer_module(
60
+ hidden_states,
61
+ attention_mask,
62
+ )
63
+ hidden_states = layer_outputs[0]
64
+
65
+ return layer_outputs[0]
66
+
67
+ def forward_stage2(self, hidden_states, attention_mask):
68
+ for i in range(self.pwam_idx[0], self.pwam_idx[1]):
69
+ layer_module = self.encoder.layer[i]
70
+ layer_outputs = layer_module(
71
+ hidden_states,
72
+ attention_mask,
73
+ )
74
+ hidden_states = layer_outputs[0]
75
+
76
+ return layer_outputs[0]
77
+
78
+ def forward_stage3(self, hidden_states, attention_mask):
79
+ for i in range(self.pwam_idx[1], self.pwam_idx[2]):
80
+ layer_module = self.encoder.layer[i]
81
+ layer_outputs = layer_module(
82
+ hidden_states,
83
+ attention_mask,
84
+ )
85
+ hidden_states = layer_outputs[0]
86
+
87
+ return layer_outputs[0]
88
+
89
+ def forward_stage4(self, hidden_states, attention_mask):
90
+ for i in range(self.pwam_idx[2], self.pwam_idx[3]):
91
+ layer_module = self.encoder.layer[i]
92
+ layer_outputs = layer_module(
93
+ hidden_states,
94
+ attention_mask,
95
+ )
96
+ hidden_states = layer_outputs[0]
97
+
98
+ return layer_outputs[0]
99
+
100
+ def forward_pwam1(self, x, l, l_mask):
101
+ l_residual = self.pwams[0](x, l, l_mask)
102
+ l = l + (self.res_gates[0](l_residual) * l_residual)
103
+ return self.norms[0](l_residual), l
104
+
105
+ def forward_pwam2(self, x, l, l_mask):
106
+ l_residual = self.pwams[1](x, l, l_mask)
107
+ l = l + (self.res_gates[1](l_residual) * l_residual)
108
+ return self.norms[1](l_residual), l
109
+
110
+ def forward_pwam3(self, x, l, l_mask):
111
+ l_residual = self.pwams[2](x, l, l_mask)
112
+ l = l + (self.res_gates[2](l_residual) * l_residual)
113
+ return self.norms[2](l_residual), l
114
+
115
+ def forward_pwam4(self, x, l, l_mask):
116
+ l_residual = self.pwams[3](x, l, l_mask)
117
+ l = l + (self.res_gates[3](l_residual) * l_residual)
118
+ return self.norms[3](l_residual), l
119
+
120
+ class PWAM(nn.Module):
121
+ def __init__(self, dim, v_in_channels, l_in_channels, key_channels, value_channels, num_heads=0, dropout=0.0):
122
+ super(PWAM, self).__init__()
123
+ # input x shape: (B, H*W, dim)
124
+ #self.vis_project = nn.Sequential(nn.Conv1d(dim, dim, 1, 1), # the init function sets bias to 0 if bias is True
125
+ # nn.GELU(),
126
+ # nn.Dropout(dropout)
127
+ # )
128
+ #self.vis_project = nn.Sequential(nn.Conv1d(dim, dim, 1, 1), # the init function sets bias to 0 if bias is True
129
+ self.vis_project = nn.Sequential(nn.Linear(dim, dim), # the init function sets bias to 0 if bias is True
130
+ nn.GELU(),
131
+ nn.Dropout(dropout)
132
+ )
133
+
134
+ self.image_lang_att = SpatialImageLanguageAttention(v_in_channels, # v_in
135
+ l_in_channels, # l_in
136
+ key_channels, # key
137
+ value_channels, # value
138
+ out_channels=value_channels, # out
139
+ num_heads=num_heads)
140
+
141
+ self.project_mm = nn.Sequential(nn.Conv1d(value_channels, value_channels, 1, 1),
142
+ nn.GELU(),
143
+ nn.Dropout(dropout)
144
+ )
145
+
146
+ def forward(self, x, l, l_mask):
147
+ # input x shape: (B, H*W, dim)
148
+ #print("???", x.shape, l.shape, l_mask.shape)
149
+ #print(self.vis_project)
150
+ #vis = self.vis_project(x.permute(0, 2, 1)) # (B, dim, H*W)
151
+ vis = self.vis_project(l) # (B, dim, H*W)
152
+
153
+ lang = self.image_lang_att(x, l, l_mask) # (B, H*W, dim)
154
+
155
+ lang = lang.permute(0, 2, 1) # (B, dim, H*W)
156
+
157
+ #print("vis", vis.shape, "lang", lang.shape)
158
+ mm = torch.mul(vis.permute(0,2,1), lang)
159
+ #print(mm.shape)
160
+ mm = self.project_mm(mm) # (B, dim, H*W)
161
+
162
+ mm = mm.permute(0, 2, 1) # (B, H*W, dim)
163
+
164
+ return mm
165
+
166
+ #self.fusion = PWAM(dim, # both the visual input and for combining, num of channels
167
+ # dim, # v_in
168
+ # 768, # l_in
169
+ # dim, # key
170
+ # dim, # value
171
+ # num_heads=num_heads_fusion,
172
+ # dropout=fusion_drop)
173
+
174
+ class SpatialImageLanguageAttention(nn.Module):
175
+ def __init__(self, v_in_channels, l_in_channels, key_channels, value_channels, out_channels=None, num_heads=1):
176
+ super(SpatialImageLanguageAttention, self).__init__()
177
+ # x shape: (B, H*W, v_in_channels)
178
+ # l input shape: (B, l_in_channels, N_l)
179
+ # l_mask shape: (B, N_l, 1)
180
+ self.v_in_channels = v_in_channels
181
+ self.l_in_channels = l_in_channels
182
+ self.out_channels = out_channels
183
+ self.key_channels = key_channels
184
+ self.value_channels = value_channels
185
+ self.num_heads = num_heads
186
+ if out_channels is None:
187
+ self.out_channels = self.value_channels
188
+
189
+ # Keys: language features: (B, l_in_channels, #words)
190
+ # avoid any form of spatial normalization because a sentence contains many padding 0s
191
+ self.f_query = nn.Sequential(
192
+ nn.Conv1d(self.l_in_channels, self.key_channels, kernel_size=1, stride=1),
193
+ )
194
+
195
+ # Queries: visual features: (B, H*W, v_in_channels)
196
+ self.f_key = nn.Sequential(
197
+ nn.Conv1d(self.v_in_channels, self.key_channels, kernel_size=1, stride=1),
198
+ nn.InstanceNorm1d(self.key_channels),
199
+ )
200
+
201
+ # Values: language features: (B, l_in_channels, #words)
202
+ #self.f_value = nn.Sequential(
203
+ # nn.Conv1d(self.l_in_channels, self.value_channels, kernel_size=1, stride=1),
204
+ #)
205
+ self.f_value = nn.Sequential(
206
+ nn.Conv1d(self.v_in_channels, self.key_channels, kernel_size=1, stride=1),
207
+ nn.InstanceNorm1d(self.key_channels),
208
+ )
209
+
210
+ # Out projection
211
+ self.W = nn.Sequential(
212
+ nn.Conv1d(self.value_channels, self.out_channels, kernel_size=1, stride=1),
213
+ nn.InstanceNorm1d(self.out_channels),
214
+ )
215
+
216
+ def forward(self, x, l, l_mask):
217
+ #print('input shape', x.shape, l.shape, l_mask.shape)
218
+ l_mask = l_mask.squeeze(1)
219
+ # x shape: (B, H*W, v_in_channels)
220
+ # l input shape: (B, l_in_channels, N_l)
221
+ # l_mask shape: (B, N_l, 1)
222
+ B, HW = x.size(0), x.size(1)
223
+ x = x.permute(0, 2, 1) # (B, key_channels, H*W)
224
+ l = l.permute(0,2,1)
225
+ #l_mask = l_mask.permute(0, 2, 1) # (B, N_l, 1) -> (B, 1, N_l)
226
+ l_mask = l_mask # (B, N_l, 1) -> (B, 1, N_l)
227
+
228
+ #query = self.f_query(x) # (B, key_channels, H*W) if Conv1D
229
+ #query = query.permute(0, 2, 1) # (B, H*W, key_channels)
230
+ #key = self.f_key(l) # (B, key_channels, N_l)
231
+ #value = self.f_value(l) # (B, self.value_channels, N_l)
232
+ #key = key * l_mask # (B, key_channels, N_l)
233
+ #value = value * l_mask # (B, self.value_channels, N_l)
234
+
235
+ #print(l.shape, self.f_query)
236
+ query = self.f_query(l) # (B, key_channels, H*W) if Conv1D
237
+ query = query * l_mask # (B, key_channels, N_l)
238
+ query = query.permute(0, 2, 1) # (B, N_l, key_channels)
239
+
240
+ key = self.f_key(x) # (B, key_channels, H*W) if Conv1D
241
+ value = self.f_value(x) # (B, key_channels, H*W) if Conv1D
242
+
243
+ n_l = query.size(1)
244
+ #print(query.shape, key.shape, value.shape)
245
+
246
+ #query = query.reshape(B, HW, self.num_heads, self.key_channels//self.num_heads).permute(0, 2, 1, 3)
247
+ # (b, num_heads, H*W, self.key_channels//self.num_heads)
248
+ #key = key.reshape(B, self.num_heads, self.key_channels//self.num_heads, n_l)
249
+ # (b, num_heads, self.key_channels//self.num_heads, n_l)
250
+ #value = value.reshape(B, self.num_heads, self.value_channels//self.num_heads, n_l)
251
+ # # (b, num_heads, self.value_channels//self.num_heads, n_l)
252
+ key = key.reshape(B, self.num_heads, self.key_channels//self.num_heads, HW)
253
+ value = value.reshape(B, self.num_heads, self.key_channels//self.num_heads, HW)
254
+ # (b, num_heads, H*W, self.key_channels//self.num_heads)
255
+ #query = query.reshape(B, self.num_heads, self.key_channels//self.num_heads, n_l)
256
+ query = query.reshape(B, n_l, self.num_heads, self.key_channels//self.num_heads).permute(0, 2, 1, 3)
257
+ # (b, num_heads, self.key_channels//self.num_heads, n_l)
258
+ #value = value.reshape(B, self.num_heads, self.value_channels//self.num_heads, n_l)
259
+ #print('after reshape', query.shape, key.shape, value.shape)
260
+
261
+ l_mask = l_mask.unsqueeze(-1) # (b, 1, 1, n_l)
262
+
263
+ #sim_map = torch.matmul(query, key) # (B, self.num_heads, H*W, N_l)
264
+ sim_map = torch.matmul(query, key) # (B, self.num_heads, N_l, H*W)
265
+ sim_map = (self.key_channels ** -.5) * sim_map # scaled dot product
266
+
267
+ sim_map = sim_map + (1e4*l_mask - 1e4) # assign a very small number to padding positions
268
+ sim_map = F.softmax(sim_map, dim=-1) # (B, num_heads, h*w, N_l)
269
+ out = torch.matmul(sim_map, value.permute(0, 1, 3, 2)) # (B, num_heads, H*W, self.value_channels//num_heads)
270
+ #print('out', out.shape)
271
+ #out = out.permute(0, 2, 1, 3).contiguous().reshape(B, HW, self.value_channels) # (B, H*W, value_channels)
272
+ out = out.permute(0, 2, 1, 3).contiguous().reshape(B, n_l, self.value_channels) # (B, H*W, value_channels)
273
+ out = out.permute(0, 2, 1) # (B, value_channels, HW)
274
+ out = self.W(out) # (B, value_channels, HW)
275
+ out = out.permute(0, 2, 1) # (B, HW, value_channels)
276
+
277
+ return out
elia/bert/tokenization_bert.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes."""
16
+
17
+
18
+ import collections
19
+ import logging
20
+ import os
21
+ import unicodedata
22
+ from typing import List, Optional
23
+
24
+ from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {
32
+ "vocab_file": {
33
+ "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
34
+ "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
35
+ "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
36
+ "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
37
+ "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
38
+ "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
39
+ "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
40
+ "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
41
+ "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
42
+ "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
43
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
44
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
45
+ "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
46
+ "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
47
+ "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
48
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
49
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
50
+ "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt",
51
+ }
52
+ }
53
+
54
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
55
+ "bert-base-uncased": 512,
56
+ "bert-large-uncased": 512,
57
+ "bert-base-cased": 512,
58
+ "bert-large-cased": 512,
59
+ "bert-base-multilingual-uncased": 512,
60
+ "bert-base-multilingual-cased": 512,
61
+ "bert-base-chinese": 512,
62
+ "bert-base-german-cased": 512,
63
+ "bert-large-uncased-whole-word-masking": 512,
64
+ "bert-large-cased-whole-word-masking": 512,
65
+ "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
66
+ "bert-large-cased-whole-word-masking-finetuned-squad": 512,
67
+ "bert-base-cased-finetuned-mrpc": 512,
68
+ "bert-base-german-dbmdz-cased": 512,
69
+ "bert-base-german-dbmdz-uncased": 512,
70
+ "TurkuNLP/bert-base-finnish-cased-v1": 512,
71
+ "TurkuNLP/bert-base-finnish-uncased-v1": 512,
72
+ "wietsedv/bert-base-dutch-cased": 512,
73
+ }
74
+
75
+ PRETRAINED_INIT_CONFIGURATION = {
76
+ "bert-base-uncased": {"do_lower_case": True},
77
+ "bert-large-uncased": {"do_lower_case": True},
78
+ "bert-base-cased": {"do_lower_case": False},
79
+ "bert-large-cased": {"do_lower_case": False},
80
+ "bert-base-multilingual-uncased": {"do_lower_case": True},
81
+ "bert-base-multilingual-cased": {"do_lower_case": False},
82
+ "bert-base-chinese": {"do_lower_case": False},
83
+ "bert-base-german-cased": {"do_lower_case": False},
84
+ "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
85
+ "bert-large-cased-whole-word-masking": {"do_lower_case": False},
86
+ "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
87
+ "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
88
+ "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
89
+ "bert-base-german-dbmdz-cased": {"do_lower_case": False},
90
+ "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
91
+ "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
92
+ "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
93
+ "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
94
+ }
95
+
96
+
97
+ def load_vocab(vocab_file):
98
+ """Loads a vocabulary file into a dictionary."""
99
+ vocab = collections.OrderedDict()
100
+ with open(vocab_file, "r", encoding="utf-8") as reader:
101
+ tokens = reader.readlines()
102
+ for index, token in enumerate(tokens):
103
+ token = token.rstrip("\n")
104
+ vocab[token] = index
105
+ return vocab
106
+
107
+
108
+ def whitespace_tokenize(text):
109
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
110
+ text = text.strip()
111
+ if not text:
112
+ return []
113
+ tokens = text.split()
114
+ return tokens
115
+
116
+
117
+ class BertTokenizer(PreTrainedTokenizer):
118
+ r"""
119
+ Constructs a BERT tokenizer. Based on WordPiece.
120
+
121
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
122
+ should refer to the superclass for more information regarding methods.
123
+
124
+ Args:
125
+ vocab_file (:obj:`string`):
126
+ File containing the vocabulary.
127
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
128
+ Whether to lowercase the input when tokenizing.
129
+ do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
130
+ Whether to do basic tokenization before WordPiece.
131
+ never_split (:obj:`Iterable`, `optional`, defaults to :obj:`None`):
132
+ Collection of tokens which will never be split during tokenization. Only has an effect when
133
+ :obj:`do_basic_tokenize=True`
134
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
135
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
136
+ token instead.
137
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
138
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
139
+ for sequence classification or for a text and a question for question answering.
140
+ It is also used as the last token of a sequence built with special tokens.
141
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
142
+ The token used for padding, for example when batching sequences of different lengths.
143
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
144
+ The classifier token which is used when doing sequence classification (classification of the whole
145
+ sequence instead of per-token classification). It is the first token of the sequence when built with
146
+ special tokens.
147
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
148
+ The token used for masking values. This is the token used when training this model with masked language
149
+ modeling. This is the token which the model will try to predict.
150
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
151
+ Whether to tokenize Chinese characters.
152
+ This should likely be deactivated for Japanese:
153
+ see: https://github.com/huggingface/transformers/issues/328
154
+ """
155
+
156
+ vocab_files_names = VOCAB_FILES_NAMES
157
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
158
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
159
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
160
+
161
+ def __init__(
162
+ self,
163
+ vocab_file,
164
+ do_lower_case=True,
165
+ do_basic_tokenize=True,
166
+ never_split=None,
167
+ unk_token="[UNK]",
168
+ sep_token="[SEP]",
169
+ pad_token="[PAD]",
170
+ cls_token="[CLS]",
171
+ mask_token="[MASK]",
172
+ tokenize_chinese_chars=True,
173
+ **kwargs
174
+ ):
175
+ super().__init__(
176
+ unk_token=unk_token,
177
+ sep_token=sep_token,
178
+ pad_token=pad_token,
179
+ cls_token=cls_token,
180
+ mask_token=mask_token,
181
+ **kwargs,
182
+ )
183
+
184
+ if not os.path.isfile(vocab_file):
185
+ raise ValueError(
186
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
187
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
188
+ )
189
+ self.vocab = load_vocab(vocab_file)
190
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
191
+ self.do_basic_tokenize = do_basic_tokenize
192
+ if do_basic_tokenize:
193
+ self.basic_tokenizer = BasicTokenizer(
194
+ do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars
195
+ )
196
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
197
+
198
+ @property
199
+ def vocab_size(self):
200
+ return len(self.vocab)
201
+
202
+ def get_vocab(self):
203
+ return dict(self.vocab, **self.added_tokens_encoder)
204
+
205
+ def _tokenize(self, text):
206
+ split_tokens = []
207
+ if self.do_basic_tokenize:
208
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
209
+
210
+ # If the token is part of the never_split set
211
+ if token in self.basic_tokenizer.never_split:
212
+ split_tokens.append(token)
213
+ else:
214
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
215
+ else:
216
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
217
+ return split_tokens
218
+
219
+ def _convert_token_to_id(self, token):
220
+ """ Converts a token (str) in an id using the vocab. """
221
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
222
+
223
+ def _convert_id_to_token(self, index):
224
+ """Converts an index (integer) in a token (str) using the vocab."""
225
+ return self.ids_to_tokens.get(index, self.unk_token)
226
+
227
+ def convert_tokens_to_string(self, tokens):
228
+ """ Converts a sequence of tokens (string) in a single string. """
229
+ out_string = " ".join(tokens).replace(" ##", "").strip()
230
+ return out_string
231
+
232
+ def build_inputs_with_special_tokens(
233
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
234
+ ) -> List[int]:
235
+ """
236
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
237
+ by concatenating and adding special tokens.
238
+ A BERT sequence has the following format:
239
+
240
+ - single sequence: ``[CLS] X [SEP]``
241
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
242
+
243
+ Args:
244
+ token_ids_0 (:obj:`List[int]`):
245
+ List of IDs to which the special tokens will be added
246
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
247
+ Optional second list of IDs for sequence pairs.
248
+
249
+ Returns:
250
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
251
+ """
252
+ if token_ids_1 is None:
253
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
254
+ cls = [self.cls_token_id]
255
+ sep = [self.sep_token_id]
256
+ return cls + token_ids_0 + sep + token_ids_1 + sep
257
+
258
+ def get_special_tokens_mask(
259
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
260
+ ) -> List[int]:
261
+ """
262
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
263
+ special tokens using the tokenizer ``prepare_for_model`` method.
264
+
265
+ Args:
266
+ token_ids_0 (:obj:`List[int]`):
267
+ List of ids.
268
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
269
+ Optional second list of IDs for sequence pairs.
270
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
271
+ Set to True if the token list is already formatted with special tokens for the model
272
+
273
+ Returns:
274
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
275
+ """
276
+
277
+ if already_has_special_tokens:
278
+ if token_ids_1 is not None:
279
+ raise ValueError(
280
+ "You should not supply a second sequence if the provided sequence of "
281
+ "ids is already formated with special tokens for the model."
282
+ )
283
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
284
+
285
+ if token_ids_1 is not None:
286
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
287
+ return [1] + ([0] * len(token_ids_0)) + [1]
288
+
289
+ def create_token_type_ids_from_sequences(
290
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
291
+ ) -> List[int]:
292
+ """
293
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
294
+ A BERT sequence pair mask has the following format:
295
+
296
+ ::
297
+
298
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
299
+ | first sequence | second sequence |
300
+
301
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
302
+
303
+ Args:
304
+ token_ids_0 (:obj:`List[int]`):
305
+ List of ids.
306
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
307
+ Optional second list of IDs for sequence pairs.
308
+
309
+ Returns:
310
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
311
+ sequence(s).
312
+ """
313
+ sep = [self.sep_token_id]
314
+ cls = [self.cls_token_id]
315
+ if token_ids_1 is None:
316
+ return len(cls + token_ids_0 + sep) * [0]
317
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
318
+
319
+ def save_vocabulary(self, vocab_path):
320
+ """
321
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
322
+
323
+ Args:
324
+ vocab_path (:obj:`str`):
325
+ The directory in which to save the vocabulary.
326
+
327
+ Returns:
328
+ :obj:`Tuple(str)`: Paths to the files saved.
329
+ """
330
+ index = 0
331
+ if os.path.isdir(vocab_path):
332
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
333
+ else:
334
+ vocab_file = vocab_path
335
+ with open(vocab_file, "w", encoding="utf-8") as writer:
336
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
337
+ if index != token_index:
338
+ logger.warning(
339
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
340
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
341
+ )
342
+ index = token_index
343
+ writer.write(token + "\n")
344
+ index += 1
345
+ return (vocab_file,)
346
+
347
+
348
+ class BasicTokenizer(object):
349
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
350
+
351
+ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
352
+ """ Constructs a BasicTokenizer.
353
+
354
+ Args:
355
+ **do_lower_case**: Whether to lower case the input.
356
+ **never_split**: (`optional`) list of str
357
+ Kept for backward compatibility purposes.
358
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
359
+ List of token not to split.
360
+ **tokenize_chinese_chars**: (`optional`) boolean (default True)
361
+ Whether to tokenize Chinese characters.
362
+ This should likely be deactivated for Japanese:
363
+ see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
364
+ """
365
+ if never_split is None:
366
+ never_split = []
367
+ self.do_lower_case = do_lower_case
368
+ self.never_split = set(never_split)
369
+ self.tokenize_chinese_chars = tokenize_chinese_chars
370
+
371
+ def tokenize(self, text, never_split=None):
372
+ """ Basic Tokenization of a piece of text.
373
+ Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
374
+
375
+ Args:
376
+ **never_split**: (`optional`) list of str
377
+ Kept for backward compatibility purposes.
378
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
379
+ List of token not to split.
380
+ """
381
+ # union() returns a new set by concatenating the two sets.
382
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
383
+
384
+ # This was added on November 1st, 2018 for the multilingual and Chinese
385
+ # models. This is also applied to the English models now, but it doesn't
386
+ # matter since the English models were not trained on any Chinese data
387
+ # and generally don't have any Chinese data in them (there are Chinese
388
+ # characters in the vocabulary because Wikipedia does have some Chinese
389
+ # words in the English Wikipedia.).
390
+ if self.tokenize_chinese_chars:
391
+ text = self._tokenize_chinese_chars(text)
392
+ orig_tokens = whitespace_tokenize(text)
393
+ split_tokens = []
394
+ for token in orig_tokens:
395
+ if self.do_lower_case and token not in never_split:
396
+ token = token.lower()
397
+ token = self._run_strip_accents(token)
398
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
399
+
400
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
401
+ return output_tokens
402
+
403
+ def _run_strip_accents(self, text):
404
+ """Strips accents from a piece of text."""
405
+ text = unicodedata.normalize("NFD", text)
406
+ output = []
407
+ for char in text:
408
+ cat = unicodedata.category(char)
409
+ if cat == "Mn":
410
+ continue
411
+ output.append(char)
412
+ return "".join(output)
413
+
414
+ def _run_split_on_punc(self, text, never_split=None):
415
+ """Splits punctuation on a piece of text."""
416
+ if never_split is not None and text in never_split:
417
+ return [text]
418
+ chars = list(text)
419
+ i = 0
420
+ start_new_word = True
421
+ output = []
422
+ while i < len(chars):
423
+ char = chars[i]
424
+ if _is_punctuation(char):
425
+ output.append([char])
426
+ start_new_word = True
427
+ else:
428
+ if start_new_word:
429
+ output.append([])
430
+ start_new_word = False
431
+ output[-1].append(char)
432
+ i += 1
433
+
434
+ return ["".join(x) for x in output]
435
+
436
+ def _tokenize_chinese_chars(self, text):
437
+ """Adds whitespace around any CJK character."""
438
+ output = []
439
+ for char in text:
440
+ cp = ord(char)
441
+ if self._is_chinese_char(cp):
442
+ output.append(" ")
443
+ output.append(char)
444
+ output.append(" ")
445
+ else:
446
+ output.append(char)
447
+ return "".join(output)
448
+
449
+ def _is_chinese_char(self, cp):
450
+ """Checks whether CP is the codepoint of a CJK character."""
451
+ # This defines a "chinese character" as anything in the CJK Unicode block:
452
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
453
+ #
454
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
455
+ # despite its name. The modern Korean Hangul alphabet is a different block,
456
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
457
+ # space-separated words, so they are not treated specially and handled
458
+ # like the all of the other languages.
459
+ if (
460
+ (cp >= 0x4E00 and cp <= 0x9FFF)
461
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
462
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
463
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
464
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
465
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
466
+ or (cp >= 0xF900 and cp <= 0xFAFF)
467
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
468
+ ): #
469
+ return True
470
+
471
+ return False
472
+
473
+ def _clean_text(self, text):
474
+ """Performs invalid character removal and whitespace cleanup on text."""
475
+ output = []
476
+ for char in text:
477
+ cp = ord(char)
478
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
479
+ continue
480
+ if _is_whitespace(char):
481
+ output.append(" ")
482
+ else:
483
+ output.append(char)
484
+ return "".join(output)
485
+
486
+
487
+ class WordpieceTokenizer(object):
488
+ """Runs WordPiece tokenization."""
489
+
490
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
491
+ self.vocab = vocab
492
+ self.unk_token = unk_token
493
+ self.max_input_chars_per_word = max_input_chars_per_word
494
+
495
+ def tokenize(self, text):
496
+ """Tokenizes a piece of text into its word pieces.
497
+
498
+ This uses a greedy longest-match-first algorithm to perform tokenization
499
+ using the given vocabulary.
500
+
501
+ For example:
502
+ input = "unaffable"
503
+ output = ["un", "##aff", "##able"]
504
+
505
+ Args:
506
+ text: A single token or whitespace separated tokens. This should have
507
+ already been passed through `BasicTokenizer`.
508
+
509
+ Returns:
510
+ A list of wordpiece tokens.
511
+ """
512
+
513
+ output_tokens = []
514
+ for token in whitespace_tokenize(text):
515
+ chars = list(token)
516
+ if len(chars) > self.max_input_chars_per_word:
517
+ output_tokens.append(self.unk_token)
518
+ continue
519
+
520
+ is_bad = False
521
+ start = 0
522
+ sub_tokens = []
523
+ while start < len(chars):
524
+ end = len(chars)
525
+ cur_substr = None
526
+ while start < end:
527
+ substr = "".join(chars[start:end])
528
+ if start > 0:
529
+ substr = "##" + substr
530
+ if substr in self.vocab:
531
+ cur_substr = substr
532
+ break
533
+ end -= 1
534
+ if cur_substr is None:
535
+ is_bad = True
536
+ break
537
+ sub_tokens.append(cur_substr)
538
+ start = end
539
+
540
+ if is_bad:
541
+ output_tokens.append(self.unk_token)
542
+ else:
543
+ output_tokens.extend(sub_tokens)
544
+ return output_tokens
545
+
elia/bert/tokenization_utils.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Tokenization classes for python tokenizers.
16
+ For fast tokenizers (provided by HuggingFace's tokenizers library) see tokenization_utils_fast.py
17
+ """
18
+
19
+ import itertools
20
+ import logging
21
+ import re
22
+ import unicodedata
23
+ from typing import Dict, List, Optional, Tuple, Union
24
+
25
+ from .file_utils import add_end_docstrings
26
+ from .tokenization_utils_base import (
27
+ ENCODE_KWARGS_DOCSTRING,
28
+ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
29
+ AddedToken,
30
+ BatchEncoding,
31
+ EncodedInput,
32
+ EncodedInputPair,
33
+ PaddingStrategy,
34
+ PreTokenizedInput,
35
+ PreTokenizedInputPair,
36
+ PreTrainedTokenizerBase,
37
+ TensorType,
38
+ TextInput,
39
+ TextInputPair,
40
+ TruncationStrategy,
41
+ )
42
+
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ def _is_whitespace(char):
48
+ """Checks whether `chars` is a whitespace character."""
49
+ # \t, \n, and \r are technically contorl characters but we treat them
50
+ # as whitespace since they are generally considered as such.
51
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
52
+ return True
53
+ cat = unicodedata.category(char)
54
+ if cat == "Zs":
55
+ return True
56
+ return False
57
+
58
+
59
+ def _is_control(char):
60
+ """Checks whether `chars` is a control character."""
61
+ # These are technically control characters but we count them as whitespace
62
+ # characters.
63
+ if char == "\t" or char == "\n" or char == "\r":
64
+ return False
65
+ cat = unicodedata.category(char)
66
+ if cat.startswith("C"):
67
+ return True
68
+ return False
69
+
70
+
71
+ def _is_punctuation(char):
72
+ """Checks whether `chars` is a punctuation character."""
73
+ cp = ord(char)
74
+ # We treat all non-letter/number ASCII as punctuation.
75
+ # Characters such as "^", "$", and "`" are not in the Unicode
76
+ # Punctuation class but we treat them as punctuation anyways, for
77
+ # consistency.
78
+ if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
79
+ return True
80
+ cat = unicodedata.category(char)
81
+ if cat.startswith("P"):
82
+ return True
83
+ return False
84
+
85
+
86
+ def _is_end_of_word(text):
87
+ """Checks whether the last character in text is one of a punctuation, control or whitespace character."""
88
+ last_char = text[-1]
89
+ return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
90
+
91
+
92
+ def _is_start_of_word(text):
93
+ """Checks whether the first character in text is one of a punctuation, control or whitespace character."""
94
+ first_char = text[0]
95
+ return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
96
+
97
+
98
+ class PreTrainedTokenizer(PreTrainedTokenizerBase):
99
+ """ Base class for all slow tokenizers.
100
+
101
+ Handle all the shared methods for tokenization and special tokens as well as methods
102
+ downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
103
+
104
+ This class also contain the added tokens in a unified way on top of all tokenizers so we don't
105
+ have to handle the specific vocabulary augmentation methods of the various underlying
106
+ dictionary structures (BPE, sentencepiece...).
107
+
108
+ Class attributes (overridden by derived classes):
109
+
110
+ - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file
111
+ required by the model, and as associated values, the filename for saving the associated file (string).
112
+ - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys
113
+ being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the
114
+ `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the
115
+ associated pretrained vocabulary file.
116
+ - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained
117
+ models, and as associated values, the maximum length of the sequence inputs of this model, or None if the
118
+ model has no maximum input size.
119
+ - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the
120
+ pretrained models, and as associated values, a dictionnary of specific arguments to pass to the
121
+ ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the
122
+ ``from_pretrained()`` method.
123
+
124
+ Args:
125
+ - ``model_max_length``: (`Optional`) int: the maximum length in number of tokens for the inputs to the transformer model.
126
+ When the tokenizer is loaded with `from_pretrained`, this will be set to the value stored for the associated
127
+ model in ``max_model_input_sizes`` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (`int(1e30)`).
128
+ no associated max_length can be found in ``max_model_input_sizes``.
129
+ - ``padding_side``: (`Optional`) string: the side on which the model should have padding applied.
130
+ Should be selected between ['right', 'left']
131
+ - ``model_input_names``: (`Optional`) List[string]: the list of the forward pass inputs accepted by the
132
+ model ("token_type_ids", "attention_mask"...).
133
+ - ``bos_token``: (`Optional`) string: a beginning of sentence token.
134
+ Will be associated to ``self.bos_token`` and ``self.bos_token_id``
135
+ - ``eos_token``: (`Optional`) string: an end of sentence token.
136
+ Will be associated to ``self.eos_token`` and ``self.eos_token_id``
137
+ - ``unk_token``: (`Optional`) string: an unknown token.
138
+ Will be associated to ``self.unk_token`` and ``self.unk_token_id``
139
+ - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence).
140
+ Will be associated to ``self.sep_token`` and ``self.sep_token_id``
141
+ - ``pad_token``: (`Optional`) string: a padding token.
142
+ Will be associated to ``self.pad_token`` and ``self.pad_token_id``
143
+ - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence
144
+ leveraging self-attention along the full depth of the model).
145
+ Will be associated to ``self.cls_token`` and ``self.cls_token_id``
146
+ - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language
147
+ modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
148
+ - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens.
149
+ Adding all special tokens here ensure they won't be split by the tokenization process.
150
+ Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
151
+
152
+
153
+ .. automethod:: __call__
154
+ """
155
+
156
+ def __init__(self, **kwargs):
157
+ super().__init__(**kwargs)
158
+
159
+ # Added tokens - We store this for both slow and fast tokenizers
160
+ # until the serialization of Fast tokenizers is updated
161
+ self.added_tokens_encoder: Dict[str, int] = {}
162
+ self.added_tokens_decoder: Dict[int, str] = {}
163
+ self.unique_no_split_tokens: List[str] = []
164
+
165
+ @property
166
+ def is_fast(self) -> bool:
167
+ return False
168
+
169
+ @property
170
+ def vocab_size(self) -> int:
171
+ """ Size of the base vocabulary (without the added tokens) """
172
+ raise NotImplementedError
173
+
174
+ def get_vocab(self):
175
+ """ Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
176
+ raise NotImplementedError()
177
+
178
+ def get_added_vocab(self) -> Dict[str, int]:
179
+ return self.added_tokens_encoder
180
+
181
+ def __len__(self):
182
+ """ Size of the full vocabulary with the added tokens """
183
+ return self.vocab_size + len(self.added_tokens_encoder)
184
+
185
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens=False) -> int:
186
+ """
187
+ Add a list of new tokens to the tokenizer class. If the new tokens are not in the
188
+ vocabulary, they are added to it with indices starting from length of the current vocabulary.
189
+
190
+ Args:
191
+ new_tokens: string or list of string. Each string is a token to add. Tokens are only added if they are not
192
+ already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
193
+
194
+ Returns:
195
+ Number of tokens added to the vocabulary.
196
+
197
+ Examples::
198
+
199
+ # Let's see how to increase the vocabulary of Bert model and tokenizer
200
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
201
+ model = BertModel.from_pretrained('bert-base-uncased')
202
+
203
+ num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
204
+ print('We have added', num_added_toks, 'tokens')
205
+ model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
206
+ """
207
+ new_tokens = [str(tok) for tok in new_tokens]
208
+
209
+ tokens_to_add = []
210
+ for token in new_tokens:
211
+ assert isinstance(token, str)
212
+ if not special_tokens and self.init_kwargs.get("do_lower_case", False):
213
+ token = token.lower()
214
+ if (
215
+ token != self.unk_token
216
+ and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
217
+ and token not in tokens_to_add
218
+ ):
219
+ tokens_to_add.append(token)
220
+ if self.verbose:
221
+ logger.info("Adding %s to the vocabulary", token)
222
+
223
+ added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
224
+ added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
225
+ self.added_tokens_encoder.update(added_tok_encoder)
226
+ self.added_tokens_decoder.update(added_tok_decoder)
227
+
228
+ # Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
229
+ if special_tokens:
230
+ self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(new_tokens)))
231
+ else:
232
+ # Or on the newly added tokens
233
+ self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
234
+
235
+ return len(tokens_to_add)
236
+
237
+ def num_special_tokens_to_add(self, pair=False):
238
+ """
239
+ Returns the number of added tokens when encoding a sequence with special tokens.
240
+
241
+ Note:
242
+ This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this
243
+ inside your training loop.
244
+
245
+ Args:
246
+ pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the
247
+ number of added tokens in the case of a single sequence if set to False.
248
+
249
+ Returns:
250
+ Number of tokens added to sequences
251
+ """
252
+ token_ids_0 = []
253
+ token_ids_1 = []
254
+ return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
255
+
256
+ def tokenize(self, text: TextInput, **kwargs):
257
+ """ Converts a string in a sequence of tokens (string), using the tokenizer.
258
+ Split in words for word-based vocabulary or sub-words for sub-word-based
259
+ vocabularies (BPE/SentencePieces/WordPieces).
260
+
261
+ Take care of added tokens.
262
+
263
+ Args:
264
+ text (:obj:`string`): The sequence to be encoded.
265
+ **kwargs (:obj: `dict`): Arguments passed to the model-specific `prepare_for_tokenization` preprocessing method.
266
+ """
267
+ # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
268
+ all_special_tokens_extended = dict(
269
+ (str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
270
+ )
271
+
272
+ text, kwargs = self.prepare_for_tokenization(text, **kwargs)
273
+
274
+ if kwargs:
275
+ logger.warning(f"Keyword arguments {kwargs} not recognized.")
276
+
277
+ # TODO: should this be in the base class?
278
+ if self.init_kwargs.get("do_lower_case", False):
279
+ # convert non-special tokens to lowercase
280
+ escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
281
+ pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
282
+ text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
283
+
284
+ def split_on_token(tok, text):
285
+ result = []
286
+ tok_extended = all_special_tokens_extended.get(tok, None)
287
+ split_text = text.split(tok)
288
+ full_word = ""
289
+ for i, sub_text in enumerate(split_text):
290
+ # AddedToken can control whitespace stripping around them.
291
+ # We use them for GPT2 and Roberta to have different behavior depending on the special token
292
+ # Cf. https://github.com/huggingface/transformers/pull/2778
293
+ # and https://github.com/huggingface/transformers/issues/3788
294
+ if isinstance(tok_extended, AddedToken):
295
+ if tok_extended.single_word:
296
+ # Try to avoid splitting on token
297
+ if (
298
+ i < len(split_text) - 1
299
+ and not _is_end_of_word(sub_text)
300
+ and not _is_start_of_word(split_text[i + 1])
301
+ ):
302
+ # Don't extract the special token
303
+ full_word += sub_text + tok
304
+ elif full_word:
305
+ full_word += sub_text
306
+ result += [full_word]
307
+ full_word = ""
308
+ continue
309
+ # Strip white spaces on the right
310
+ if tok_extended.rstrip and i > 0:
311
+ # A bit counter-intuitive but we strip the left of the string
312
+ # since tok_extended.rstrip means the special token is eating all white spaces on its right
313
+ sub_text = sub_text.lstrip()
314
+ # Strip white spaces on the left
315
+ if tok_extended.lstrip and i < len(split_text) - 1:
316
+ sub_text = sub_text.rstrip() # Opposite here
317
+ else:
318
+ # We strip left and right by default
319
+ if i < len(split_text) - 1:
320
+ sub_text = sub_text.rstrip()
321
+ if i > 0:
322
+ sub_text = sub_text.lstrip()
323
+
324
+ if i == 0 and not sub_text:
325
+ result += [tok]
326
+ elif i == len(split_text) - 1:
327
+ if sub_text:
328
+ result += [sub_text]
329
+ else:
330
+ pass
331
+ else:
332
+ if sub_text:
333
+ result += [sub_text]
334
+ result += [tok]
335
+ return result
336
+
337
+ def split_on_tokens(tok_list, text):
338
+ if not text.strip():
339
+ return []
340
+ if not tok_list:
341
+ return self._tokenize(text)
342
+
343
+ tokenized_text = []
344
+ text_list = [text]
345
+ for tok in tok_list:
346
+ tokenized_text = []
347
+ for sub_text in text_list:
348
+ if sub_text not in self.unique_no_split_tokens:
349
+ tokenized_text += split_on_token(tok, sub_text)
350
+ else:
351
+ tokenized_text += [sub_text]
352
+ text_list = tokenized_text
353
+
354
+ return list(
355
+ itertools.chain.from_iterable(
356
+ (
357
+ self._tokenize(token) if token not in self.unique_no_split_tokens else [token]
358
+ for token in tokenized_text
359
+ )
360
+ )
361
+ )
362
+
363
+ no_split_token = self.unique_no_split_tokens
364
+ tokenized_text = split_on_tokens(no_split_token, text)
365
+ return tokenized_text
366
+
367
+ def _tokenize(self, text, **kwargs):
368
+ """ Converts a string in a sequence of tokens (string), using the tokenizer.
369
+ Split in words for word-based vocabulary or sub-words for sub-word-based
370
+ vocabularies (BPE/SentencePieces/WordPieces).
371
+
372
+ Do NOT take care of added tokens.
373
+ """
374
+ raise NotImplementedError
375
+
376
+ def convert_tokens_to_ids(self, tokens):
377
+ """ Converts a token string (or a sequence of tokens) in a single integer id
378
+ (or a sequence of ids), using the vocabulary.
379
+ """
380
+ if tokens is None:
381
+ return None
382
+
383
+ if isinstance(tokens, str):
384
+ return self._convert_token_to_id_with_added_voc(tokens)
385
+
386
+ ids = []
387
+ for token in tokens:
388
+ ids.append(self._convert_token_to_id_with_added_voc(token))
389
+ return ids
390
+
391
+ def _convert_token_to_id_with_added_voc(self, token):
392
+ if token is None:
393
+ return None
394
+
395
+ if token in self.added_tokens_encoder:
396
+ return self.added_tokens_encoder[token]
397
+ return self._convert_token_to_id(token)
398
+
399
+ def _convert_token_to_id(self, token):
400
+ raise NotImplementedError
401
+
402
+ def _encode_plus(
403
+ self,
404
+ text: Union[TextInput, PreTokenizedInput, EncodedInput],
405
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
406
+ add_special_tokens: bool = True,
407
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
408
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
409
+ max_length: Optional[int] = None,
410
+ stride: int = 0,
411
+ is_pretokenized: bool = False,
412
+ pad_to_multiple_of: Optional[int] = None,
413
+ return_tensors: Optional[Union[str, TensorType]] = None,
414
+ return_token_type_ids: Optional[bool] = None,
415
+ return_attention_mask: Optional[bool] = None,
416
+ return_overflowing_tokens: bool = False,
417
+ return_special_tokens_mask: bool = False,
418
+ return_offsets_mapping: bool = False,
419
+ return_length: bool = False,
420
+ verbose: bool = True,
421
+ **kwargs
422
+ ) -> BatchEncoding:
423
+ def get_input_ids(text):
424
+ if isinstance(text, str):
425
+ tokens = self.tokenize(text, **kwargs)
426
+ return self.convert_tokens_to_ids(tokens)
427
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
428
+ if is_pretokenized:
429
+ tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
430
+ return self.convert_tokens_to_ids(tokens)
431
+ else:
432
+ return self.convert_tokens_to_ids(text)
433
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
434
+ return text
435
+ else:
436
+ if is_pretokenized:
437
+ raise ValueError(
438
+ f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_pretokenized=True`."
439
+ )
440
+ else:
441
+ raise ValueError(
442
+ f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
443
+ )
444
+
445
+ if return_offsets_mapping:
446
+ raise NotImplementedError(
447
+ "return_offset_mapping is not available when using Python tokenizers."
448
+ "To use this feature, change your tokenizer to one deriving from "
449
+ "transformers.PreTrainedTokenizerFast."
450
+ "More information on available tokenizers at "
451
+ "https://github.com/huggingface/transformers/pull/2674"
452
+ )
453
+
454
+ first_ids = get_input_ids(text)
455
+ second_ids = get_input_ids(text_pair) if text_pair is not None else None
456
+
457
+ return self.prepare_for_model(
458
+ first_ids,
459
+ pair_ids=second_ids,
460
+ add_special_tokens=add_special_tokens,
461
+ padding=padding_strategy.value,
462
+ truncation=truncation_strategy.value,
463
+ max_length=max_length,
464
+ stride=stride,
465
+ pad_to_multiple_of=pad_to_multiple_of,
466
+ return_tensors=return_tensors,
467
+ prepend_batch_axis=True,
468
+ return_attention_mask=return_attention_mask,
469
+ return_token_type_ids=return_token_type_ids,
470
+ return_overflowing_tokens=return_overflowing_tokens,
471
+ return_special_tokens_mask=return_special_tokens_mask,
472
+ return_length=return_length,
473
+ verbose=verbose,
474
+ )
475
+
476
+ def _batch_encode_plus(
477
+ self,
478
+ batch_text_or_text_pairs: Union[
479
+ List[TextInput],
480
+ List[TextInputPair],
481
+ List[PreTokenizedInput],
482
+ List[PreTokenizedInputPair],
483
+ List[EncodedInput],
484
+ List[EncodedInputPair],
485
+ ],
486
+ add_special_tokens: bool = True,
487
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
488
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
489
+ max_length: Optional[int] = None,
490
+ stride: int = 0,
491
+ is_pretokenized: bool = False,
492
+ pad_to_multiple_of: Optional[int] = None,
493
+ return_tensors: Optional[Union[str, TensorType]] = None,
494
+ return_token_type_ids: Optional[bool] = None,
495
+ return_attention_mask: Optional[bool] = None,
496
+ return_overflowing_tokens: bool = False,
497
+ return_special_tokens_mask: bool = False,
498
+ return_offsets_mapping: bool = False,
499
+ return_length: bool = False,
500
+ verbose: bool = True,
501
+ **kwargs
502
+ ) -> BatchEncoding:
503
+ def get_input_ids(text):
504
+ if isinstance(text, str):
505
+ tokens = self.tokenize(text, **kwargs)
506
+ return self.convert_tokens_to_ids(tokens)
507
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
508
+ if is_pretokenized:
509
+ tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
510
+ return self.convert_tokens_to_ids(tokens)
511
+ else:
512
+ return self.convert_tokens_to_ids(text)
513
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
514
+ return text
515
+ else:
516
+ raise ValueError(
517
+ "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
518
+ )
519
+
520
+ if return_offsets_mapping:
521
+ raise NotImplementedError(
522
+ "return_offset_mapping is not available when using Python tokenizers."
523
+ "To use this feature, change your tokenizer to one deriving from "
524
+ "transformers.PreTrainedTokenizerFast."
525
+ )
526
+
527
+ input_ids = []
528
+ for ids_or_pair_ids in batch_text_or_text_pairs:
529
+ if not isinstance(ids_or_pair_ids, (list, tuple)):
530
+ ids, pair_ids = ids_or_pair_ids, None
531
+ elif is_pretokenized and not isinstance(ids_or_pair_ids[0], (list, tuple)):
532
+ ids, pair_ids = ids_or_pair_ids, None
533
+ else:
534
+ ids, pair_ids = ids_or_pair_ids
535
+
536
+ first_ids = get_input_ids(ids)
537
+ second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
538
+ input_ids.append((first_ids, second_ids))
539
+
540
+ batch_outputs = self._batch_prepare_for_model(
541
+ input_ids,
542
+ add_special_tokens=add_special_tokens,
543
+ padding_strategy=padding_strategy,
544
+ truncation_strategy=truncation_strategy,
545
+ max_length=max_length,
546
+ stride=stride,
547
+ pad_to_multiple_of=pad_to_multiple_of,
548
+ return_attention_mask=return_attention_mask,
549
+ return_token_type_ids=return_token_type_ids,
550
+ return_overflowing_tokens=return_overflowing_tokens,
551
+ return_special_tokens_mask=return_special_tokens_mask,
552
+ return_length=return_length,
553
+ return_tensors=return_tensors,
554
+ verbose=verbose,
555
+ )
556
+
557
+ return BatchEncoding(batch_outputs)
558
+
559
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
560
+ def _batch_prepare_for_model(
561
+ self,
562
+ batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]],
563
+ add_special_tokens: bool = True,
564
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
565
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
566
+ max_length: Optional[int] = None,
567
+ stride: int = 0,
568
+ pad_to_multiple_of: Optional[int] = None,
569
+ return_tensors: Optional[str] = None,
570
+ return_token_type_ids: Optional[bool] = None,
571
+ return_attention_mask: Optional[bool] = None,
572
+ return_overflowing_tokens: bool = False,
573
+ return_special_tokens_mask: bool = False,
574
+ return_length: bool = False,
575
+ verbose: bool = True,
576
+ ) -> BatchEncoding:
577
+ """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
578
+ It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
579
+ manages a moving window (with user defined stride) for overflowing tokens
580
+
581
+ Args:
582
+ batch_ids_pairs: list of tokenized input ids or input ids pairs
583
+ """
584
+
585
+ batch_outputs = {}
586
+ for first_ids, second_ids in batch_ids_pairs:
587
+ outputs = self.prepare_for_model(
588
+ first_ids,
589
+ second_ids,
590
+ add_special_tokens=add_special_tokens,
591
+ padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
592
+ truncation=truncation_strategy.value,
593
+ max_length=max_length,
594
+ stride=stride,
595
+ pad_to_multiple_of=None, # we pad in batch afterward
596
+ return_attention_mask=False, # we pad in batch afterward
597
+ return_token_type_ids=return_token_type_ids,
598
+ return_overflowing_tokens=return_overflowing_tokens,
599
+ return_special_tokens_mask=return_special_tokens_mask,
600
+ return_length=return_length,
601
+ return_tensors=None, # We convert the whole batch to tensors at the end
602
+ prepend_batch_axis=False,
603
+ verbose=verbose,
604
+ )
605
+
606
+ for key, value in outputs.items():
607
+ if key not in batch_outputs:
608
+ batch_outputs[key] = []
609
+ batch_outputs[key].append(value)
610
+
611
+ batch_outputs = self.pad(
612
+ batch_outputs,
613
+ padding=padding_strategy.value,
614
+ max_length=max_length,
615
+ pad_to_multiple_of=pad_to_multiple_of,
616
+ return_attention_mask=return_attention_mask,
617
+ )
618
+
619
+ batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
620
+
621
+ return batch_outputs
622
+
623
+ def prepare_for_tokenization(self, text: str, is_pretokenized=False, **kwargs) -> (str, dict):
624
+ """ Performs any necessary transformations before tokenization.
625
+
626
+ This method should pop the arguments from kwargs and return kwargs as well.
627
+ We test kwargs at the end of the encoding process to be sure all the arguments have been used.
628
+ """
629
+ return (text, kwargs)
630
+
631
+ def get_special_tokens_mask(
632
+ self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
633
+ ) -> List[int]:
634
+ """
635
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
636
+ special tokens using the tokenizer ``prepare_for_model`` method.
637
+
638
+ Args:
639
+ token_ids_0: list of ids (must not contain special tokens)
640
+ token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
641
+ for sequence pairs
642
+ already_has_special_tokens: (default False) Set to True if the token list is already formated with
643
+ special tokens for the model
644
+
645
+ Returns:
646
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
647
+ """
648
+ return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
649
+
650
+ def convert_ids_to_tokens(
651
+ self, ids: Union[int, List[int]], skip_special_tokens: bool = False
652
+ ) -> Union[str, List[str]]:
653
+ """ Converts a single index or a sequence of indices (integers) in a token "
654
+ (resp.) a sequence of tokens (str), using the vocabulary and added tokens.
655
+
656
+ Args:
657
+ skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
658
+ """
659
+ if isinstance(ids, int):
660
+ if ids in self.added_tokens_decoder:
661
+ return self.added_tokens_decoder[ids]
662
+ else:
663
+ return self._convert_id_to_token(ids)
664
+ tokens = []
665
+ for index in ids:
666
+ index = int(index)
667
+ if skip_special_tokens and index in self.all_special_ids:
668
+ continue
669
+ if index in self.added_tokens_decoder:
670
+ tokens.append(self.added_tokens_decoder[index])
671
+ else:
672
+ tokens.append(self._convert_id_to_token(index))
673
+ return tokens
674
+
675
+ def _convert_id_to_token(self, index: int) -> str:
676
+ raise NotImplementedError
677
+
678
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
679
+ """ Converts a sequence of tokens (string) in a single string.
680
+ The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
681
+ but we often want to remove sub-word tokenization artifacts at the same time.
682
+ """
683
+ return " ".join(self.convert_ids_to_tokens(tokens))
684
+
685
+ def decode(
686
+ self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
687
+ ) -> str:
688
+ filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
689
+
690
+ # To avoid mixing byte-level and unicode for byte-level BPT
691
+ # we need to build string separatly for added tokens and byte-level tokens
692
+ # cf. https://github.com/huggingface/transformers/issues/1133
693
+ sub_texts = []
694
+ current_sub_text = []
695
+ for token in filtered_tokens:
696
+ if skip_special_tokens and token in self.all_special_ids:
697
+ continue
698
+ if token in self.added_tokens_encoder:
699
+ if current_sub_text:
700
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
701
+ current_sub_text = []
702
+ sub_texts.append(token)
703
+ else:
704
+ current_sub_text.append(token)
705
+ if current_sub_text:
706
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
707
+ text = " ".join(sub_texts)
708
+
709
+ if clean_up_tokenization_spaces:
710
+ clean_text = self.clean_up_tokenization(text)
711
+ return clean_text
712
+ else:
713
+ return text
714
+
715
+ def save_vocabulary(self, save_directory) -> Tuple[str]:
716
+ """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
717
+ and special token mappings.
718
+
719
+ Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full
720
+ Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained`
721
+ class method.
722
+ """
723
+ raise NotImplementedError
elia/bert/tokenization_utils_base.py ADDED
The diff for this file is too large to render. See raw diff
 
elia/demo_inference.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_path = './image001.png'
2
+ sentence = 'spoon on the dish'
3
+ weights = '/cluster/nvme4/cyx/lavt/vis/model_best_refcoco_0508.pth'
4
+ device = 'cpu'
5
+
6
+ # pre-process the input image
7
+ from PIL import Image
8
+ import torchvision.transforms as T
9
+ import numpy as np
10
+ import datetime
11
+ import os
12
+ import time
13
+
14
+ import torch
15
+ import torch.utils.data
16
+ from torch import nn
17
+
18
+ from bert.multimodal_bert import MultiModalBert
19
+ import torchvision
20
+
21
+ from lib import multimodal_segmentation_ppm
22
+ #import transforms as T
23
+ import utils
24
+
25
+ import numpy as np
26
+ from PIL import Image
27
+ import torch.nn.functional as F
28
+
29
+ from modeling.MaskFormerModel import MaskFormerHead
30
+ from addict import Dict
31
+ #from bert.modeling_bert import BertLMPredictionHead, BertEncoder
32
+ import cv2
33
+ import textwrap
34
+
35
+ class WrapperModel(nn.Module):
36
+ def __init__(self, image_model, language_model, classifier) :
37
+ super(WrapperModel, self).__init__()
38
+ self.image_model = image_model
39
+ self.language_model = language_model
40
+ self.classifier = classifier
41
+
42
+ config = Dict({
43
+ "architectures": [
44
+ "BertForMaskedLM"
45
+ ],
46
+ "attention_probs_dropout_prob": 0.1,
47
+ "gradient_checkpointing": False,
48
+ "hidden_act": "gelu",
49
+ "hidden_dropout_prob": 0.1,
50
+ "hidden_size": 512,
51
+ "initializer_range": 0.02,
52
+ "intermediate_size": 3072,
53
+ "layer_norm_eps": 1e-12,
54
+ #"max_position_embeddings": 16+20,
55
+ "model_type": "bert",
56
+ "num_attention_heads": 8,
57
+ "num_hidden_layers": 8,
58
+ "pad_token_id": 0,
59
+ "position_embedding_type": "absolute",
60
+ "transformers_version": "4.6.0.dev0",
61
+ "type_vocab_size": 2,
62
+ "use_cache": True,
63
+ "vocab_size": 30522
64
+ })
65
+
66
+
67
+
68
+ def _get_binary_mask(self, target):
69
+ # 返回每类的binary mask
70
+ y, x = target.size()
71
+ target_onehot = torch.zeros(self.num_classes + 1, y, x)
72
+ target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
73
+ return target_onehot[1:]
74
+
75
+ def semantic_inference(self, mask_cls, mask_pred):
76
+ mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
77
+ mask_pred = mask_pred.sigmoid()
78
+ semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
79
+ return semseg
80
+
81
+ def forward(self, image, sentences, attentions):
82
+ print(image.sum(), sentences.sum(), attentions.sum())
83
+ input_shape = image.shape[-2:]
84
+ l_mask = attentions.unsqueeze(dim=-1)
85
+
86
+ i0, Wh, Ww = self.image_model.forward_stem(image)
87
+ l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)
88
+
89
+ i1 = self.image_model.forward_stage1(i0, Wh, Ww)
90
+ l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
91
+ i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
92
+ l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask)
93
+ i1 = i1_temp
94
+
95
+ i2 = self.image_model.forward_stage2(i1, Wh, Ww)
96
+ l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
97
+ i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
98
+ l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask)
99
+ i2 = i2_temp
100
+
101
+ i3 = self.image_model.forward_stage3(i2, Wh, Ww)
102
+ l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
103
+ i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
104
+ l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask)
105
+ i3 = i3_temp
106
+
107
+ i4 = self.image_model.forward_stage4(i3, Wh, Ww)
108
+ l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
109
+ i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
110
+ l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask)
111
+ i4 = i4_temp
112
+
113
+ #i1_residual, i2_residual, i3_residual, i4_residual = features
114
+ #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
115
+ #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
116
+ outputs = {}
117
+ outputs['s1'] = i1_residual
118
+ outputs['s2'] = i2_residual
119
+ outputs['s3'] = i3_residual
120
+ outputs['s4'] = i4_residual
121
+
122
+ predictions = self.classifier(outputs)
123
+ return predictions
124
+
125
+ #img = Image.open(image_path).convert("RGB")
126
+ img = Image.open(image_path).convert("RGB")
127
+ img_ndarray = np.array(img) # (orig_h, orig_w, 3); for visualization
128
+ original_w, original_h = img.size # PIL .size returns width first and height second
129
+
130
+ image_transforms = T.Compose(
131
+ [
132
+ T.Resize((480, 480)),
133
+ T.ToTensor(),
134
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
135
+ ]
136
+ )
137
+
138
+ img = image_transforms(img).unsqueeze(0) # (1, 3, 480, 480)
139
+ img = img.to(device) # for inference (input)
140
+
141
+ # pre-process the raw sentence
142
+ from bert.tokenization_bert import BertTokenizer
143
+ import torch
144
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
145
+ sentence_tokenized = tokenizer.encode(text=sentence, add_special_tokens=True)
146
+ sentence_tokenized = sentence_tokenized[:20] # if the sentence is longer than 20, then this truncates it to 20 words
147
+ # pad the tokenized sentence
148
+ padded_sent_toks = [0] * 20
149
+ padded_sent_toks[:len(sentence_tokenized)] = sentence_tokenized
150
+ # create a sentence token mask: 1 for real words; 0 for padded tokens
151
+ attention_mask = [0] * 20
152
+ attention_mask[:len(sentence_tokenized)] = [1]*len(sentence_tokenized)
153
+ # convert lists to tensors
154
+ padded_sent_toks = torch.tensor(padded_sent_toks).unsqueeze(0) # (1, 20)
155
+ attention_mask = torch.tensor(attention_mask).unsqueeze(0) # (1, 20)
156
+ padded_sent_toks = padded_sent_toks.to(device) # for inference (input)
157
+ attention_mask = attention_mask.to(device) # for inference (input)
158
+
159
+ # initialize model and load weights
160
+ #from bert.modeling_bert import BertModel
161
+ #from lib import segmentation
162
+
163
+ # construct a mini args class; like from a config file
164
+
165
+
166
+ class args:
167
+ swin_type = 'base'
168
+ window12 = True
169
+ mha = ''
170
+ fusion_drop = 0.0
171
+
172
+
173
+ #single_model = segmentation.__dict__['lavt'](pretrained='', args=args)
174
+ single_model = multimodal_segmentation_ppm.__dict__['lavt'](pretrained='',args=args)
175
+ single_model.to(device)
176
+ model_class = MultiModalBert
177
+ single_bert_model = model_class.from_pretrained('bert-base-uncased', embed_dim=single_model.backbone.embed_dim)
178
+ single_bert_model.pooler = None
179
+
180
+ input_shape = dict()
181
+ input_shape['s1'] = Dict({'channel': 128, 'stride': 4})
182
+ input_shape['s2'] = Dict({'channel': 256, 'stride': 8})
183
+ input_shape['s3'] = Dict({'channel': 512, 'stride': 16})
184
+ input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})
185
+
186
+
187
+
188
+ cfg = Dict()
189
+ cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
190
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.0
191
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
192
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4
193
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
194
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
195
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]
196
+
197
+ cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
198
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
199
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1
200
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
201
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10
202
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
203
+
204
+
205
+ maskformer_head = MaskFormerHead(cfg, input_shape)
206
+
207
+
208
+ model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head)
209
+
210
+
211
+
212
+ checkpoint = torch.load(weights, map_location='cpu')
213
+
214
+ model.load_state_dict(checkpoint['model'], strict=False)
215
+ model.to(device)
216
+ model.eval()
217
+ #single_bert_model.load_state_dict(checkpoint['bert_model'])
218
+ #single_model.load_state_dict(checkpoint['model'])
219
+ #model = single_model.to(device)
220
+ #bert_model = single_bert_model.to(device)
221
+
222
+
223
+ # inference
224
+ #import torch.nn.functional as F
225
+ #last_hidden_states = bert_model(padded_sent_toks, attention_mask=attention_mask)[0]
226
+ #embedding = last_hidden_states.permute(0, 2, 1)
227
+ #output = model(img, embedding, l_mask=attention_mask.unsqueeze(-1))
228
+ #output = output.argmax(1, keepdim=True) # (1, 1, 480, 480)
229
+ #output = F.interpolate(output.float(), (original_h, original_w)) # 'nearest'; resize to the original image size
230
+ #output = output.squeeze() # (orig_h, orig_w)
231
+ #output = output.cpu().data.numpy() # (orig_h, orig_w)
232
+
233
+ output = model(img, padded_sent_toks, attention_mask)[0]
234
+ #print(output[0].keys())
235
+ #print(output[1].shape)
236
+ mask_cls_results = output["pred_logits"]
237
+ mask_pred_results = output["pred_masks"]
238
+
239
+ target_shape = img_ndarray.shape[:2]
240
+ #print(target_shape, mask_pred_results.shape)
241
+ mask_pred_results = F.interpolate(mask_pred_results, size=(480,480), mode='bilinear', align_corners=True)
242
+
243
+ pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)
244
+ #output = pred_masks[0]
245
+
246
+ #output = output.cpu()
247
+
248
+
249
+
250
+ #print(output.shape)
251
+ #output_mask = output.argmax(1).data.numpy()
252
+ #output = (output > 0.5).data.cpu().numpy()
253
+ output = torch.nn.functional.interpolate(pred_masks, target_shape)
254
+ output = (output > 0.5).data.cpu().numpy()
255
+
256
+
257
+ # show/save results
258
+ def overlay_davis(image, mask, colors=[[0, 0, 0], [255, 0, 0]], cscale=1, alpha=0.4):
259
+ from scipy.ndimage.morphology import binary_dilation
260
+
261
+ colors = np.reshape(colors, (-1, 3))
262
+ colors = np.atleast_2d(colors) * cscale
263
+
264
+ im_overlay = image.copy()
265
+ object_ids = np.unique(mask)
266
+
267
+ for object_id in object_ids[1:]:
268
+ # Overlay color on binary mask
269
+ foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id])
270
+ binary_mask = mask == object_id
271
+
272
+ # Compose image
273
+ im_overlay[binary_mask] = foreground[binary_mask]
274
+
275
+ # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask
276
+ countours = binary_dilation(binary_mask) ^ binary_mask
277
+ # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask
278
+ im_overlay[countours, :] = 0
279
+
280
+ return im_overlay.astype(image.dtype)
281
+
282
+
283
+ output = output.astype(np.uint8) # (orig_h, orig_w), np.uint8
284
+ # Overlay the mask on the image
285
+ print(img_ndarray.shape, output.shape)
286
+ visualization = overlay_davis(img_ndarray, output[0][0]) # red
287
+ visualization = Image.fromarray(visualization)
288
+ # show the visualization
289
+ #visualization.show()
290
+ # Save the visualization
291
+ visualization.save('./demo/spoon_on_the_dish.jpg')
292
+
293
+
294
+
295
+
elia/requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ requests
2
+ filelock
3
+ tqdm
4
+ timm
5
+ mmcv-full==1.3.12
6
+ mmsegmentation==0.17.0
7
+ ftfy
8
+ regex
9
+ scipy
10
+ scikit-image
11
+ pycocotools==2.0.2
12
+ opencv-python==4.5.3.56
13
+ tokenizers==0.8.1rc1
14
+ h5py
elia/test_elia.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import datetime
3
+ import os
4
+ import time
5
+
6
+ import torch
7
+ import torch.utils.data
8
+ from torch import nn
9
+
10
+ from bert.multimodal_bert import MultiModalBert
11
+ import torchvision
12
+
13
+ from lib import multimodal_segmentation_ppm
14
+ import transforms as T
15
+ import utils
16
+
17
+ import numpy as np
18
+ from PIL import Image
19
+ import torch.nn.functional as F
20
+
21
+ from modeling.MaskFormerModel import MaskFormerHead
22
+ from addict import Dict
23
+ from bert.modeling_bert import BertLMPredictionHead, BertEncoder
24
+
25
+ def get_dataset(image_set, transform, args):
26
+ from data.dataset_refer_bert import ReferDataset
27
+ ds = ReferDataset(args,
28
+ split=image_set,
29
+ image_transforms=transform,
30
+ target_transforms=None,
31
+ eval_mode=True
32
+ )
33
+ num_classes = 2
34
+ return ds, num_classes
35
+
36
+
37
+ def evaluate(model, data_loader, device):
38
+ model.eval()
39
+ metric_logger = utils.MetricLogger(delimiter=" ")
40
+
41
+ # evaluation variables
42
+ cum_I, cum_U = 0, 0
43
+ eval_seg_iou_list = [.5, .6, .7, .8, .9]
44
+ seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
45
+ seg_total = 0
46
+ mean_IoU = []
47
+ header = 'Test:'
48
+
49
+ with torch.no_grad():
50
+ for data in metric_logger.log_every(data_loader, 100, header):
51
+ image, target, sentences, attentions = data
52
+ image, target, sentences, attentions = image.to(device), target.to(device), \
53
+ sentences.to(device), attentions.to(device)
54
+ sentences = sentences.squeeze(1)
55
+ attentions = attentions.squeeze(1)
56
+ target = target.cpu().data.numpy()
57
+ for j in range(sentences.size(-1)):
58
+ #if bert_model is not None:
59
+ # last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0]
60
+ # embedding = last_hidden_states.permute(0, 2, 1)
61
+ # output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1))
62
+ #else:
63
+ output = model(image, sentences[:, :, j], attentions[:, :, j])
64
+ mask_cls_results = output["pred_logits"]
65
+ mask_pred_results = output["pred_masks"]
66
+
67
+ target_shape = target.shape[-2:]
68
+ mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True)
69
+
70
+ pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)
71
+ output = pred_masks[0]
72
+
73
+ output = output.cpu()
74
+ #print(output.shape)
75
+ #output_mask = output.argmax(1).data.numpy()
76
+ output_mask = (output > 0.5).data.numpy()
77
+ I, U = computeIoU(output_mask, target)
78
+ if U == 0:
79
+ this_iou = 0.0
80
+ else:
81
+ this_iou = I*1.0/U
82
+ mean_IoU.append(this_iou)
83
+ cum_I += I
84
+ cum_U += U
85
+ for n_eval_iou in range(len(eval_seg_iou_list)):
86
+ eval_seg_iou = eval_seg_iou_list[n_eval_iou]
87
+ seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou)
88
+ seg_total += 1
89
+
90
+ #del image, target, sentences, attentions, output, output_mask
91
+ #if bert_model is not None:
92
+ # del last_hidden_states, embedding
93
+
94
+ mean_IoU = np.array(mean_IoU)
95
+ mIoU = np.mean(mean_IoU)
96
+ print('Final results:')
97
+ print('Mean IoU is %.2f\n' % (mIoU*100.))
98
+ results_str = ''
99
+ for n_eval_iou in range(len(eval_seg_iou_list)):
100
+ results_str += ' precision@%s = %.2f\n' % \
101
+ (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
102
+ results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
103
+ print(results_str)
104
+
105
+
106
+ def get_transform(args):
107
+ transforms = [T.Resize(args.img_size, args.img_size),
108
+ T.ToTensor(),
109
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
110
+ ]
111
+
112
+ return T.Compose(transforms)
113
+
114
+
115
+ def computeIoU(pred_seg, gd_seg):
116
+ I = np.sum(np.logical_and(pred_seg, gd_seg))
117
+ U = np.sum(np.logical_or(pred_seg, gd_seg))
118
+
119
+ return I, U
120
+
121
+ class WrapperModel(nn.Module):
122
+ def __init__(self, image_model, language_model, classifier, args) :
123
+ super(WrapperModel, self).__init__()
124
+ self.image_model = image_model
125
+ self.language_model = language_model
126
+ self.classifier = classifier
127
+ self.lang_proj = nn.Linear(768,256)
128
+
129
+ config = Dict({
130
+ "architectures": [
131
+ "BertForMaskedLM"
132
+ ],
133
+ "attention_probs_dropout_prob": 0.1,
134
+ "gradient_checkpointing": False,
135
+ "hidden_act": "gelu",
136
+ "hidden_dropout_prob": 0.1,
137
+ "hidden_size": 512,
138
+ "initializer_range": 0.02,
139
+ "intermediate_size": 3072,
140
+ "layer_norm_eps": 1e-12,
141
+ #"max_position_embeddings": 16+20,
142
+ "model_type": "bert",
143
+ "num_attention_heads": 8,
144
+ "num_hidden_layers": 8,
145
+ "pad_token_id": 0,
146
+ "position_embedding_type": "absolute",
147
+ "transformers_version": "4.6.0.dev0",
148
+ "type_vocab_size": 2,
149
+ "use_cache": True,
150
+ "vocab_size": 30522
151
+ })
152
+ self.mlm_transformer = BertEncoder(config)
153
+
154
+ self.lang_proj = nn.Linear(768,256)
155
+ self.mlm_vis_proj = nn.Conv2d(1024,512,1)
156
+ self.mlm_lang_proj = nn.Linear(768,512)
157
+ #print(vis_proj)
158
+ self.mlm_head = BertLMPredictionHead(config)
159
+
160
+ assert args.img_size % 4 == 0
161
+ num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2
162
+ print(num_img_tokens)
163
+ self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512)
164
+ self.mlm_modal_embeds = nn.Embedding(3, 512)
165
+
166
+ self.mlm_mask_embed = nn.Embedding(1, 512)
167
+ self.mlm_pos_mlp = nn.Sequential(
168
+ nn.Linear(2, 512),
169
+ nn.LayerNorm(512),
170
+ nn.Linear(512,512),
171
+ nn.GELU()
172
+ )
173
+
174
+ def _get_binary_mask(self, target):
175
+ # 返回每类的binary mask
176
+ y, x = target.size()
177
+ target_onehot = torch.zeros(self.num_classes + 1, y, x)
178
+ target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
179
+ return target_onehot[1:]
180
+
181
+ def semantic_inference(self, mask_cls, mask_pred):
182
+ mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
183
+ mask_pred = mask_pred.sigmoid()
184
+ semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
185
+ return semseg
186
+
187
+ def forward(self, image, sentences, attentions):
188
+ input_shape = image.shape[-2:]
189
+ l_mask = attentions.unsqueeze(dim=-1)
190
+
191
+ i0, Wh, Ww = self.image_model.forward_stem(image)
192
+ l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)
193
+
194
+ i1 = self.image_model.forward_stage1(i0, Wh, Ww)
195
+ l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
196
+ i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
197
+ l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask)
198
+ i1 = i1_temp
199
+
200
+ i2 = self.image_model.forward_stage2(i1, Wh, Ww)
201
+ l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
202
+ i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
203
+ l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask)
204
+ i2 = i2_temp
205
+
206
+ i3 = self.image_model.forward_stage3(i2, Wh, Ww)
207
+ l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
208
+ i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
209
+ l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask)
210
+ i3 = i3_temp
211
+
212
+ i4 = self.image_model.forward_stage4(i3, Wh, Ww)
213
+ l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
214
+ i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
215
+ l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask)
216
+ i4 = i4_temp
217
+
218
+ #i1_residual, i2_residual, i3_residual, i4_residual = features
219
+ #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
220
+ #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
221
+ outputs = {}
222
+ outputs['s1'] = i1_residual
223
+ outputs['s2'] = i2_residual
224
+ outputs['s3'] = i3_residual
225
+ outputs['s4'] = i4_residual
226
+
227
+ predictions, _ = self.classifier(outputs)
228
+ return predictions
229
+
230
+ def main(args):
231
+ #def main(local_rank, args):
232
+
233
+ #device = torch.device(args.device)
234
+ device = 'cuda'
235
+ dataset_test, _ = get_dataset(args.split, get_transform(args=args), args)
236
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
237
+ data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1,
238
+ sampler=test_sampler, num_workers=args.workers)
239
+ print(args.model)
240
+ single_model = multimodal_segmentation_ppm.__dict__[args.model](pretrained='',args=args)
241
+ #single_model = MultiModalFocal(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], focal_windows=[9,9,9,9], drop_path_rate=0.3)
242
+ #single_model.init_weights('./focalnet_base_lrf.pth')
243
+ checkpoint = torch.load(args.resume, map_location='cpu')
244
+ #single_model.load_state_dict(checkpoint['model'])
245
+ #model = single_model.to(device)
246
+
247
+ if args.model != 'lavt_one':
248
+ model_class = MultiModalBert
249
+ #single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=128)
250
+ single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=single_model.backbone.embed_dim)
251
+ # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines
252
+ if args.ddp_trained_weights:
253
+ single_bert_model.pooler = None
254
+ #single_bert_model.load_state_dict(checkpoint['bert_model'])
255
+ #bert_model = single_bert_model.to(device)
256
+ else:
257
+ bert_model = None
258
+
259
+ #model = WrapperModel(single_model.backbone, single_bert_model, single_model.classifier)
260
+ #model.load_state_dict(checkpoint['model'])
261
+ #model.to(device)
262
+ input_shape = dict()
263
+ input_shape['s1'] = Dict({'channel': 128, 'stride': 4})
264
+ input_shape['s2'] = Dict({'channel': 256, 'stride': 8})
265
+ input_shape['s3'] = Dict({'channel': 512, 'stride': 16})
266
+ input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})
267
+
268
+
269
+
270
+ cfg = Dict()
271
+ cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
272
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.0
273
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
274
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4
275
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
276
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
277
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]
278
+
279
+ cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
280
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
281
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1
282
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
283
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10
284
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
285
+
286
+
287
+ maskformer_head = MaskFormerHead(cfg, input_shape)
288
+ #maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head)
289
+ #maskformer_head.cuda()
290
+ #maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False)
291
+ #single_head = maskformer_head.module
292
+ #print(single_head)
293
+
294
+ model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head, args)
295
+ model.load_state_dict(checkpoint['model'])
296
+ model.to(device)
297
+ #model.cuda()
298
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
299
+ #single_model = model.module
300
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
301
+ #single_model = model.module
302
+ evaluate(model, data_loader_test, device=device)
303
+
304
+
305
+ if __name__ == "__main__":
306
+ from args import get_parser
307
+ parser = get_parser()
308
+ args = parser.parse_args()
309
+ print('Image size: {}'.format(str(args.img_size)))
310
+ print(args)
311
+ main(args)
312
+ #mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())
elia/test_lavt.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import time
4
+
5
+ import torch
6
+ import torch.utils.data
7
+ from torch import nn
8
+
9
+ from bert.modeling_bert import BertModel
10
+ import torchvision
11
+
12
+ from lib import segmentation
13
+ import transforms as T
14
+ import utils
15
+
16
+ import numpy as np
17
+ from PIL import Image
18
+ import torch.nn.functional as F
19
+
20
+
21
+ def get_dataset(image_set, transform, args):
22
+ from data.dataset_refer_bert import ReferDataset
23
+ ds = ReferDataset(args,
24
+ split=image_set,
25
+ image_transforms=transform,
26
+ target_transforms=None,
27
+ eval_mode=True
28
+ )
29
+ num_classes = 2
30
+ return ds, num_classes
31
+
32
+
33
+ def evaluate(model, data_loader, bert_model, device):
34
+ model.eval()
35
+ metric_logger = utils.MetricLogger(delimiter=" ")
36
+
37
+ # evaluation variables
38
+ cum_I, cum_U = 0, 0
39
+ eval_seg_iou_list = [.5, .6, .7, .8, .9]
40
+ seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
41
+ seg_total = 0
42
+ mean_IoU = []
43
+ header = 'Test:'
44
+
45
+ with torch.no_grad():
46
+ for data in metric_logger.log_every(data_loader, 100, header):
47
+ image, target, sentences, attentions = data
48
+ image, target, sentences, attentions = image.to(device), target.to(device), \
49
+ sentences.to(device), attentions.to(device)
50
+ sentences = sentences.squeeze(1)
51
+ attentions = attentions.squeeze(1)
52
+ target = target.cpu().data.numpy()
53
+ for j in range(sentences.size(-1)):
54
+ if bert_model is not None:
55
+ last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0]
56
+ embedding = last_hidden_states.permute(0, 2, 1)
57
+ output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1))
58
+ else:
59
+ output = model(image, sentences[:, :, j], l_mask=attentions[:, :, j])
60
+
61
+ output = output.cpu()
62
+ output_mask = output.argmax(1).data.numpy()
63
+ I, U = computeIoU(output_mask, target)
64
+ if U == 0:
65
+ this_iou = 0.0
66
+ else:
67
+ this_iou = I*1.0/U
68
+ mean_IoU.append(this_iou)
69
+ cum_I += I
70
+ cum_U += U
71
+ for n_eval_iou in range(len(eval_seg_iou_list)):
72
+ eval_seg_iou = eval_seg_iou_list[n_eval_iou]
73
+ seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou)
74
+ seg_total += 1
75
+
76
+ del image, target, sentences, attentions, output, output_mask
77
+ if bert_model is not None:
78
+ del last_hidden_states, embedding
79
+
80
+ mean_IoU = np.array(mean_IoU)
81
+ mIoU = np.mean(mean_IoU)
82
+ print('Final results:')
83
+ print('Mean IoU is %.2f\n' % (mIoU*100.))
84
+ results_str = ''
85
+ for n_eval_iou in range(len(eval_seg_iou_list)):
86
+ results_str += ' precision@%s = %.2f\n' % \
87
+ (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
88
+ results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
89
+ print(results_str)
90
+
91
+
92
+ def get_transform(args):
93
+ transforms = [T.Resize(args.img_size, args.img_size),
94
+ T.ToTensor(),
95
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
96
+ ]
97
+
98
+ return T.Compose(transforms)
99
+
100
+
101
+ def computeIoU(pred_seg, gd_seg):
102
+ I = np.sum(np.logical_and(pred_seg, gd_seg))
103
+ U = np.sum(np.logical_or(pred_seg, gd_seg))
104
+
105
+ return I, U
106
+
107
+
108
+ def main(args):
109
+ device = torch.device(args.device)
110
+ dataset_test, _ = get_dataset(args.split, get_transform(args=args), args)
111
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
112
+ data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1,
113
+ sampler=test_sampler, num_workers=args.workers)
114
+ print(args.model)
115
+ single_model = segmentation.__dict__[args.model](pretrained='',args=args)
116
+ checkpoint = torch.load(args.resume, map_location='cpu')
117
+ single_model.load_state_dict(checkpoint['model'])
118
+ model = single_model.to(device)
119
+
120
+ if args.model != 'lavt_one':
121
+ model_class = BertModel
122
+ single_bert_model = model_class.from_pretrained(args.ck_bert)
123
+ # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines
124
+ if args.ddp_trained_weights:
125
+ single_bert_model.pooler = None
126
+ single_bert_model.load_state_dict(checkpoint['bert_model'])
127
+ bert_model = single_bert_model.to(device)
128
+ else:
129
+ bert_model = None
130
+
131
+ evaluate(model, data_loader_test, bert_model, device=device)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ from args import get_parser
136
+ parser = get_parser()
137
+ args = parser.parse_args()
138
+ print('Image size: {}'.format(str(args.img_size)))
139
+ main(args)
elia/train_elia.py ADDED
@@ -0,0 +1,812 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import time
4
+
5
+ import torch
6
+ import torch.utils.data
7
+ from torch import nn
8
+
9
+ from functools import reduce
10
+ import operator
11
+ from bert.multimodal_bert import MultiModalBert
12
+
13
+ import torchvision
14
+ from lib import multimodal_segmentation_ppm
15
+
16
+ import transforms as T
17
+ import utils
18
+ import numpy as np
19
+
20
+ import torch.nn.functional as F
21
+
22
+ import gc
23
+ from collections import OrderedDict
24
+
25
+ import torch.backends.cudnn as cudnn
26
+
27
+ #from ffrecord.torch import DataLoader,Dataset
28
+ from modeling.MaskFormerModel import MaskFormerHead
29
+ from addict import Dict
30
+
31
+ from mask2former_utils.criterion import SetCriterion, Criterion
32
+ from mask2former_utils.matcher import HungarianMatcher
33
+ from bert.modeling_bert import BertLMPredictionHead, BertEncoder
34
+
35
+
36
+
37
+
38
+ class WrapperModel(nn.Module):
39
+ def __init__(self, image_model, language_model, classifier, args) :
40
+ super(WrapperModel, self).__init__()
41
+ self.image_model = image_model
42
+ self.language_model = language_model
43
+ self.classifier = classifier
44
+
45
+ self.lang_proj = nn.Linear(768,256)
46
+
47
+ config = Dict({
48
+ "architectures": [
49
+ "BertForMaskedLM"
50
+ ],
51
+ "attention_probs_dropout_prob": 0.1,
52
+ "gradient_checkpointing": False,
53
+ "hidden_act": "gelu",
54
+ "hidden_dropout_prob": 0.1,
55
+ "hidden_size": 512,
56
+ "initializer_range": 0.02,
57
+ "intermediate_size": 3072,
58
+ "layer_norm_eps": 1e-12,
59
+ #"max_position_embeddings": 16+20,
60
+ "model_type": "bert",
61
+ "num_attention_heads": 8,
62
+ "num_hidden_layers": 8,
63
+ "pad_token_id": 0,
64
+ "position_embedding_type": "absolute",
65
+ "transformers_version": "4.6.0.dev0",
66
+ "type_vocab_size": 2,
67
+ "use_cache": True,
68
+ "vocab_size": 30522
69
+ })
70
+ self.mlm_transformer = BertEncoder(config)
71
+
72
+ self.lang_proj = nn.Linear(768,256)
73
+ self.mlm_vis_proj = nn.Conv2d(1024,512,1)
74
+ self.mlm_lang_proj = nn.Linear(768,512)
75
+ #print(vis_proj)
76
+ self.mlm_head = BertLMPredictionHead(config)
77
+
78
+ assert args.img_size % 4 == 0
79
+ num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2
80
+ print(num_img_tokens)
81
+ self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512)
82
+ self.mlm_modal_embeds = nn.Embedding(3, 512)
83
+
84
+ self.mlm_mask_embed = nn.Embedding(1, 512)
85
+ self.mlm_pos_mlp = nn.Sequential(
86
+ nn.Linear(2, 512),
87
+ nn.LayerNorm(512),
88
+ nn.Linear(512,512),
89
+ nn.GELU()
90
+ )
91
+
92
+ def _get_binary_mask(self, target):
93
+ # 返回每类的binary mask
94
+ y, x = target.size()
95
+ target_onehot = torch.zeros(self.num_classes + 1, y, x)
96
+ target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
97
+ return target_onehot[1:]
98
+
99
+ def semantic_inference(self, mask_cls, mask_pred):
100
+ mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
101
+ mask_pred = mask_pred.sigmoid()
102
+ semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
103
+ return semseg
104
+
105
+ def forward(self, image, sentences, attentions, mlm_targets, mlm_masks, position):
106
+ input_shape = image.shape[-2:]
107
+ l_mask = attentions.unsqueeze(dim=-1)
108
+
109
+ i0, Wh, Ww = self.image_model.forward_stem(image)
110
+ l0, extended_attention_mask = self.language_model.forward_stem(mlm_targets.squeeze(1), attentions)
111
+
112
+ i1 = self.image_model.forward_stage1(i0, Wh, Ww)
113
+ l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
114
+ i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
115
+ l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask)
116
+ i1 = i1_temp
117
+
118
+ i2 = self.image_model.forward_stage2(i1, Wh, Ww)
119
+ l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
120
+ i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
121
+ l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask)
122
+ i2 = i2_temp
123
+
124
+ i3 = self.image_model.forward_stage3(i2, Wh, Ww)
125
+ l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
126
+ i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
127
+ l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask)
128
+ i3 = i3_temp
129
+
130
+ i4 = self.image_model.forward_stage4(i3, Wh, Ww)
131
+ l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
132
+ i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
133
+ l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask)
134
+ i4 = i4_temp
135
+
136
+ #i1_residual, i2_residual, i3_residual, i4_residual = features
137
+ #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
138
+ #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
139
+ outputs = {}
140
+ outputs['s1'] = i1_residual
141
+ outputs['s2'] = i2_residual
142
+ outputs['s3'] = i3_residual
143
+ outputs['s4'] = i4_residual
144
+
145
+ predictions, mask_features = self.classifier(outputs)
146
+
147
+ #print(target_reshape.shape)
148
+ #tmp = np.argwhere(target_reshape[:, 0].detach().cpu().numpy()).reshape(-1, target_reshape.shape[2]*target_reshape[3], 3)
149
+ #centroid = tmp.mean(1)
150
+ #print(centroid)
151
+ #centroid_x, centroid_y = int(centroid[1]), int(centroid[0])
152
+ #last_hidden_states = brt_model(sentences, attention_mask=attentions)[0] # (6, 10, 768)
153
+ #embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy
154
+
155
+
156
+ l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)
157
+ l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
158
+ l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
159
+ l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
160
+ l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
161
+
162
+
163
+ mlp_embed = self.mlm_pos_mlp(position)
164
+ #print(centroid_x, centroid_y)
165
+
166
+ mlm_targets = torch.where(
167
+ mlm_masks > 0,
168
+ mlm_targets,
169
+ torch.ones_like(mlm_targets) * (-1)
170
+ )
171
+
172
+ #print(x_c4[target_reshape[:, [0]].bool()].shape)
173
+ vis_features = self.mlm_vis_proj(i4_residual).flatten(2).permute(0,2,1)
174
+ #print(l4.shape)
175
+ lang_features = self.mlm_lang_proj(l4)
176
+
177
+ #print(lang_features.shape, vis_features.shape, mlp_embed.shape)
178
+ mm_features = torch.cat([lang_features, vis_features, mlp_embed.unsqueeze(1)], dim=1)
179
+ #print(mm_features.shape)
180
+
181
+ #print(mlm_modal_embeds.weight.shape)
182
+ modal_embeds = torch.cat([self.mlm_modal_embeds.weight[0].unsqueeze(0).repeat(1, lang_features.shape[1], 1), self.mlm_modal_embeds.weight[1].unsqueeze(0).repeat(1, vis_features.shape[1], 1), self.mlm_modal_embeds.weight[2].unsqueeze(0).repeat(1,1,1)], dim=1)
183
+ #print(modal_embeds.shape)
184
+
185
+ #print(mlm_transformer)
186
+
187
+
188
+ #print(attentions.shape)
189
+ mixed_attention_mask = torch.cat([attentions.unsqueeze(-1), torch.ones(attentions.shape[0], vis_features.shape[1]+1, 1).to(attentions.device)], dim=1)
190
+ mixed_attention_mask = mixed_attention_mask.permute(0,2,1).unsqueeze(1)
191
+ mixed_attention_mask = (1-mixed_attention_mask)* -10000.0
192
+ head_mask = [None] * 8
193
+ #extended_attention_mask = get_extended_attention_mask(mixed_attention_mask, mm_features.shape, mm_features.device)
194
+ #print(mm_features.shape, mixed_attention_mask.shape, head_mask)
195
+ #print(mm_features.shape, self.mlm_pos_embeds.weight.shape, self.mlm_modal_embeds.weight.shape)
196
+ head_features = self.mlm_transformer(mm_features + self.mlm_pos_embeds.weight.unsqueeze(0) + modal_embeds, mixed_attention_mask, head_mask)[0]
197
+ #print(head_features.shape, attentions.shape)
198
+ head_features = head_features[:, :20][attentions.bool()]
199
+
200
+ #print(embedding.shape, mask_features.shape)
201
+ mlm_predictions = self.mlm_head(head_features)
202
+ mlm_predictions = mlm_predictions.reshape(-1, self.language_model.config.vocab_size)
203
+ mlm_targets = mlm_targets.squeeze(1)[attentions.bool()]
204
+ #mlm_loss = mlm_weight * nn.CrossEntropyLoss(ignore_index=-1)(mlm_predictions, mlm_targets)
205
+ #loss += mlm_loss
206
+ #mlm_loss_print=mlm_loss.item()
207
+
208
+ return predictions, mask_features, self.lang_proj((l4_residual * l_mask).sum(1)/l_mask.sum(1)), mlm_predictions, mlm_targets
209
+ # IoU calculation for validation
210
+ def IoU(pred, gt):
211
+ #pred = pred.argmax(1)
212
+ pred = (pred > 0.5)
213
+
214
+ intersection = torch.sum(torch.mul(pred, gt))
215
+ union = torch.sum(torch.add(pred, gt)) - intersection
216
+
217
+ if intersection == 0 or union == 0:
218
+ iou = 0
219
+ else:
220
+ iou = float(intersection) / float(union)
221
+
222
+ return iou, intersection, union
223
+
224
+ def get_dataset(image_set, transform, args):
225
+ from data.dataset_refer_bert_mlm import ReferDataset
226
+ ds = ReferDataset(args,
227
+ split=image_set,
228
+ image_transforms=transform,
229
+ target_transforms=None
230
+ )
231
+ num_classes = 2
232
+
233
+ return ds, num_classes
234
+
235
+
236
+
237
+ def get_transform(args):
238
+ transforms = [T.Resize(args.img_size, args.img_size),
239
+ T.ToTensor(),
240
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
241
+ ]
242
+
243
+ return T.Compose(transforms)
244
+
245
+
246
+ #def criterion(input, target):
247
+ # weight = torch.FloatTensor([0.9, 1.1]).cuda()
248
+ # return nn.functional.cross_entropy(input, target, weight=weight)
249
+
250
+
251
+ def evaluate(model, data_loader):
252
+ model.eval()
253
+ metric_logger = utils.MetricLogger(delimiter=" ")
254
+ header = 'Test:'
255
+ total_its = 0
256
+ acc_ious = 0
257
+
258
+ # evaluation variables
259
+ cum_I, cum_U = 0, 0
260
+ eval_seg_iou_list = [.5, .6, .7, .8, .9]
261
+ seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
262
+ seg_total = 0
263
+ mean_IoU = []
264
+
265
+ with torch.no_grad():
266
+ for data in metric_logger.log_every(data_loader, 100, header):
267
+ total_its += 1
268
+ #image, target, sentences, attentions = data
269
+ #image, target, sentences, attentions = image.cuda(non_blocking=True),\
270
+ # target.cuda(non_blocking=True),\
271
+ # sentences.cuda(non_blocking=True),\
272
+ # attentions.cuda(non_blocking=True)
273
+
274
+ image, target, sentences, attentions, mlm_targets, mlm_masks, position = data
275
+ image, target, sentences, attentions, mlm_targets, mlm_masks, position = image.cuda(non_blocking=True),\
276
+ target.cuda(non_blocking=True),\
277
+ sentences.cuda(non_blocking=True),\
278
+ attentions.cuda(non_blocking=True), \
279
+ mlm_targets.cuda(non_blocking=True), \
280
+ mlm_masks.cuda(non_blocking=True), \
281
+ position.cuda(non_blocking=True)
282
+
283
+ sentences = sentences.squeeze(1)
284
+ attentions = attentions.squeeze(1)
285
+ #print("sentences", sentences.shape)
286
+ #print("attentions", attentions.shape)
287
+
288
+
289
+ output, mask_features, avg_lang_feature, mlm_predictions, mlm_targets = model(image, sentences, attentions, mlm_targets, mlm_masks, position)
290
+ mask_cls_results = output["pred_logits"]
291
+ mask_pred_results = output["pred_masks"]
292
+
293
+ target_shape = target.shape[-2:]
294
+ mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True)
295
+
296
+ pred_masks = model.module.semantic_inference(mask_cls_results, mask_pred_results)
297
+ output = pred_masks[0]
298
+
299
+
300
+ iou, I, U = IoU(output, target)
301
+ acc_ious += iou
302
+ mean_IoU.append(iou)
303
+ cum_I += I
304
+ cum_U += U
305
+ for n_eval_iou in range(len(eval_seg_iou_list)):
306
+ eval_seg_iou = eval_seg_iou_list[n_eval_iou]
307
+ seg_correct[n_eval_iou] += (iou >= eval_seg_iou)
308
+ seg_total += 1
309
+ iou = acc_ious / total_its
310
+
311
+ mean_IoU = np.array(mean_IoU)
312
+ mIoU = np.mean(mean_IoU)
313
+ print('Final results:')
314
+ print('Mean IoU is %.2f\n' % (mIoU * 100.))
315
+ results_str = ''
316
+ for n_eval_iou in range(len(eval_seg_iou_list)):
317
+ results_str += ' precision@%s = %.2f\n' % \
318
+ (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
319
+ results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
320
+ print(results_str)
321
+
322
+ return 100 * iou, 100 * cum_I / cum_U
323
+
324
+
325
+ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq,
326
+ iterations, args):
327
+ model.train()
328
+ metric_logger = utils.MetricLogger(delimiter=" ")
329
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
330
+ header = 'Epoch: [{}]'.format(epoch)
331
+ train_loss = 0
332
+ total_its = 0
333
+
334
+ for data in metric_logger.log_every(data_loader, print_freq, header):
335
+ total_its += 1
336
+ #image, target, sentences, attentions = data
337
+ #image, target, sentences, attentions = image.cuda(non_blocking=True),\
338
+ # target.cuda(non_blocking=True),\
339
+ # sentences.cuda(non_blocking=True),\
340
+ # attentions.cuda(non_blocking=True)
341
+ image, target, sentences, attentions, mlm_targets, mlm_masks, position = data
342
+ image, target, sentences, attentions, mlm_targets, mlm_masks, position = image.cuda(non_blocking=True),\
343
+ target.cuda(non_blocking=True),\
344
+ sentences.cuda(non_blocking=True),\
345
+ attentions.cuda(non_blocking=True), \
346
+ mlm_targets.cuda(non_blocking=True), \
347
+ mlm_masks.cuda(non_blocking=True), \
348
+ position.cuda(non_blocking=True)
349
+
350
+ sentences = sentences.squeeze(1)
351
+ attentions = attentions.squeeze(1)
352
+ #l_mask = attentions.unsqueeze(dim=-1)
353
+
354
+ output, mask_features, avg_lang_feature, mlm_predictions, mlm_targets = model(image, sentences, attentions, mlm_targets, mlm_masks, position)
355
+ #print(avg_lang_feature.shape)
356
+ avg_lang_feature = torch.nn.functional.normalize(avg_lang_feature, dim=1)
357
+ #print("----")
358
+ #print(output.shape)
359
+ #print(mask_features.shape)
360
+ #print(avg_lang_feature.shape)
361
+ #print( mlm_predictions.shape)
362
+ #print(mlm_targets.shape)
363
+ #print("----")
364
+
365
+ target_shape = target.shape[-2:]
366
+ output['pred_masks'] = F.interpolate(output['pred_masks'], size=target_shape, mode='bilinear', align_corners=True)
367
+
368
+ if "aux_outputs" in output:
369
+ for i, aux_outputs in enumerate(output["aux_outputs"]):
370
+ output['aux_outputs'][i]['pred_masks'] = F.interpolate(output['aux_outputs'][i]['pred_masks'], size=target_shape, mode='bilinear', align_corners=True)
371
+
372
+ # pixel region
373
+ B, C, H, W = mask_features.shape
374
+
375
+ target_reshape = F.interpolate(target.unsqueeze(1).float(), size=mask_features.shape[-2:], mode='nearest').long()
376
+
377
+ target_reshape = target_reshape.repeat(1, mask_features.shape[1], 1, 1)
378
+ #print(avg_pos_feature.shape, avg_lang_feature.shape, avg_neg_feature.shape)
379
+
380
+ #cl_loss = 0.0
381
+ plic_lang_loss = 0.0
382
+ plic_pos_loss = 0.0
383
+ plic_neg_loss = 0.0
384
+ for i in range(B):
385
+ if ((target_reshape[[i]] == 0).sum() != 0 and (target_reshape[[i]] == 1).sum() != 0):
386
+
387
+ avg_pos_feature = (mask_features[[i]] * target_reshape[[i]]).sum(-1).sum(-1) / target_reshape[[i]].sum(-1).sum(-1)
388
+ avg_neg_feature = (mask_features[[i]] * (1.0-target_reshape[[i]])).sum(-1).sum(-1) / (1.0-target_reshape[[i]]).sum(-1).sum(-1)
389
+ avg_pos_feature = torch.nn.functional.normalize(avg_pos_feature, dim=1)
390
+ avg_neg_feature = torch.nn.functional.normalize(avg_neg_feature, dim=1)
391
+
392
+ #avg lang feature no normalize???
393
+
394
+
395
+
396
+ pos_features = mask_features[[i]][target_reshape[[i]]==1].view(1, C, -1)
397
+ neg_features = mask_features[[i]][target_reshape[[i]]==0].view(1, C, -1)
398
+ #inter_neg_features = mask_features[[B-i-1]][target_reshape[[B-i-1]]==1].view(1, C, -1)
399
+ #neg_features = torch.cat([intra_neg_features, inter_neg_features], dim=2)
400
+
401
+ pos_features = torch.nn.functional.normalize(pos_features, dim=1)
402
+ neg_features = torch.nn.functional.normalize(neg_features, dim=1)
403
+
404
+ #print(avg_lang_feature.shape, avg_lang_feature[[i]].shape, pos_features.shape)
405
+ lang_pos_scores = torch.einsum("bq,bqn->bn", avg_lang_feature[[i]], pos_features)
406
+ lang_neg_scores = torch.einsum("bq,bqn->bn", avg_lang_feature[[i]], neg_features)
407
+
408
+ lang_matrix = torch.cat([lang_pos_scores.unsqueeze(-1), lang_neg_scores.unsqueeze(1).repeat(1, lang_pos_scores.shape[1], 1)], dim=2)
409
+ lang_labels = torch.zeros(lang_matrix.shape[1], dtype=torch.long).cuda()
410
+ lang_labels = lang_labels.unsqueeze(0).repeat(lang_matrix.shape[0], 1)
411
+
412
+ lang_score = torch.softmax(lang_matrix, -1)
413
+ lang_score = 1.0 - lang_score[:, :, 0]
414
+
415
+ pos_pos_scores = torch.einsum("bq,bqn->bn", avg_pos_feature, pos_features)
416
+ pos_neg_scores = torch.einsum("bqn,bqm->bnm", pos_features, neg_features)
417
+
418
+ pos_matrix = torch.cat([pos_pos_scores.unsqueeze(-1), pos_neg_scores], dim=2)
419
+ pos_labels = torch.zeros(pos_matrix.shape[1], dtype=torch.long).cuda()
420
+ pos_labels = pos_labels.unsqueeze(0).repeat(pos_matrix.shape[0], 1)
421
+
422
+ pos_score = torch.softmax(pos_matrix, -1)
423
+ pos_score = 1.0 - pos_score[:, :, 0]
424
+ #pos_weight = pos_weight.view(-1, pos_weight.shape[-1])
425
+
426
+ #intra_neg_features = torch.nn.functional.normalize(intra_neg_features, dim=1)
427
+ neg_neg_scores = torch.einsum("bq,bqn->bn", avg_neg_feature, neg_features)
428
+ neg_pos_scores = torch.einsum("bqn,bqm->bnm", neg_features, pos_features)
429
+
430
+ neg_matrix = torch.cat([neg_neg_scores.unsqueeze(-1), neg_pos_scores], dim=2)
431
+ neg_labels = torch.zeros(neg_matrix.shape[1], dtype=torch.long).cuda()
432
+ neg_labels = neg_labels.unsqueeze(0).repeat(neg_matrix.shape[0], 1)
433
+
434
+ neg_score = torch.softmax(neg_matrix, -1)
435
+ neg_score = 1.0 - neg_score[:, :, 0]
436
+ #neg_weight = neg_weight.view(-1, neg_weight.shape[-1])
437
+
438
+ pos_loss = (torch.pow(pos_score, args.plic_pos_alpha) * torch.nn.functional.cross_entropy(pos_matrix.view(-1, pos_matrix.shape[-1])/args.plic_pos_temp, pos_labels.view(-1), reduction='none')).mean()
439
+ neg_loss = (torch.pow(neg_score, args.plic_neg_alpha) * torch.nn.functional.cross_entropy(neg_matrix.view(-1, neg_matrix.shape[-1])/args.plic_neg_temp, neg_labels.view(-1), reduction='none')).mean()
440
+
441
+ lang_loss = (torch.pow(lang_score, args.plic_lang_alpha) * torch.nn.functional.cross_entropy(lang_matrix.view(-1, lang_matrix.shape[-1])/args.plic_lang_temp, lang_labels.view(-1), reduction='none')).mean()
442
+
443
+ plic_pos_loss += pos_loss
444
+ plic_neg_loss += neg_loss
445
+ plic_lang_loss += lang_loss
446
+ #cl_loss += 0.5 * (torch.nn.functional.cross_entropy(pos_matrix.view(-1, pos_matrix.shape[-1])/cl_temp, pos_labels.view(-1))+torch.nn.functional.cross_entropy(neg_matrix.view(-1, neg_matrix.shape[-1])/cl_temp, neg_labels.view(-1)))
447
+ plic_pos_loss = (args.plic_pos_weight * plic_pos_loss) / B
448
+ plic_neg_loss = (args.plic_neg_weight * plic_neg_loss) / B
449
+ plic_lang_loss = (args.plic_lang_weight * plic_lang_loss) / B
450
+ plic_loss = plic_pos_loss + plic_neg_loss +plic_lang_loss
451
+
452
+
453
+ #print(output.device, target.device)
454
+ losses = criterion(output, target)
455
+ weight_dict = criterion.weight_dict
456
+
457
+ loss_ce = 0.0
458
+ loss_dice = 0.0
459
+ loss_mask = 0.0
460
+ for k in list(losses.keys()):
461
+ if k in weight_dict:
462
+ losses[k] *= criterion.weight_dict[k]
463
+ if '_ce' in k:
464
+ loss_ce += losses[k]
465
+ elif '_dice' in k:
466
+ loss_dice += losses[k]
467
+ else:
468
+ loss_mask += losses[k]
469
+ else:
470
+ # remove this loss if not specified in `weight_dict`
471
+ losses.pop(k)
472
+ #loss = 0.3 * loss_ce + 0.3 * loss_dice + 0.4 * loss_mask
473
+ smlm_loss = args.smlm_weight * nn.CrossEntropyLoss(ignore_index=-1)(mlm_predictions, mlm_targets)
474
+ loss = loss_ce + loss_dice + loss_mask + plic_loss + smlm_loss
475
+
476
+
477
+ #loss = criterion(output.squeeze(1), target.float())
478
+ optimizer.zero_grad() # set_to_none=True is only available in pytorch 1.6+
479
+ loss.backward()
480
+ optimizer.step()
481
+ lr_scheduler.step()
482
+
483
+ torch.cuda.synchronize()
484
+ train_loss += loss.item()
485
+ iterations += 1
486
+ #metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
487
+ metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"], loss_ce=loss_ce.item(), loss_dice=loss_dice.item(), loss_mask=loss_mask.item(), plic_loss=plic_loss.item(), plic_lang_loss=plic_lang_loss.item(), plic_pos_loss=plic_pos_loss.item(), plic_neg_loss=plic_neg_loss.item(), smlm_loss=smlm_loss.item())
488
+ #metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"], loss_ce=loss_ce.item(), loss_dice=loss_dice.item(), loss_mask=loss_mask.item(), cl_loss=cl_loss.item(), cl_lang_loss=cl_lang_loss_print, cl_pos_loss=cl_pos_loss_print, cl_neg_loss=cl_neg_loss_print)
489
+
490
+ #del image, target, sentences, attentions, loss, output, data
491
+ #if bert_model is not None:
492
+ # del last_hidden_states, embedding
493
+
494
+ #gc.collect()
495
+ #torch.cuda.empty_cache()
496
+ #del loss
497
+ #del cl_loss
498
+ #del cl_lang_loss
499
+ #del loss_ce
500
+ #del loss_dice
501
+ #del loss_mask
502
+ torch.cuda.synchronize()
503
+
504
+
505
+ def main(args):
506
+ #def main(local_rank, args):
507
+ #ip = os.environ['MASTER_IP']
508
+ #port = os.environ['MASTER_PORT']
509
+ #hosts = int(os.environ['WORLD_SIZE']) # 机器个数 1
510
+ #rank = int(os.environ['RANK']) # 当前机器编号
511
+ #gpus = torch.cuda.device_count() # 每台机器的GPU个数
512
+ #print(local_rank, rank, gpus) #3 0 8
513
+ #dist.init_process_group(backend='nccl', init_method=f'tcp://{ip}:{port}', world_size=hosts*gpus, rank=rank*gpus+local_rank)
514
+ #torch.cuda.set_device(local_rank)
515
+ #dist.barrier()
516
+
517
+ ##utils.init_distributed_mode(args)
518
+ #args.distributed=True
519
+ #args.gpu = local_rank
520
+ #print(args)
521
+ ##misc.init_distributed_mode(args)
522
+
523
+ #print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
524
+ #print("{}".format(args).replace(', ', ',\n'))
525
+
526
+ #device = torch.device(args.device)
527
+
528
+ # fix the seed for reproducibility
529
+ seed = args.seed + utils.get_rank()
530
+ print('seed', seed)
531
+ torch.manual_seed(seed)
532
+ np.random.seed(seed)
533
+
534
+ #cudnn.benchmark = True
535
+
536
+ dataset, num_classes = get_dataset("train",
537
+ get_transform(args=args),
538
+ args=args)
539
+ dataset_test, _ = get_dataset("val",
540
+ get_transform(args=args),
541
+ args=args)
542
+
543
+ # batch sampler
544
+ print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.")
545
+ num_tasks = utils.get_world_size()
546
+ global_rank = utils.get_rank()
547
+ #num_tasks = hosts*gpus
548
+ #global_rank = rank*gpus+local_rank
549
+ train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank,
550
+ shuffle=True)
551
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
552
+
553
+ # data loader
554
+ data_loader = torch.utils.data.DataLoader(
555
+ dataset, batch_size=args.batch_size,
556
+ sampler=train_sampler, num_workers=args.workers, pin_memory=True, drop_last=True)
557
+
558
+ data_loader_test = torch.utils.data.DataLoader(
559
+ dataset_test, batch_size=1, sampler=test_sampler, pin_memory=True, num_workers=args.workers)
560
+
561
+ # model initialization
562
+ print(args.model)
563
+ model = multimodal_segmentation_ppm.__dict__[args.model](pretrained=args.pretrained_swin_weights,
564
+ args=args)
565
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
566
+ #model.cuda()
567
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True)
568
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=False)
569
+ #single_model = model.module
570
+
571
+ if args.model != 'lavt_one':
572
+ model_class = MultiModalBert
573
+ bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=model.backbone.embed_dim)
574
+ bert_model.pooler = None # a work-around for a bug in Transformers = 3.0.2 that appears for DistributedDataParallel
575
+ #bert_model.cuda()
576
+ bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model)
577
+ #bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[local_rank])
578
+ #single_bert_model = bert_model.module
579
+ else:
580
+ bert_model = None
581
+ single_bert_model = None
582
+
583
+ input_shape = dict()
584
+ input_shape['s1'] = Dict({'channel': 128, 'stride': 4})
585
+ input_shape['s2'] = Dict({'channel': 256, 'stride': 8})
586
+ input_shape['s3'] = Dict({'channel': 512, 'stride': 16})
587
+ input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})
588
+
589
+
590
+
591
+ cfg = Dict()
592
+ cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
593
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.0
594
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
595
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = args.transformer_enc_layers
596
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
597
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
598
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]
599
+
600
+ cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
601
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
602
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = args.num_object_queries
603
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = args.dim_feedforward
604
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = args.dec_layers
605
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
606
+
607
+ cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
608
+ cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = args.no_object_weight
609
+ cfg.MODEL.MASK_FORMER.CLASS_WEIGHT = args.class_weight
610
+ cfg.MODEL.MASK_FORMER.DICE_WEIGHT = args.dice_weight
611
+ cfg.MODEL.MASK_FORMER.MASK_WEIGHT = args.mask_weight
612
+
613
+ cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS = args.train_num_points
614
+ cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO = 3.0
615
+ cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75
616
+ print(cfg)
617
+
618
+ maskformer_head = MaskFormerHead(cfg, input_shape)
619
+ maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head)
620
+ #maskformer_head.cuda()
621
+ #maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False)
622
+ #single_head = maskformer_head.module
623
+ #print(single_head)
624
+
625
+ model = WrapperModel(model.backbone, bert_model, maskformer_head, args)
626
+ model.cuda()
627
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
628
+ single_model = model.module
629
+
630
+ # mask2former loss
631
+ deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
632
+ no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
633
+
634
+ # loss weights
635
+ class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
636
+ dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
637
+ mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
638
+ # self.criterion = Criterion(self.num_classes)
639
+
640
+ # building criterion
641
+
642
+ matcher = HungarianMatcher(
643
+ cost_class=class_weight,
644
+ cost_mask=mask_weight,
645
+ cost_dice=dice_weight,
646
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
647
+ )
648
+
649
+ weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}
650
+ if deep_supervision:
651
+ dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
652
+ aux_weight_dict = {}
653
+ for i in range(dec_layers - 1):
654
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
655
+ weight_dict.update(aux_weight_dict)
656
+
657
+ losses = ["labels", "masks"]
658
+ criterion = SetCriterion(
659
+ cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
660
+ matcher=matcher,
661
+ weight_dict=weight_dict,
662
+ eos_coef=no_object_weight,
663
+ losses=losses,
664
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
665
+ oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
666
+ importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
667
+ device='cuda'
668
+ )
669
+
670
+ if args.resume == "auto":
671
+ last_ckpt = ""
672
+ for e in range(args.epochs):
673
+ ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth')
674
+ if os.path.exists(ckpt_path):
675
+ last_ckpt = ckpt_path
676
+ args.resume = last_ckpt
677
+
678
+ # resume training
679
+ if args.resume:
680
+ checkpoint = torch.load(args.resume, map_location='cpu')
681
+ single_model.load_state_dict(checkpoint['model'])
682
+ #if args.model != 'lavt_one':
683
+ # single_bert_model.load_state_dict(checkpoint['bert_model'])
684
+
685
+ # parameters to optimize
686
+ backbone_no_decay = list()
687
+ backbone_decay = list()
688
+ for name, m in single_model.image_model.named_parameters():
689
+ if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name:
690
+ backbone_no_decay.append(m)
691
+ else:
692
+ backbone_decay.append(m)
693
+
694
+ params_to_optimize = [
695
+ {'params': backbone_no_decay, 'weight_decay': 0.0},
696
+ {'params': backbone_decay},
697
+ {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]},
698
+ # the following are the parameters of bert
699
+ {"params": reduce(operator.concat,
700
+ [[p for p in single_model.language_model.encoder.layer[i].parameters()
701
+ if p.requires_grad] for i in range(10)])},
702
+ {"params": single_model.language_model.pwams.parameters()},
703
+ {"params": single_model.language_model.res_gates.parameters()},
704
+ {"params": single_model.language_model.norms.parameters()},
705
+ {"params": single_model.lang_proj.parameters()},
706
+ #{"params": single_model.language_model.parameters()},
707
+ {'params': single_model.mlm_head.parameters()},
708
+ {'params': single_model.mlm_vis_proj.parameters()},
709
+ {'params': single_model.mlm_lang_proj.parameters()},
710
+ {'params': single_model.mlm_transformer.parameters()},
711
+ {'params': single_model.mlm_pos_embeds.parameters()},
712
+ {'params': single_model.mlm_modal_embeds.parameters()},
713
+ {'params': single_model.mlm_mask_embed.parameters()},
714
+ {'params': single_model.mlm_pos_mlp.parameters()},
715
+ #{'params': mlm_head.parameters(), 'weight_decay': 0.0},
716
+ #{'params': mlm_vis_proj.parameters(), 'weight_decay': 0.0},
717
+ #{'params': mlm_lang_proj.parameters(), 'weight_decay': 0.0},
718
+ #{'params': mlm_transformer.parameters(), 'weight_decay': 0.0},
719
+ #{'params': mlm_pos_embeds.parameters(), 'weight_decay': 0.0},
720
+ #{'params': mlm_modal_embeds.parameters(), 'weight_decay': 0.0},
721
+ #{'params': mlm_mask_embed.parameters(), 'weight_decay': 0.0},
722
+ #{'params': mlm_pos_mlp.parameters(), 'weight_decay': 0.0},
723
+ ]
724
+
725
+
726
+ # optimizer
727
+ optimizer = torch.optim.AdamW(params_to_optimize,
728
+ lr=args.lr,
729
+ weight_decay=args.weight_decay,
730
+ amsgrad=args.amsgrad
731
+ )
732
+
733
+ # learning rate scheduler
734
+ lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
735
+ lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
736
+
737
+ # housekeeping
738
+ start_time = time.time()
739
+ iterations = 0
740
+ best_oIoU = -0.1
741
+
742
+ # resume training (optimizer, lr scheduler, and the epoch)
743
+ if args.resume:
744
+ optimizer.load_state_dict(checkpoint['optimizer'])
745
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
746
+ resume_epoch = checkpoint['epoch']
747
+ else:
748
+ resume_epoch = -999
749
+
750
+ # training loops
751
+ for epoch in range(max(0, resume_epoch+1), args.epochs):
752
+ data_loader.sampler.set_epoch(epoch)
753
+ train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq,
754
+ iterations, args)
755
+ iou, overallIoU = evaluate(model, data_loader_test)
756
+
757
+ print('Average object IoU {}'.format(iou))
758
+ print('Overall IoU {}'.format(overallIoU))
759
+
760
+
761
+ dict_to_save = {'model': single_model.state_dict(),
762
+ 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
763
+ 'lr_scheduler': lr_scheduler.state_dict()}
764
+
765
+ checkpoint_path = os.path.join(args.output_dir, 'checkpoint-{}.pth'.format(epoch))
766
+ utils.save_on_master(dict_to_save, str(checkpoint_path) + '_TEMP')
767
+ if utils.is_main_process():
768
+ os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path))
769
+
770
+ if utils.is_main_process():
771
+ ckpt_paths = []
772
+ for e in range(args.epochs):
773
+ ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth')
774
+ print(ckpt_path)
775
+ if os.path.exists(ckpt_path):
776
+ ckpt_paths.append(ckpt_path)
777
+ print(ckpt_paths)
778
+ for ckpt_path in ckpt_paths[:-args.max_ckpt]:
779
+ os.remove(ckpt_path)
780
+ print("remove {:s}".format(ckpt_path))
781
+
782
+
783
+ save_checkpoint = (best_oIoU < overallIoU)
784
+ if save_checkpoint:
785
+ print('Better epoch: {}\n'.format(epoch))
786
+ dict_to_save = {'model': single_model.state_dict(),
787
+ 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
788
+ 'lr_scheduler': lr_scheduler.state_dict()}
789
+
790
+ checkpoint_path = os.path.join(args.output_dir, 'model_best_{}.pth'.format(args.model_id))
791
+ utils.save_on_master(dict_to_save, checkpoint_path + '_TEMP')
792
+ if utils.is_main_process():
793
+ os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path))
794
+ best_oIoU = overallIoU
795
+ torch.cuda.empty_cache()
796
+
797
+ # summarize
798
+ total_time = time.time() - start_time
799
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
800
+ print('Training time {}'.format(total_time_str))
801
+
802
+
803
+ if __name__ == "__main__":
804
+ from args import get_parser
805
+ parser = get_parser()
806
+ args = parser.parse_args()
807
+ os.makedirs(args.output_dir, exist_ok=True)
808
+ # set up distributed learning
809
+ utils.init_distributed_mode(args)
810
+ print('Image size: {}'.format(str(args.img_size)))
811
+ main(args)
812
+ #mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())
elia/train_lavt.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import haienv
2
+ haienv.set_env('lavt2')
3
+ import torch.multiprocessing as mp
4
+ import torch.distributed as dist
5
+
6
+ import datetime
7
+ import os
8
+ import time
9
+
10
+ import torch
11
+ import torch.utils.data
12
+ from torch import nn
13
+
14
+ from functools import reduce
15
+ import operator
16
+ from bert.modeling_bert import BertModel
17
+
18
+ import torchvision
19
+ from lib import segmentation
20
+
21
+ import transforms as T
22
+ import utils
23
+ import numpy as np
24
+
25
+ import torch.nn.functional as F
26
+
27
+ import gc
28
+ from collections import OrderedDict
29
+
30
+ import torch.backends.cudnn as cudnn
31
+
32
+ from ffrecord.torch import DataLoader,Dataset
33
+ def get_dataset(image_set, transform, args):
34
+ from data.dataset_refer_bert import ReferDataset
35
+ ds = ReferDataset(args,
36
+ split=image_set,
37
+ image_transforms=transform,
38
+ target_transforms=None
39
+ )
40
+ num_classes = 2
41
+
42
+ return ds, num_classes
43
+
44
+
45
+ # IoU calculation for validation
46
+ def IoU(pred, gt):
47
+ pred = pred.argmax(1)
48
+
49
+ intersection = torch.sum(torch.mul(pred, gt))
50
+ union = torch.sum(torch.add(pred, gt)) - intersection
51
+
52
+ if intersection == 0 or union == 0:
53
+ iou = 0
54
+ else:
55
+ iou = float(intersection) / float(union)
56
+
57
+ return iou, intersection, union
58
+
59
+
60
+ def get_transform(args):
61
+ transforms = [T.Resize(args.img_size, args.img_size),
62
+ T.ToTensor(),
63
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
64
+ ]
65
+
66
+ return T.Compose(transforms)
67
+
68
+
69
+ def criterion(input, target):
70
+ weight = torch.FloatTensor([0.9, 1.1]).cuda()
71
+ return nn.functional.cross_entropy(input, target, weight=weight)
72
+
73
+
74
+ def evaluate(model, data_loader, bert_model):
75
+ model.eval()
76
+ metric_logger = utils.MetricLogger(delimiter=" ")
77
+ header = 'Test:'
78
+ total_its = 0
79
+ acc_ious = 0
80
+
81
+ # evaluation variables
82
+ cum_I, cum_U = 0, 0
83
+ eval_seg_iou_list = [.5, .6, .7, .8, .9]
84
+ seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
85
+ seg_total = 0
86
+ mean_IoU = []
87
+
88
+ with torch.no_grad():
89
+ for data in metric_logger.log_every(data_loader, 100, header):
90
+ total_its += 1
91
+ image, target, sentences, attentions = data
92
+ image, target, sentences, attentions = image.cuda(non_blocking=True),\
93
+ target.cuda(non_blocking=True),\
94
+ sentences.cuda(non_blocking=True),\
95
+ attentions.cuda(non_blocking=True)
96
+
97
+ sentences = sentences.squeeze(1)
98
+ attentions = attentions.squeeze(1)
99
+ #print("sentences", sentences.shape)
100
+ #print("attentions", attentions.shape)
101
+
102
+ if bert_model is not None:
103
+ last_hidden_states = bert_model(sentences, attention_mask=attentions)[0]
104
+ #print("last hidden states", last_hidden_states.shape)
105
+ embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy
106
+ attentions = attentions.unsqueeze(dim=-1) # (B, N_l, 1)
107
+ output = model(image, embedding, l_mask=attentions)
108
+ else:
109
+ output = model(image, sentences, l_mask=attentions)
110
+
111
+ iou, I, U = IoU(output, target)
112
+ acc_ious += iou
113
+ mean_IoU.append(iou)
114
+ cum_I += I
115
+ cum_U += U
116
+ for n_eval_iou in range(len(eval_seg_iou_list)):
117
+ eval_seg_iou = eval_seg_iou_list[n_eval_iou]
118
+ seg_correct[n_eval_iou] += (iou >= eval_seg_iou)
119
+ seg_total += 1
120
+ iou = acc_ious / total_its
121
+
122
+ mean_IoU = np.array(mean_IoU)
123
+ mIoU = np.mean(mean_IoU)
124
+ print('Final results:')
125
+ print('Mean IoU is %.2f\n' % (mIoU * 100.))
126
+ results_str = ''
127
+ for n_eval_iou in range(len(eval_seg_iou_list)):
128
+ results_str += ' precision@%s = %.2f\n' % \
129
+ (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
130
+ results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
131
+ print(results_str)
132
+
133
+ return 100 * iou, 100 * cum_I / cum_U
134
+
135
+
136
+ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq,
137
+ iterations, bert_model):
138
+ model.train()
139
+ metric_logger = utils.MetricLogger(delimiter=" ")
140
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
141
+ header = 'Epoch: [{}]'.format(epoch)
142
+ train_loss = 0
143
+ total_its = 0
144
+
145
+ for data in metric_logger.log_every(data_loader, print_freq, header):
146
+ total_its += 1
147
+ image, target, sentences, attentions = data
148
+ image, target, sentences, attentions = image.cuda(non_blocking=True),\
149
+ target.cuda(non_blocking=True),\
150
+ sentences.cuda(non_blocking=True),\
151
+ attentions.cuda(non_blocking=True)
152
+
153
+ sentences = sentences.squeeze(1)
154
+ attentions = attentions.squeeze(1)
155
+ #print(sentences.shape, attentions.shape, target.shape)
156
+ #print(sentences)
157
+ #print('a', sentences.shape)
158
+ #print('b', attentions.shape)
159
+
160
+ if bert_model is not None:
161
+ last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] # (6, 10, 768)
162
+ #print('c', last_hidden_states.shape)
163
+
164
+ embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy
165
+ #print('e', embedding.shape)
166
+ attentions = attentions.unsqueeze(dim=-1) # (batch, N_l, 1)
167
+ #print('f', attentions.shape)
168
+ output = model(image, embedding, l_mask=attentions)
169
+ else:
170
+ output = model(image, sentences, l_mask=attentions)
171
+
172
+ loss = criterion(output, target)
173
+ optimizer.zero_grad() # set_to_none=True is only available in pytorch 1.6+
174
+ loss.backward()
175
+ optimizer.step()
176
+ lr_scheduler.step()
177
+
178
+ torch.cuda.synchronize()
179
+ train_loss += loss.item()
180
+ iterations += 1
181
+ metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
182
+
183
+ del image, target, sentences, attentions, loss, output, data
184
+ if bert_model is not None:
185
+ del last_hidden_states, embedding
186
+
187
+ #gc.collect()
188
+ #torch.cuda.empty_cache()
189
+ torch.cuda.synchronize()
190
+
191
+
192
+ #def main(args):
193
+ def main(local_rank, args):
194
+ ip = os.environ['MASTER_IP']
195
+ port = os.environ['MASTER_PORT']
196
+ hosts = int(os.environ['WORLD_SIZE']) # 机器个数 1
197
+ rank = int(os.environ['RANK']) # 当前机器编号
198
+ gpus = torch.cuda.device_count() # 每台机器的GPU个数
199
+ print(local_rank, rank, gpus) #3 0 8
200
+ dist.init_process_group(backend='nccl', init_method=f'tcp://{ip}:{port}', world_size=hosts*gpus, rank=rank*gpus+local_rank)
201
+ torch.cuda.set_device(local_rank)
202
+ dist.barrier()
203
+
204
+ #utils.init_distributed_mode(args)
205
+ args.distributed=True
206
+ args.gpu = local_rank
207
+ print(args)
208
+ #misc.init_distributed_mode(args)
209
+
210
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
211
+ print("{}".format(args).replace(', ', ',\n'))
212
+
213
+ device = torch.device(args.device)
214
+
215
+ # fix the seed for reproducibility
216
+ seed = args.seed + utils.get_rank()
217
+ torch.manual_seed(seed)
218
+ np.random.seed(seed)
219
+
220
+ #cudnn.benchmark = True
221
+
222
+ dataset, num_classes = get_dataset("train",
223
+ get_transform(args=args),
224
+ args=args)
225
+ dataset_test, _ = get_dataset("val",
226
+ get_transform(args=args),
227
+ args=args)
228
+
229
+ # batch sampler
230
+ print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.")
231
+ #num_tasks = utils.get_world_size()
232
+ #global_rank = utils.get_rank()
233
+ num_tasks = hosts*gpus
234
+ global_rank = rank*gpus+local_rank
235
+ train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank,
236
+ shuffle=True)
237
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
238
+
239
+ # data loader
240
+ data_loader = DataLoader(
241
+ dataset, batch_size=args.batch_size,
242
+ sampler=train_sampler, num_workers=args.workers, pin_memory=True, drop_last=True)
243
+
244
+ data_loader_test = DataLoader(
245
+ dataset_test, batch_size=1, sampler=test_sampler, pin_memory=True, num_workers=args.workers)
246
+
247
+ # model initialization
248
+ print(args.model)
249
+ model = segmentation.__dict__[args.model](pretrained=args.pretrained_swin_weights,
250
+ args=args)
251
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
252
+ model.cuda()
253
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
254
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=False)
255
+ single_model = model.module
256
+
257
+ if args.model != 'lavt_one':
258
+ model_class = BertModel
259
+ bert_model = model_class.from_pretrained(args.ck_bert)
260
+ bert_model.pooler = None # a work-around for a bug in Transformers = 3.0.2 that appears for DistributedDataParallel
261
+ bert_model.cuda()
262
+ bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model)
263
+ bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank])
264
+ single_bert_model = bert_model.module
265
+ else:
266
+ bert_model = None
267
+ single_bert_model = None
268
+
269
+ input_shape = dict()
270
+ input_shape['s1'] = Dict({'channel': 128, 'stride': 4})
271
+ input_shape['s2'] = Dict({'channel': 256, 'stride': 8})
272
+ input_shape['s3'] = Dict({'channel': 512, 'stride': 16})
273
+ input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})
274
+
275
+
276
+
277
+ cfg = Dict()
278
+ cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
279
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.0
280
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
281
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4
282
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
283
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
284
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]
285
+
286
+ cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
287
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
288
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1
289
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
290
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10
291
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
292
+
293
+
294
+ maskformer_head = MaskFormerHead(cfg, input_shape)
295
+ maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head)
296
+ maskformer_head.cuda()
297
+ maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False)
298
+ single_head = maskformer_head.module
299
+ print(single_head)
300
+
301
+
302
+ if args.resume == "auto":
303
+ last_ckpt = ""
304
+ for e in range(args.epochs):
305
+ ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth')
306
+ if os.path.exists(ckpt_path):
307
+ last_ckpt = ckpt_path
308
+ args.resume = last_ckpt
309
+
310
+ # resume training
311
+ if args.resume:
312
+ checkpoint = torch.load(args.resume, map_location='cpu')
313
+ single_model.load_state_dict(checkpoint['model'])
314
+ single_head.load_state_dict(checkpoint['head_model'])
315
+ if args.model != 'lavt_one':
316
+ single_bert_model.load_state_dict(checkpoint['bert_model'])
317
+
318
+ # parameters to optimize
319
+ backbone_no_decay = list()
320
+ backbone_decay = list()
321
+ for name, m in single_model.backbone.named_parameters():
322
+ if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name:
323
+ backbone_no_decay.append(m)
324
+ else:
325
+ backbone_decay.append(m)
326
+
327
+ if args.model != 'lavt_one':
328
+ params_to_optimize = [
329
+ {'params': backbone_no_decay, 'weight_decay': 0.0},
330
+ {'params': backbone_decay},
331
+ {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]},
332
+ # the following are the parameters of bert
333
+ {"params": reduce(operator.concat,
334
+ [[p for p in single_bert_model.encoder.layer[i].parameters()
335
+ if p.requires_grad] for i in range(10)])},
336
+ {"params": single_head.parameters()}
337
+ ]
338
+ else:
339
+ params_to_optimize = [
340
+ {'params': backbone_no_decay, 'weight_decay': 0.0},
341
+ {'params': backbone_decay},
342
+ {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]},
343
+ # the following are the parameters of bert
344
+ {"params": reduce(operator.concat,
345
+ [[p for p in single_model.text_encoder.encoder.layer[i].parameters()
346
+ if p.requires_grad] for i in range(10)])},
347
+ ]
348
+
349
+ # optimizer
350
+ optimizer = torch.optim.AdamW(params_to_optimize,
351
+ lr=args.lr,
352
+ weight_decay=args.weight_decay,
353
+ amsgrad=args.amsgrad
354
+ )
355
+
356
+ # learning rate scheduler
357
+ lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
358
+ lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
359
+
360
+ # housekeeping
361
+ start_time = time.time()
362
+ iterations = 0
363
+ best_oIoU = -0.1
364
+
365
+ # resume training (optimizer, lr scheduler, and the epoch)
366
+ if args.resume:
367
+ optimizer.load_state_dict(checkpoint['optimizer'])
368
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
369
+ resume_epoch = checkpoint['epoch']
370
+ else:
371
+ resume_epoch = -999
372
+
373
+ # training loops
374
+ for epoch in range(max(0, resume_epoch+1), args.epochs):
375
+ data_loader.sampler.set_epoch(epoch)
376
+ train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq,
377
+ iterations, bert_model, single_head)
378
+ iou, overallIoU = evaluate(model, data_loader_test, bert_model, single_head)
379
+
380
+ print('Average object IoU {}'.format(iou))
381
+ print('Overall IoU {}'.format(overallIoU))
382
+
383
+
384
+ if single_bert_model is not None:
385
+ dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(),
386
+ 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
387
+ 'lr_scheduler': lr_scheduler.state_dict(), 'head_model': single_head.state_dict()}
388
+ else:
389
+ dict_to_save = {'model': single_model.state_dict(),
390
+ 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
391
+ 'lr_scheduler': lr_scheduler.state_dict()}
392
+
393
+ checkpoint_path = os.path.join(args.output_dir, 'checkpoint-{}.pth'.format(epoch))
394
+ utils.save_on_master(dict_to_save, str(checkpoint_path) + '_TEMP')
395
+ if utils.is_main_process():
396
+ os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path))
397
+
398
+ if utils.is_main_process():
399
+ ckpt_paths = []
400
+ for e in range(args.epochs):
401
+ ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth')
402
+ print(ckpt_path)
403
+ if os.path.exists(ckpt_path):
404
+ ckpt_paths.append(ckpt_path)
405
+ print(ckpt_paths)
406
+ for ckpt_path in ckpt_paths[:-args.max_ckpt]:
407
+ os.remove(ckpt_path)
408
+ print("remove {:s}".format(ckpt_path))
409
+
410
+
411
+ save_checkpoint = (best_oIoU < overallIoU)
412
+ if save_checkpoint:
413
+ print('Better epoch: {}\n'.format(epoch))
414
+ if single_bert_model is not None:
415
+ dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(),
416
+ 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
417
+ 'lr_scheduler': lr_scheduler.state_dict()}
418
+ else:
419
+ dict_to_save = {'model': single_model.state_dict(),
420
+ 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
421
+ 'lr_scheduler': lr_scheduler.state_dict()}
422
+
423
+ checkpoint_path = os.path.join(args.output_dir, 'model_best_{}.pth'.format(args.model_id))
424
+ utils.save_on_master(dict_to_save, checkpoint_path + '_TEMP')
425
+ if utils.is_main_process():
426
+ os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path))
427
+ best_oIoU = overallIoU
428
+
429
+ # summarize
430
+ total_time = time.time() - start_time
431
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
432
+ print('Training time {}'.format(total_time_str))
433
+
434
+
435
+ if __name__ == "__main__":
436
+ from args import get_parser
437
+ parser = get_parser()
438
+ args = parser.parse_args()
439
+ os.makedirs(args.output_dir, exist_ok=True)
440
+ # set up distributed learning
441
+ #utils.init_distributed_mode(args)
442
+ print('Image size: {}'.format(str(args.img_size)))
443
+ #main(args)
444
+ mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())