The model of SitEmb-v1.5-Qwen3.
Transformer Usage
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from more_itertools import chunked
residual = True
residual_factor = 0.5
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-Embedding-8B",
use_fast=True,
padding_side='left',
)
model = AutoModel.from_pretrained(
"SituatedEmbedding/SitEmb-v1.5-Qwen3",
torch_dtype=torch.bfloat16,
device_map={"": 0},
)
def _pooling(last_hidden_state, attention_mask, pooling, normalize, input_ids=None, match_idx=None):
if pooling in ['cls', 'first']:
reps = last_hidden_state[:, 0]
elif pooling in ['mean', 'avg', 'average']:
masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif pooling in ['last', 'eos']:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
reps = last_hidden_state[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_state.shape[0]
reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
elif pooling == 'ext':
if match_idx is None:
# default mean
masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
else:
for k in range(input_ids.shape[0]):
sep_index = input_ids[k].tolist().index(match_idx)
attention_mask[k][sep_index:] = 0
masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
else:
raise ValueError(f'unknown pooling method: {pooling}')
if normalize:
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
def first_eos_token_pooling(
last_hidden_states,
first_eos_position,
normalize,
):
batch_size = last_hidden_states.shape[0]
reps = last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), first_eos_position]
if normalize:
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
def encode_query(tokenizer, model, pooling, queries, batch_size, normalize, max_length, residual):
task = "Given a search query, retrieve relevant chunks from fictions that answer the query"
sents = []
for query in queries:
sents.append(get_detailed_instruct(task, query))
return encode_passage(tokenizer, model, pooling, sents, batch_size, normalize, max_length)
def encode_passage(tokenizer, model, pooling, passages, batch_size, normalize, max_length, residual=False):
pas_embs = []
pas_embs_residual = []
total = len(passages) // batch_size + (1 if len(passages) % batch_size != 0 else 0)
with tqdm(total=total) as pbar:
for sent_b in chunked(passages, batch_size):
batch_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True,
return_tensors='pt').to(model.device)
if residual:
batch_list_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True, )
input_ids = batch_list_dict['input_ids']
attention_mask = batch_list_dict['attention_mask']
max_len = len(input_ids[0])
input_starts = [max_len - sum(att) for att in attention_mask]
eos_pos = []
for ii, it in zip(input_ids, input_starts):
pos = ii.index(tokenizer.pad_token_id, it)
eos_pos.append(pos)
eos_pos = torch.tensor(eos_pos).to(model.device)
else:
eos_pos = None
outputs = model(**batch_dict)
pemb_ = _pooling(outputs.last_hidden_state, batch_dict['attention_mask'], pooling, normalize)
if residual:
remb_ = first_eos_token_pooling(outputs.last_hidden_state, eos_pos, normalize)
pas_embs_residual.append(remb_)
pas_embs.append(pemb_)
pbar.update(1)
pas_embs = torch.cat(pas_embs, dim=0)
if pas_embs_residual:
pas_embs_residual = torch.cat(pas_embs_residual, dim=0)
else:
pas_embs_residual = None
return pas_embs, pas_embs_residual
your_query = "Your Query"
query_hidden, _ = encode_query(
tokenizer, model, pooling_type="eos", queries=[your_query],
batch_size=8, normalize=True, max_length=8192, residual=residual,
)
passage_affix = "The context in which the chunk is situated is given below. Please encode the chunk by being aware of the context. Context:\n"
your_chunk = "Your Chunk"
your_context = "Your Context"
candidate_hidden, candidate_hidden_residual = encode_passage(
tokenizer, model, pooling_type="eos", passages=[f"{your_chunk}<|endoftext|>{passage_affix}{your_context}"],
batch_size=4, normalize=True, max_length=8192, residual=residual,
)
query2candidate = query_hidden @ candidate_hidden.T # [num_queries, num_candidates]
if candidate_hidden_residual is not None:
query2candidate_residual = query_hidden @ candidate_hidden_residual.T
if residual_factor == 1.:
query2candidate = query2candidate_residual
elif residual_factor == 0.:
pass
else:
query2candidate = query2candidate * (1. - residual_factor) + query2candidate_residual * residual_factor
print(query2candidate.tolist())
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support