|
import os |
|
from langdetect import detect |
|
import torch.multiprocessing as mp |
|
|
|
from colbert import Indexer, Searcher |
|
from colbert.infra import ColBERTConfig, Run |
|
from colbert.utils.utils import print_message |
|
from colbert.data.collection import Collection |
|
from colbert.modeling.checkpoint import Checkpoint |
|
from colbert.indexing.index_saver import IndexSaver |
|
from colbert.search.index_storage import IndexScorer |
|
from colbert.infra.launcher import Launcher, print_memory_stats |
|
from colbert.indexing.collection_encoder import CollectionEncoder |
|
from colbert.indexing.collection_indexer import CollectionIndexer |
|
|
|
|
|
MMARCO_LANGUAGES = { |
|
'ar': ('arabic', 'ar_AR'), |
|
'de': ('german', 'de_DE'), |
|
'en': ('english', 'en_XX'), |
|
'es': ('spanish', 'es_XX'), |
|
'fr': ('french', 'fr_XX'), |
|
'hi': ('hindi', 'hi_IN'), |
|
'id': ('indonesian', 'id_ID'), |
|
'it': ('italian', 'it_IT'), |
|
'ja': ('japanese', 'ja_XX'), |
|
'nl': ('dutch', 'nl_XX'), |
|
'pt': ('portuguese', 'pt_XX'), |
|
'ru': ('russian', 'ru_RU'), |
|
'vi': ('vietnamese', 'vi_VN'), |
|
'zh': ('chinese', 'zh_CN'), |
|
} |
|
MRTYDI_LANGUAGES = { |
|
'ar': ('arabic', 'ar_AR'), |
|
'bn': ('bengali', 'bn_IN'), |
|
'en': ('english', 'en_XX'), |
|
'fi': ('finnish', 'fi_FI'), |
|
'id': ('indonesian', 'id_ID'), |
|
'ja': ('japanese', 'ja_XX'), |
|
'ko': ('korean', 'ko_KR'), |
|
'ru': ('russian', 'ru_RU'), |
|
'sw': ('swahili', 'sw_KE'), |
|
'te': ('telugu', 'te_IN'), |
|
'th': ('thai', 'th_TH'), |
|
} |
|
MIRACL_LANGUAGES = { |
|
'ar': ('arabic', 'ar_AR'), |
|
'bn': ('bengali', 'bn_IN'), |
|
'en': ('english', 'en_XX'), |
|
'es': ('spanish', 'es_XX'), |
|
'fa': ('persian', 'fa_IR'), |
|
'fi': ('finnish', 'fi_FI'), |
|
'fr': ('french', 'fr_XX'), |
|
'hi': ('hindi', 'hi_IN'), |
|
'id': ('indonesian', 'id_ID'), |
|
'ja': ('japanese', 'ja_XX'), |
|
'ko': ('korean', 'ko_KR'), |
|
'ru': ('russian', 'ru_RU'), |
|
'sw': ('swahili', 'sw_KE'), |
|
'te': ('telugu', 'te_IN'), |
|
'th': ('thai', 'th_TH'), |
|
'zh': ('chinese', 'zh_CN'), |
|
} |
|
ALL_LANGUAGES = {**MMARCO_LANGUAGES, **MRTYDI_LANGUAGES, **MIRACL_LANGUAGES} |
|
|
|
|
|
def set_xmod_language(model, lang:str): |
|
""" |
|
Set the default language code for the model. This is used when the language is not specified in the input. |
|
Source: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/xmod/modeling_xmod.py#L687 |
|
""" |
|
lang = lang.split('-')[0] |
|
if (value := ALL_LANGUAGES.get(lang)) is not None: |
|
model.set_default_language(value[1]) |
|
else: |
|
raise KeyError(f"Language {lang} not supported.") |
|
|
|
|
|
|
|
|
|
class CustomIndexer(Indexer): |
|
def __launch(self, collection): |
|
manager = mp.Manager() |
|
shared_lists = [manager.list() for _ in range(self.config.nranks)] |
|
shared_queues = [manager.Queue(maxsize=1) for _ in range(self.config.nranks)] |
|
launcher = Launcher(custom_encode) |
|
launcher.launch(self.config, collection, shared_lists, shared_queues, self.verbose) |
|
|
|
def custom_encode(config, collection, shared_lists, shared_queues, verbose: int = 3): |
|
encoder = CustomCollectionIndexer(config=config, collection=collection, verbose=verbose) |
|
encoder.run(shared_lists) |
|
|
|
class CustomCollectionIndexer(CollectionIndexer): |
|
def __init__(self, config: ColBERTConfig, collection, verbose=2): |
|
self.verbose = verbose |
|
self.config = config |
|
self.rank, self.nranks = self.config.rank, self.config.nranks |
|
self.use_gpu = self.config.total_visible_gpus > 0 |
|
if self.config.rank == 0 and self.verbose > 1: |
|
self.config.help() |
|
self.collection = Collection.cast(collection) |
|
self.checkpoint = Checkpoint(self.config.checkpoint, colbert_config=self.config) |
|
if self.checkpoint.bert.__class__.__name__.lower().startswith("xmod"): |
|
language = detect(self.collection.__getitem__(0)) |
|
Run().print_main(f"#> Setting X-MOD language adapters to {language}.") |
|
set_xmod_language(self.checkpoint.bert, lang=language) |
|
if self.use_gpu: |
|
self.checkpoint = self.checkpoint.cuda() |
|
self.encoder = CollectionEncoder(config, self.checkpoint) |
|
self.saver = IndexSaver(config) |
|
print_memory_stats(f'RANK:{self.rank}') |
|
|
|
|
|
|
|
|
|
class CustomSearcher(Searcher): |
|
def __init__(self, index, checkpoint=None, collection=None, config=None, index_root=None, verbose:int = 3): |
|
self.verbose = verbose |
|
if self.verbose > 1: |
|
print_memory_stats() |
|
|
|
initial_config = ColBERTConfig.from_existing(config, Run().config) |
|
|
|
default_index_root = initial_config.index_root_ |
|
index_root = index_root if index_root else default_index_root |
|
self.index = os.path.join(index_root, index) |
|
self.index_config = ColBERTConfig.load_from_index(self.index) |
|
|
|
self.checkpoint = checkpoint or self.index_config.checkpoint |
|
self.checkpoint_config = ColBERTConfig.load_from_checkpoint(self.checkpoint) |
|
self.config = ColBERTConfig.from_existing(self.checkpoint_config, self.index_config, initial_config) |
|
|
|
self.collection = Collection.cast(collection or self.config.collection) |
|
self.configure(checkpoint=self.checkpoint, collection=self.collection) |
|
|
|
self.checkpoint = Checkpoint(self.checkpoint, colbert_config=self.config, verbose=self.verbose) |
|
if self.checkpoint.bert.__class__.__name__.lower().startswith("xmod"): |
|
language = detect(self.collection.__getitem__(0)) |
|
print_message(f"#> Setting X-MOD language adapters to {language}.") |
|
set_xmod_language(self.checkpoint.bert, lang=language) |
|
use_gpu = self.config.total_visible_gpus > 0 |
|
if use_gpu: |
|
self.checkpoint = self.checkpoint.cuda() |
|
load_index_with_mmap = self.config.load_index_with_mmap |
|
if load_index_with_mmap and use_gpu: |
|
raise ValueError(f"Memory-mapped index can only be used with CPU!") |
|
self.ranker = IndexScorer(self.index, use_gpu, load_index_with_mmap) |
|
print_memory_stats() |
|
|