Batched generation (batch_size > 1) produces incorrect outputs β€” possible causal mask issue?

#9
by vconchel - opened

Generation isn't working properly when batch_size > 1 for me, the longest sample of the batch is normally generated, but the rest are full of spaces and repeating a lot of words. Is this a common issue?

I solved it by changing lines 567-583 in modeling_ouro.py from

mask_kwargs = {
    "config": self.config,
    "input_embeds": inputs_embeds,
    "attention_mask": attention_mask,
    "cache_position": cache_position,
    "past_key_values": past_key_values,
    "position_ids": position_ids,
}
# Create the masks
causal_mask_mapping = {
    "full_attention": create_causal_mask(**mask_kwargs),
}
# The sliding window alternating layers are not always activated depending on the config
if self.has_sliding_layers:
    causal_mask_mapping["sliding_attention"] = (
        create_sliding_window_causal_mask(**mask_kwargs)
    )

to

mask_kwargs = {
    "attention_mask": attention_mask,
    "input_shape": inputs_embeds.shape[:2],
    "inputs_embeds": inputs_embeds,
    "past_key_values_length": past_key_values.get_seq_length() if past_key_values is not None else 0
}
# Create the masks
causal_mask_mapping = {
    "full_attention": _prepare_4d_causal_attention_mask(**mask_kwargs),
}

# The sliding window alternating layers are not always activated depending on the config
if self.has_sliding_layers:
    causal_mask_mapping["sliding_attention"] = _prepare_4d_causal_attention_mask(
        **mask_kwargs,
        sliding_window=self.config["sliding_window"]
    )

Is there a more straightforward solution?

Hey, thank you for this hint! I digged into it myself and found batch_size>1 not working properly with attn_implementation="eager" (many whitespaces) and "sdpa" (completely crash). "flash_attention_2" backend worked fine.

The issue I found is an error in UniversalTransformerCache. UniversalTransformerCache.get_mask_sizes β€” returns wrong KV length during autoregressive steps (always returns query_length instead of cached_length + query_length). This makes the 4D attention mask too small, so padding positions get broadcasted away and batched generation is corrupted for all sequences except the longest (unpadded) one.

See https://huggingface.co/ByteDance/Ouro-2.6B-Thinking/discussions/8

Sign up or log in to comment