--- license: cc-by-nc-sa-4.0 language: en tags: - splade - conversational-search - multi-turn retrieval - query-expansion - document-expansion - passage-retrieval - knowledge-distillation pipeline_tag: fill-mask --- ## DiSCo: LLM Knowledge Distillation for Efficient Sparse Retrieval in Conversational Search This model is a conversational search adaptation of the original [SPLADE++ (CoCondenser-EnsembleDistil)](https://huggingface.co/naver/splade-cocondenser-ensembledistil) model. It retains the original document encoder and **finetunes the query encoder on QReCC**, a dataset designed for multi-turn conversational search. Training is performed via **distillation from human rewrites**, allowing the model to better capture the semantics of conversational queries. For more details, see the original paper: * DiSCo SPLADE - SIGIR 2025 full paper: https://arxiv.org/abs/2410.14609 > **Note:** This is the **query encoder**. For inference, you also need the corresponding [document encoder](https://huggingface.co/naver/splade-cocondenser-ensembledistil), which remains unchanged from the original SPLADE++ checkpoint. SPLADE can use asymmetric architecture: separate models for query and document representation. ## Usage The input format is a flattened version of the conversational history. q_n [SEP] a_{n-1} [SEP] q_{n-1} [SEP] ... [SEP] a_0 [SEP] q_0 Below is an example script for encoding a conversation: ```python from transformers import AutoTokenizer, AutoModelForMaskedLM import torch.nn.functional as F import torch model = AutoModelForMaskedLM.from_pretrained("slupart/splade-disco-human") tokenizer = AutoTokenizer.from_pretrained("slupart/splade-disco-human") model.eval() conv = [ ("what's the weather like today?", "it's sunny."), ("should I wear sunscreen?", "yes, UV index is high."), ("do I need sunglasses?", "definitely."), ("where can I buy sunglasses?", "try the optician nearby."), ("how much do they cost?", None) ] parts = [conv[-1][0]] + [x for q, a in reversed(conv[:-1]) for x in (a, q) if x] text = " [SEP] ".join(parts) inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits sparse = F.relu(logits).max(1).values.squeeze(0) scores = [(tokenizer.convert_ids_to_tokens([i.item()])[0], sparse[i].item()) for i in torch.nonzero(sparse).squeeze(1)] for token, score in sorted(scores, key=lambda x: -x[1]): print(f"Token: {token:15} | Score: {score:.4f}") ``` ## Citation If you use our checkpoint, please cite our work: ``` @article{lupart2024disco, title={DiSCo Meets LLMs: A Unified Approach for Sparse Retrieval and Contextual Distillation in Conversational Search}, author={Lupart, Simon and Aliannejadi, Mohammad and Kanoulas, Evangelos}, journal={arXiv preprint arXiv:2410.14609}, year={2024} } ```