The model of SitEmb-v1.5-Qwen3.

Transformer Usage

import torch

from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from more_itertools import chunked


residual = True
residual_factor = 0.5

tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen3-Embedding-8B",
    use_fast=True,
    padding_side='left',
)

model = AutoModel.from_pretrained(
    "SituatedEmbedding/SitEmb-v1.5-Qwen3",
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)

def _pooling(last_hidden_state, attention_mask, pooling, normalize, input_ids=None, match_idx=None):
    if pooling in ['cls', 'first']:
        reps = last_hidden_state[:, 0]
    elif pooling in ['mean', 'avg', 'average']:
        masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
        reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    elif pooling in ['last', 'eos']:
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            reps = last_hidden_state[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_state.shape[0]
            reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
    elif pooling == 'ext':
        if match_idx is None:
            # default mean
            masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
            reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
        else:
            for k in range(input_ids.shape[0]):
                sep_index = input_ids[k].tolist().index(match_idx)
                attention_mask[k][sep_index:] = 0
            masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
            reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    else:
        raise ValueError(f'unknown pooling method: {pooling}')
    if normalize:
        reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
    return reps


def first_eos_token_pooling(
        last_hidden_states,
        first_eos_position,
        normalize,
):
    batch_size = last_hidden_states.shape[0]
    reps = last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), first_eos_position]
    if normalize:
        reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
    return reps

def encode_query(tokenizer, model, pooling, queries, batch_size, normalize, max_length, residual):
    task = "Given a search query, retrieve relevant chunks from fictions that answer the query"
    sents = []
    for query in queries:
        sents.append(get_detailed_instruct(task, query))

    return encode_passage(tokenizer, model, pooling, sents, batch_size, normalize, max_length)


def encode_passage(tokenizer, model, pooling, passages, batch_size, normalize, max_length, residual=False):
    pas_embs = []
    pas_embs_residual = []
    total = len(passages) // batch_size + (1 if len(passages) % batch_size != 0 else 0)
    with tqdm(total=total) as pbar:
        for sent_b in chunked(passages, batch_size):
            batch_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True,
                                   return_tensors='pt').to(model.device)
            if residual:
                batch_list_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True, )
                input_ids = batch_list_dict['input_ids']
                attention_mask = batch_list_dict['attention_mask']
                max_len = len(input_ids[0])
                input_starts = [max_len - sum(att) for att in attention_mask]
                eos_pos = []
                for ii, it in zip(input_ids, input_starts):
                    pos = ii.index(tokenizer.pad_token_id, it)
                    eos_pos.append(pos)
                eos_pos = torch.tensor(eos_pos).to(model.device)
            else:
                eos_pos = None
            outputs = model(**batch_dict)
            pemb_ = _pooling(outputs.last_hidden_state, batch_dict['attention_mask'], pooling, normalize)
            if residual:
                remb_ = first_eos_token_pooling(outputs.last_hidden_state, eos_pos, normalize)
                pas_embs_residual.append(remb_)
            pas_embs.append(pemb_)
            pbar.update(1)
    pas_embs = torch.cat(pas_embs, dim=0)
    if pas_embs_residual:
        pas_embs_residual = torch.cat(pas_embs_residual, dim=0)
    else:
        pas_embs_residual = None
    return pas_embs, pas_embs_residual

your_query = "Your Query"

query_hidden, _ = encode_query(
    tokenizer, model, pooling_type="eos", queries=[your_query],
    batch_size=8, normalize=True, max_length=8192, residual=residual,
)

passage_affix = "The context in which the chunk is situated is given below. Please encode the chunk by being aware of the context. Context:\n"
your_chunk = "Your Chunk"
your_context = "Your Context"

candidate_hidden, candidate_hidden_residual = encode_passage(
    tokenizer, model, pooling_type="eos", passages=[f"{your_chunk}<|endoftext|>{passage_affix}{your_context}"],
    batch_size=4, normalize=True, max_length=8192, residual=residual,
)

query2candidate = query_hidden @ candidate_hidden.T    # [num_queries, num_candidates]
if candidate_hidden_residual is not None:
    query2candidate_residual = query_hidden @ candidate_hidden_residual.T
    if residual_factor == 1.:
        query2candidate = query2candidate_residual
    elif residual_factor == 0.:
        pass
    else:
        query2candidate = query2candidate * (1. - residual_factor) + query2candidate_residual * residual_factor

print(query2candidate.tolist())
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for SituatedEmbedding/SitEmb-v1.5-Qwen3

Base model

Qwen/Qwen3-8B-Base
Finetuned
(5)
this model