Fill-Mask
Transformers
PyTorch
modernbert
orionweller commited on
Commit
5b24071
·
verified ·
1 Parent(s): cb354c2

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +482 -0
README.md ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - jhu-clsp/mmbert-decay
5
+ - jhu-clsp/mmbert-midtraining
6
+ - jhu-clsp/mmbert-pretrain-p1-fineweb2-langs
7
+ - jhu-clsp/mmbert-pretrain-p2-fineweb2-remaining
8
+ - jhu-clsp/mmbert-pretrain-p3-others
9
+ pipeline_tag: fill-mask
10
+ ---
11
+
12
+ # mmBERT: A Modern Multilingual Encoder
13
+
14
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
15
+ [![Paper](https://img.shields.io/badge/Paper-Arxiv-red)](https://arxiv.org/abs/2509.06888)
16
+ [![Model](https://img.shields.io/badge/🤗%20Hugging%20Face-Model-blue)](https://huggingface.co/jhu-clsp/mmBERT-base)
17
+ [![Collection](https://img.shields.io/badge/🤗%20Model%20Collection-blue)](https://huggingface.co/collections/jhu-clsp/mmbert-a-modern-multilingual-encoder-68b725831d7c6e3acc435ed4)
18
+ [![GitHub](https://img.shields.io/badge/GitHub-Code-black)](https://github.com/jhu-clsp/mmBERT)
19
+
20
+ > TL;DR: A state-of-the-art multilingual encoder trained on 3T+ tokens across 1800+ languages, introducing novel techniques for learning low-resource languages during the decay phase.
21
+
22
+ mmBERT is a modern multilingual encoder that significantly outperforms previous generation models like XLM-R on classification, embedding, and retrieval tasks. Built on the ModernBERT architecture with novel multilingual training innovations, mmBERT demonstrates that low-resource languages can be effectively learned during the decay phase of training. It is also significantly faster than any previous multilingual encoder.
23
+
24
+ ## Table of Contents
25
+ - [Highlights](#highlights)
26
+ - [Quick Start](#quick-start)
27
+ - [Model Description](#model-description)
28
+ - [Novel Training Innovations](#novel-training-innovations)
29
+ - [Model Family](#model-family)
30
+ - [Training Data](#training-data)
31
+ - [Usage Examples](#usage-examples)
32
+ - [Fine-tuning Examples](#fine-tuning-examples)
33
+ - [Model Architecture](#model-architecture)
34
+ - [Citation](#citation)
35
+
36
+
37
+ ## Quick Start
38
+
39
+ ### Installation
40
+ ```bash
41
+ pip install torch>=1.9.0
42
+ pip install transformers>=4.21.0
43
+ ```
44
+
45
+ ### Usage
46
+
47
+ ```python
48
+ from transformers import AutoTokenizer, AutoModel
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained("jhu-clsp/mmBERT-base")
51
+ model = AutoModel.from_pretrained("jhu-clsp/mmBERT-base")
52
+
53
+ inputs = tokenizer("Hello world", return_tensors="pt")
54
+ outputs = model(**inputs)
55
+ ```
56
+
57
+ ## Model Description
58
+
59
+ mmBERT represents the first significant advancement over XLM-R for massively multilingual encoder models. Key features include:
60
+
61
+ 1. **Massive Language Coverage** - Trained on over 1800 languages with progressive inclusion strategy
62
+ 2. **Modern Architecture** - Built on ModernBERT foundation with Flash Attention 2 and unpadding techniques
63
+ 3. **Novel Training Recipe** - Introduces inverse mask scheduling and temperature sampling
64
+ 4. **Open Training Data** - Complete 3T+ token dataset publicly available
65
+ 5. **Decay Phase Innovation** - Demonstrates effective learning of low-resource languages in final training phase
66
+
67
+ The model uses bidirectional attention with masked language modeling objectives, optimized specifically for multilingual understanding and cross-lingual transfer.
68
+
69
+ ## Novel Training Innovations
70
+
71
+ **Progressive Language Addition**: Start with 60 high-resource languages, expand to 110 mid-resource languages, then include all 1833 languages in decay phase.
72
+
73
+ **Inverse Mask Schedule**: Reduce mask ratio from 30% → 15% → 5% across training phases for progressively refined learning.
74
+
75
+ **Inverse Temperature Sampling**: Adjust multilingual sampling from high-resource bias (τ=0.7) to uniform sampling (τ=0.3).
76
+
77
+ **Model Merging**: Combine English-focused, high-resource, and all-language decay variants using TIES merging.
78
+
79
+ ## Model Family
80
+
81
+ | Model | Total Params | Non-embed Params | Languages | Download |
82
+ |:------|:-------------|:------------------|:----------|:---------|
83
+ | [mmBERT-small](https://huggingface.co/jhu-clsp/mmBERT-small) | 140M | 42M | 1800+ | [![Download](https://img.shields.io/badge/🤗-Download-blue)](https://huggingface.co/jhu-clsp/mmBERT-small) |
84
+ | [mmBERT-base](https://huggingface.co/jhu-clsp/mmBERT-base) | 307M | 110M | 1800+ | [![Download](https://img.shields.io/badge/🤗-Download-blue)](https://huggingface.co/jhu-clsp/mmBERT-base) |
85
+
86
+ ## Training Data
87
+
88
+ mmBERT training data is publicly available across different phases:
89
+
90
+ | Phase | Dataset | Tokens | Description |
91
+ |:------|:--------|:-------|:------------|
92
+ | Pre-training P1 | [mmbert-pretrain-p1](https://huggingface.co/datasets/jhu-clsp/mmbert-pretrain-p1-fineweb2-langs) | 2.3T | 60 languages, foundational training |
93
+ | Pre-training P2 | [mmbert-pretrain-p2](https://huggingface.co/datasets/jhu-clsp/mmbert-pretrain-p2-fineweb2-langs) | - | Extension data for pre-training phase |
94
+ | Pre-training P3 | [mmbert-pretrain-p3](https://huggingface.co/datasets/jhu-clsp/mmbert-pretrain-p3-fineweb2-langs) | - | Final pre-training data |
95
+ | Mid-training | [mmbert-midtraining](https://huggingface.co/datasets/jhu-clsp/mmbert-midtraining-data) | 600B | 110 languages, context extension to 8K |
96
+ | Decay Phase | [mmbert-decay](https://huggingface.co/datasets/jhu-clsp/mmbert-decay-data) | 100B | 1833 languages, premium quality |
97
+
98
+ **Data Sources**: Filtered DCLM (English), FineWeb2 (multilingual), FineWeb2-HQ (20 high-resource languages), Wikipedia (MegaWika), code repositories (StarCoder, ProLong), academic papers (ArXiv, PeS2o), and community discussions (StackExchange).
99
+
100
+ ## Model Architecture
101
+
102
+ | Parameter | mmBERT-small | mmBERT-base |
103
+ |:----------|:-------------|:------------|
104
+ | Layers | 22 | 22 |
105
+ | Hidden Size | 384 | 768 |
106
+ | Intermediate Size | 1152 | 1152 |
107
+ | Attention Heads | 6 | 12 |
108
+ | Total Parameters | 140M | 307M |
109
+ | Non-embedding Parameters | 42M | 110M |
110
+ | Max Sequence Length | 8192 | 8192 |
111
+ | Vocabulary Size | 256,000 | 256,000 |
112
+ | Tokenizer | Gemma 2 | Gemma 2 |
113
+
114
+ ## Usage Examples
115
+
116
+ ### Masked Language Modeling
117
+
118
+ ```python
119
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
120
+ import torch
121
+
122
+ tokenizer = AutoTokenizer.from_pretrained("jhu-clsp/mmBERT-base")
123
+ model = AutoModelForMaskedLM.from_pretrained("jhu-clsp/mmBERT-base")
124
+
125
+ def predict_masked_token(text):
126
+ inputs = tokenizer(text, return_tensors="pt")
127
+ with torch.no_grad():
128
+ outputs = model(**inputs)
129
+
130
+ mask_indices = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)
131
+ predictions = outputs.logits[mask_indices]
132
+ top_tokens = torch.topk(predictions, 5, dim=-1)
133
+
134
+ return [tokenizer.decode(token) for token in top_tokens.indices[0]]
135
+
136
+ # Works across languages
137
+ texts = [
138
+ "The capital of France is [MASK].",
139
+ "La capital de España es [MASK].",
140
+ "Die Hauptstadt von Deutschland ist [MASK]."
141
+ ]
142
+
143
+ for text in texts:
144
+ predictions = predict_masked_token(text)
145
+ print(f"Text: {text}")
146
+ print(f"Predictions: {predictions}")
147
+ ```
148
+
149
+ ### Cross-lingual Embeddings
150
+
151
+ ```python
152
+ from transformers import AutoTokenizer, AutoModel
153
+ import torch
154
+ from sklearn.metrics.pairwise import cosine_similarity
155
+
156
+ tokenizer = AutoTokenizer.from_pretrained("jhu-clsp/mmBERT-base")
157
+ model = AutoModel.from_pretrained("jhu-clsp/mmBERT-base")
158
+
159
+ def get_embeddings(texts):
160
+ inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
161
+
162
+ with torch.no_grad():
163
+ outputs = model(**inputs)
164
+ embeddings = outputs.last_hidden_state.mean(dim=1)
165
+
166
+ return embeddings.numpy()
167
+
168
+ multilingual_texts = [
169
+ "Artificial intelligence is transforming technology",
170
+ "La inteligencia artificial está transformando la tecnología",
171
+ "L'intelligence artificielle transforme la technologie",
172
+ "人工智能正在改变技术"
173
+ ]
174
+
175
+ embeddings = get_embeddings(multilingual_texts)
176
+ similarities = cosine_similarity(embeddings)
177
+ print("Cross-lingual similarity matrix:")
178
+ print(similarities)
179
+ ```
180
+
181
+ ## Fine-tuning Examples
182
+
183
+ ### Dense Retrieval with Sentence Transformers
184
+
185
+ <details>
186
+ <summary>Click to expand dense retrieval fine-tuning example</summary>
187
+
188
+ ```python
189
+ import argparse
190
+ from datasets import load_dataset
191
+ from sentence_transformers import (
192
+ SentenceTransformer,
193
+ SentenceTransformerTrainer,
194
+ SentenceTransformerTrainingArguments,
195
+ )
196
+ from sentence_transformers.evaluation import TripletEvaluator
197
+ from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
198
+ from sentence_transformers.training_args import BatchSamplers
199
+
200
+ def main():
201
+ parser = argparse.ArgumentParser()
202
+ parser.add_argument("--lr", type=float, default=8e-5)
203
+ parser.add_argument("--model_name", type=str, default="jhu-clsp/mmBERT-base")
204
+ args = parser.parse_args()
205
+
206
+ lr = args.lr
207
+ model_name = args.model_name
208
+ model_shortname = model_name.split("/")[-1]
209
+
210
+ model = SentenceTransformer(model_name)
211
+
212
+ dataset = load_dataset(
213
+ "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1",
214
+ "triplet-hard",
215
+ split="train",
216
+ )
217
+ dataset_dict = dataset.train_test_split(test_size=1_000, seed=12)
218
+ train_dataset = dataset_dict["train"].select(range(1_250_000))
219
+ eval_dataset = dataset_dict["test"]
220
+
221
+ loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=16)
222
+ run_name = f"{model_shortname}-DPR-{lr}"
223
+
224
+ training_args = SentenceTransformerTrainingArguments(
225
+ output_dir=f"output/{model_shortname}/{run_name}",
226
+ num_train_epochs=1,
227
+ per_device_train_batch_size=512,
228
+ per_device_eval_batch_size=512,
229
+ warmup_ratio=0.05,
230
+ fp16=False,
231
+ bf16=True,
232
+ batch_sampler=BatchSamplers.NO_DUPLICATES,
233
+ learning_rate=lr,
234
+ save_strategy="steps",
235
+ save_steps=500,
236
+ save_total_limit=2,
237
+ logging_steps=500,
238
+ run_name=run_name,
239
+ )
240
+
241
+ dev_evaluator = TripletEvaluator(
242
+ anchors=eval_dataset["query"],
243
+ positives=eval_dataset["positive"],
244
+ negatives=eval_dataset["negative"],
245
+ name="msmarco-co-condenser-dev",
246
+ )
247
+ dev_evaluator(model)
248
+
249
+ trainer = SentenceTransformerTrainer(
250
+ model=model,
251
+ args=training_args,
252
+ train_dataset=train_dataset,
253
+ eval_dataset=eval_dataset,
254
+ loss=loss,
255
+ evaluator=dev_evaluator,
256
+ )
257
+ trainer.train()
258
+
259
+ model.save_pretrained(f"output/{model_shortname}/{run_name}/final")
260
+ model.push_to_hub(run_name, private=False)
261
+
262
+ if __name__ == "__main__":
263
+ main()
264
+ ```
265
+
266
+ </details>
267
+
268
+ ### Cross-lingual Classification
269
+
270
+ <details>
271
+ <summary>Click to expand multilingual classification fine-tuning example</summary>
272
+
273
+ ```python
274
+ from transformers import (
275
+ AutoTokenizer,
276
+ AutoModelForSequenceClassification,
277
+ TrainingArguments,
278
+ Trainer
279
+ )
280
+ from datasets import load_dataset
281
+ import numpy as np
282
+ from sklearn.metrics import accuracy_score, f1_score
283
+
284
+ def compute_metrics(eval_pred):
285
+ predictions, labels = eval_pred
286
+ predictions = np.argmax(predictions, axis=1)
287
+ return {
288
+ 'accuracy': accuracy_score(labels, predictions),
289
+ 'f1': f1_score(labels, predictions, average='weighted')
290
+ }
291
+
292
+ def main():
293
+ model_name = "jhu-clsp/mmBERT-base"
294
+
295
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
296
+ model = AutoModelForSequenceClassification.from_pretrained(
297
+ model_name,
298
+ num_labels=3
299
+ )
300
+
301
+ dataset = load_dataset("xnli", "all_languages")
302
+
303
+ def tokenize_function(examples):
304
+ texts = [f"{p} {tokenizer.sep_token} {h}"
305
+ for p, h in zip(examples["premise"], examples["hypothesis"])]
306
+
307
+ return tokenizer(
308
+ texts,
309
+ truncation=True,
310
+ padding=True,
311
+ max_length=512
312
+ )
313
+
314
+ train_dataset = dataset["train"].map(tokenize_function, batched=True)
315
+ eval_dataset = dataset["validation"].map(tokenize_function, batched=True)
316
+
317
+ training_args = TrainingArguments(
318
+ output_dir="./mmbert-xnli",
319
+ learning_rate=3e-5,
320
+ per_device_train_batch_size=32,
321
+ per_device_eval_batch_size=32,
322
+ num_train_epochs=3,
323
+ weight_decay=0.01,
324
+ evaluation_strategy="epoch",
325
+ save_strategy="epoch",
326
+ load_best_model_at_end=True,
327
+ metric_for_best_model="f1",
328
+ greater_is_better=True,
329
+ )
330
+
331
+ trainer = Trainer(
332
+ model=model,
333
+ args=training_args,
334
+ train_dataset=train_dataset,
335
+ eval_dataset=eval_dataset,
336
+ compute_metrics=compute_metrics,
337
+ )
338
+
339
+ trainer.train()
340
+
341
+ if __name__ == "__main__":
342
+ main()
343
+ ```
344
+
345
+ </details>
346
+
347
+ ### Multilingual Reranking
348
+
349
+ <details>
350
+ <summary>Click to expand multilingual reranking fine-tuning example</summary>
351
+
352
+ ```python
353
+ import logging
354
+ from datasets import load_dataset
355
+ from sentence_transformers.cross_encoder import (
356
+ CrossEncoder,
357
+ CrossEncoderModelCardData,
358
+ CrossEncoderTrainer,
359
+ CrossEncoderTrainingArguments,
360
+ )
361
+ from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
362
+ from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss
363
+ from sentence_transformers.util import mine_hard_negatives
364
+ from sentence_transformers import SentenceTransformer
365
+ import torch
366
+
367
+ def main():
368
+ model_name = "jhu-clsp/mmBERT-base"
369
+ train_batch_size = 32
370
+ num_epochs = 2
371
+ num_hard_negatives = 7
372
+
373
+ model = CrossEncoder(
374
+ model_name,
375
+ model_card_data=CrossEncoderModelCardData(
376
+ language="multilingual",
377
+ license="mit",
378
+ ),
379
+ )
380
+
381
+ full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(50_000))
382
+ dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=42)
383
+ train_dataset = dataset_dict["train"]
384
+ eval_dataset = dataset_dict["test"]
385
+
386
+ embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", device="cpu")
387
+ hard_train_dataset = mine_hard_negatives(
388
+ train_dataset,
389
+ embedding_model,
390
+ num_negatives=num_hard_negatives,
391
+ margin=0,
392
+ range_min=0,
393
+ range_max=100,
394
+ sampling_strategy="top",
395
+ batch_size=2048,
396
+ output_format="labeled-pair",
397
+ use_faiss=True,
398
+ )
399
+
400
+ loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives))
401
+
402
+ nano_beir_evaluator = CrossEncoderNanoBEIREvaluator(
403
+ dataset_names=["msmarco", "nfcorpus", "nq"],
404
+ batch_size=train_batch_size,
405
+ )
406
+
407
+ args = CrossEncoderTrainingArguments(
408
+ output_dir="./mmbert-reranker",
409
+ num_train_epochs=num_epochs,
410
+ per_device_train_batch_size=train_batch_size,
411
+ per_device_eval_batch_size=train_batch_size,
412
+ learning_rate=2e-5,
413
+ warmup_ratio=0.1,
414
+ fp16=False,
415
+ bf16=True,
416
+ dataloader_num_workers=4,
417
+ load_best_model_at_end=True,
418
+ metric_for_best_model="eval_msmarco_ndcg@10",
419
+ eval_strategy="steps",
420
+ eval_steps=1000,
421
+ save_strategy="steps",
422
+ save_steps=1000,
423
+ save_total_limit=2,
424
+ logging_steps=200,
425
+ seed=42,
426
+ )
427
+
428
+ trainer = CrossEncoderTrainer(
429
+ model=model,
430
+ args=args,
431
+ train_dataset=hard_train_dataset,
432
+ loss=loss,
433
+ evaluator=nano_beir_evaluator,
434
+ )
435
+ trainer.train()
436
+
437
+ model.save_pretrained("./mmbert-reranker/final")
438
+
439
+ if __name__ == "__main__":
440
+ main()
441
+ ```
442
+
443
+ </details>
444
+
445
+ ## Training Data
446
+
447
+ mmBERT was trained on a carefully curated 3T+ token multilingual dataset:
448
+
449
+ | Phase | Dataset | Description |
450
+ |:------|:--------|:------------|
451
+ | [Pre-training P1](https://huggingface.co/datasets/jhu-clsp/mmbert-pretrain-p1-fineweb2-langs) | 2.3T tokens | 60 languages, diverse data mixture |
452
+ | [Pre-training P2](https://huggingface.co/datasets/jhu-clsp/mmbert-pretrain-p2-fineweb2-langs) | - | Extension data for pre-training |
453
+ | [Pre-training P3](https://huggingface.co/datasets/jhu-clsp/mmbert-pretrain-p3-fineweb2-langs) | - | Final pre-training data |
454
+ | [Mid-training](https://huggingface.co/datasets/jhu-clsp/mmbert-midtraining-data) | 600B tokens | 110 languages, context extension |
455
+ | [Decay Phase](https://huggingface.co/datasets/jhu-clsp/mmbert-decay-data) | 100B tokens | 1833 languages, premium quality |
456
+
457
+ **Primary Sources:**
458
+ - **Filtered DCLM**: High-quality English content
459
+ - **FineWeb2**: Broad multilingual web coverage (1800+ languages)
460
+ - **FineWeb2-HQ**: Filtered subset of 20 high-resource languages
461
+ - **Code**: StarCoder and ProLong repositories
462
+ - **Academic**: ArXiv papers and PeS2o scientific content
463
+ - **Reference**: Wikipedia (MegaWika) and textbooks
464
+ - **Community**: StackExchange discussions
465
+
466
+
467
+ ## Citation
468
+
469
+ If you use mmBERT in your research, please cite our work:
470
+
471
+ ```bibtex
472
+ @misc{marone2025mmbertmodernmultilingualencoder,
473
+ title={mmBERT: A Modern Multilingual Encoder with Annealed Language Learning},
474
+ author={Marc Marone and Orion Weller and William Fleshman and Eugene Yang and Dawn Lawrie and Benjamin Van Durme},
475
+ year={2025},
476
+ eprint={2509.06888},
477
+ archivePrefix={arXiv},
478
+ primaryClass={cs.CL},
479
+ url={https://arxiv.org/abs/2509.06888},
480
+ }
481
+ ```
482
+ """