Дообучение

#2
by alexeyskov - opened

Добрый день!

Пробую дообучить модель на кастомных диалогах и столкнулся с проблемой запуска обучения.

Загружаю модель с квантизацией следующим образом:

@dataclass
class Config:
    model_name = "../preload/GigaChat-20B-A3B-instruct-v1.5/"
    new_model = "./training/new_gigachat_20B/"
    torch_dtype = torch.bfloat16

cfg = Config() 

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=cfg.torch_dtype,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    cfg.model_name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)
model.generation_config = GenerationConfig.from_pretrained(cfg.model_name)

tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, trust_remote_code=True)

tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = model.config.eos_token_id

Для дообучения используется следующий lora конфиг:

peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["gate_proj", "up_proj"]
)

При обучении используется следующий сборщих данных, предварительно кодирую через (tokenizer.apply_chat_template(dialogue, tokenize=False, add_generation_prompt=False)) :

response_template = "assistant<|role_sep|>"

collator = DataCollatorForCompletionOnlyLM(
    response_template=tokenizer.encode(
        response_template, 
        add_special_tokens=False
    ),
    tokenizer=tokenizer,
    mlm=False
)

Обучаю через SFTTrain из trl со следующими параметрами:

training_config = SFTConfig(
    output_dir=cfg.new_model,
    overwrite_output_dir=True,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    optim="paged_adamw_32bit",
    num_train_epochs=100,
    max_steps=100,
    eval_strategy="steps",
    logging_steps=20,
    warmup_steps=50,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    gradient_checkpointing=False,
    label_names=["labels"],
    packing=False,
    max_seq_length=512,
    dataset_num_proc=2,
    dataset_text_field="text"
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    data_collator=collator,
    processing_class=tokenizer,
    args=training_config
)

И при вызове обучения trainer.train() получаю следующую ошибку:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[40], line 1
----> 1 trainer.train()

File ~/venv/lib/python3.10/site-packages/transformers/trainer.py:2241, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2239         hf_hub_utils.enable_progress_bars()
   2240 else:
-> 2241     return inner_training_loop(
   2242         args=args,
   2243         resume_from_checkpoint=resume_from_checkpoint,
   2244         trial=trial,
   2245         ignore_keys_for_eval=ignore_keys_for_eval,
   2246     )

File ~/venv/lib/python3.10/site-packages/transformers/trainer.py:2548, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2541 context = (
   2542     functools.partial(self.accelerator.no_sync, model=model)
   2543     if i != len(batch_samples) - 1
   2544     and self.accelerator.distributed_type != DistributedType.DEEPSPEED
   2545     else contextlib.nullcontext
   2546 )
   2547 with context():
-> 2548     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
   2550 if (
   2551     args.logging_nan_inf_filter
   2552     and not is_torch_xla_available()
   2553     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   2554 ):
   2555     # if loss is nan or inf simply add the average of previous logged losses
   2556     tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File ~/venv/lib/python3.10/site-packages/transformers/trainer.py:3698, in Trainer.training_step(self, model, inputs, num_items_in_batch)
   3695     return loss_mb.reduce_mean().detach().to(self.args.device)
   3697 with self.compute_loss_context_manager():
-> 3698     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
   3700 del inputs
   3701 if (
   3702     self.args.torch_empty_cache_steps is not None
   3703     and self.state.global_step % self.args.torch_empty_cache_steps == 0
   3704 ):

File ~/venv/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:438, in SFTTrainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
    434 def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
    435     """
    436     Compute training loss and additionally compute token accuracies
    437     """
--> 438     (loss, outputs) = super().compute_loss(
    439         model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
    440     )
    442     # Compute token accuracy if we have labels and if the model is not using Liger (no logits)
    443     if "labels" in inputs and not self.args.use_liger:

File ~/venv/lib/python3.10/site-packages/transformers/trainer.py:3759, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   3757         loss_kwargs["num_items_in_batch"] = num_items_in_batch
   3758     inputs = {**inputs, **loss_kwargs}
-> 3759 outputs = model(**inputs)
   3760 # Save past state if it exists
   3761 # TODO: this needs to be fixed and made cleaner later.
   3762 if self.args.past_index >= 0:

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/venv/lib/python3.10/site-packages/peft/peft_model.py:1719, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1717     with self._enable_peft_forward_hooks(**kwargs):
   1718         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1719         return self.base_model(
   1720             input_ids=input_ids,
   1721             attention_mask=attention_mask,
   1722             inputs_embeds=inputs_embeds,
   1723             labels=labels,
   1724             output_attentions=output_attentions,
   1725             output_hidden_states=output_hidden_states,
   1726             return_dict=return_dict,
   1727             **kwargs,
   1728         )
   1730 batch_size = _get_batch_size(input_ids, inputs_embeds)
   1731 if attention_mask is not None:
   1732     # concat prompt attention mask

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/venv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:197, in BaseTuner.forward(self, *args, **kwargs)
    196 def forward(self, *args: Any, **kwargs: Any):
--> 197     return self.model.forward(*args, **kwargs)

File ~/venv/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/a32fc228a93b8c22084d5e227a44afffb674bca5/modelling_deepseek.py:1279, in DeepseekForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **kwargs)
   1276 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1278 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1279 outputs = self.model(
   1280     input_ids=input_ids,
   1281     attention_mask=attention_mask,
   1282     position_ids=position_ids,
   1283     past_key_values=past_key_values,
   1284     inputs_embeds=inputs_embeds,
   1285     use_cache=use_cache,
   1286     output_attentions=output_attentions,
   1287     output_hidden_states=output_hidden_states,
   1288     return_dict=return_dict,
   1289 )
   1291 hidden_states = outputs[0]
   1292 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/venv/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/a32fc228a93b8c22084d5e227a44afffb674bca5/modelling_deepseek.py:1030, in DeepseekModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)
   1018     layer_outputs = self._gradient_checkpointing_func(
   1019         decoder_layer.__call__,
   1020         hidden_states,
   (...)
   1027         position_embeddings,
   1028     )
   1029 else:
-> 1030     layer_outputs = decoder_layer(
   1031         hidden_states,
   1032         attention_mask=causal_mask,
   1033         position_ids=position_ids,
   1034         past_key_value=past_key_values,
   1035         output_attentions=output_attentions,
   1036         use_cache=use_cache,
   1037         cache_position=cache_position,
   1038         position_embeddings=position_embeddings,
   1039         **flash_attn_kwargs,
   1040     )
   1042 hidden_states = layer_outputs[0]
   1044 if use_cache:

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/venv/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/a32fc228a93b8c22084d5e227a44afffb674bca5/modelling_deepseek.py:773, in DeepseekDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
    771 residual = hidden_states
    772 hidden_states = self.post_attention_layernorm(hidden_states)
--> 773 hidden_states = self.mlp(hidden_states)
    774 hidden_states = residual + hidden_states
    776 outputs = (hidden_states,)

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/venv/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/a32fc228a93b8c22084d5e227a44afffb674bca5/modelling_deepseek.py:318, in DeepseekMoE.forward(self, hidden_states)
    316 if self.training:
    317     y = self.moe_train(hidden_states, flat_topk_idx, topk_weight.view(-1, 1))
--> 318     y = y.view(*orig_shape)
    319     y = AddAuxiliaryLoss.apply(y, aux_loss)
    320 else:

RuntimeError: shape '[1, 512, 2048]' is invalid for input of size 6291456

Версии библиотек следующие:

torch==2.5.1+cu121
transformers==4.48.3
trl==0.15.0
peft==0.14.0

Также добавляю минимальный пример для воспроизведения ошибки:

import os

os.environ["WANDB_DISABLED"] = "true"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    GenerationConfig, 
)
from peft import (
    LoraConfig
)

from datasets import Dataset
from trl import SFTTrainer, SFTConfig
from trl import DataCollatorForCompletionOnlyLM
from dataclasses import dataclass


@dataclass
class Config:
    model_name = "ai-sage/GigaChat-20B-A3B-instruct-v1.5"
    new_model = "./new_gigachat_20B/"
    torch_dtype = torch.bfloat16


def get_example_data():
    return [
        [
            {"role": "system", "content": "Работа системы 1"},
            {"role": "user", "content": "Привет!"},
            {"role": "assistant", "content": "Привет!"}
        ],
        [
            {"role": "system", "content": "Работа системы 2"},
            {"role": "user", "content": "Привет! Как дела?"},
            {"role": "assistant", "content": "Привет! Хорошо! У тебя как?"}
        ]
    ]


if __name__ == "__main__":
    cfg = Config() 


    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=cfg.torch_dtype,
        bnb_4bit_use_double_quant=True,
    )

    model = AutoModelForCausalLM.from_pretrained(
        cfg.model_name,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )
    model.generation_config = GenerationConfig.from_pretrained(cfg.model_name)

    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, trust_remote_code=True)

    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = model.config.eos_token_id

    peft_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["gate_proj", "up_proj"]
    )

    train_dataset_texts = []
    for d in get_example_data():
        train_dataset_texts.append(
            {
                "text": tokenizer.apply_chat_template(d, tokenize=False, add_generation_prompt=False)
            }
        )
    val_dataset_texts = train_dataset_texts

    train_dataset = Dataset.from_list(train_dataset_texts)
    eval_dataset = Dataset.from_list(val_dataset_texts)

    response_template = "assistant<|role_sep|>"

    collator = DataCollatorForCompletionOnlyLM(
        response_template=tokenizer.encode(
            response_template, 
            add_special_tokens=False
        ),
        tokenizer=tokenizer,
        mlm=False
    )

    training_config = SFTConfig(
        output_dir=cfg.new_model,
        overwrite_output_dir=True,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=8,
        optim="paged_adamw_32bit",
        num_train_epochs=100,
        max_steps=100,
        eval_strategy="steps",
        logging_steps=20,
        warmup_steps=50,
        logging_strategy="steps",
        learning_rate=2e-4,
        fp16=False,
        bf16=False,
        group_by_length=True,
        gradient_checkpointing=False,
        label_names=["labels"],
        packing=False,
        max_seq_length=512,
        dataset_num_proc=2
    )

    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        peft_config=peft_config,
        data_collator=collator,
        processing_class=tokenizer,
        args=training_config
    )

    trainer.train()

Я заметил, что такую же ошибку можно получить, если перед инференсом не использовать загрузку конфига генерации model.generation_config = GenerationConfig.from_pretrained(cfg.model_name).

Можете подсказать, как правлиьно дообучать модель?

За основу брал рабочий код для дообучения моделей LLaMA, но понимаю, что тут MoE, может быть есть какие-то туториалы про дообучение GigaChat?

Заранее спасибо!

alexeyskov changed discussion status to closed
alexeyskov changed discussion status to open
ai-sage org

Добрый день! Проблему решили, код обновили (если точнее, то moddeling файл)

Просим Вас попробовать подтянуть модель заново и перезапустить дообучение

Спасибо!

Добрый день, если ещё к созданию модели добавить параметр return_dict=True, то обучение запускается, обновленный код создания модели:

model = AutoModelForCausalLM.from_pretrained(
        cfg.model_name,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
        return_dict=True
    )

Процесс обучения запускается, но на валидации возникает следующая ошибка:

Traceback (most recent call last):
  File "/home/alexey/programming/work/CarLLM_Project/experiments/min_example.py", line 137, in <module>
    trainer.train()
  File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 2241, in train
    return inner_training_loop(
  File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 2612, in _inner_training_loop
    self._maybe_log_save_evaluate(
  File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 3085, in _maybe_log_save_evaluate
    metrics = self._evaluate(trial, ignore_keys_for_eval)
  File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 3039, in _evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 4105, in evaluate
    output = eval_loop(
  File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 4299, in evaluation_loop
    losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
  File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 4515, in prediction_step
    loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
  File "/home/alexey/venv/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 464, in compute_loss
    (loss, outputs) = super().compute_loss(
  File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 3759, in compute_loss
    outputs = model(**inputs)
  File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/alexey/venv/lib/python3.10/site-packages/peft/peft_model.py", line 1719, in forward
    return self.base_model(
  File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/alexey/venv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 197, in forward
    return self.model.forward(*args, **kwargs)
  File "/home/alexey/venv/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/alexey/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/7c4e5992e33198828505e86fbdfb27c5f4c24cd9/modelling_deepseek.py", line 1302, in forward
    outputs = self.model(
  File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/alexey/venv/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/alexey/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/7c4e5992e33198828505e86fbdfb27c5f4c24cd9/modelling_deepseek.py", line 1053, in forward
    layer_outputs = decoder_layer(
  File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/alexey/venv/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/alexey/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/7c4e5992e33198828505e86fbdfb27c5f4c24cd9/modelling_deepseek.py", line 796, in forward
    hidden_states = self.mlp(hidden_states)
  File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/alexey/venv/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/alexey/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/7c4e5992e33198828505e86fbdfb27c5f4c24cd9/modelling_deepseek.py", line 344, in forward
    y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight).view(*orig_shape) # removed unnecessary .view(-1, 1)
  File "/home/alexey/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/alexey/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/7c4e5992e33198828505e86fbdfb27c5f4c24cd9/modelling_deepseek.py", line 371, in moe_infer
    expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
RuntimeError: The size of tensor a (2048) must match the size of tensor b (6) at non-singleton dimension 1
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [0,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [1,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [2,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [3,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [4,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [5,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [6,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [7,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [8,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [9,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [10,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [11,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.

Если доабвить return_dict=True в код, который отправлял ранее для воспроизведения ошибки, придёт также и к этой ошибке, когда обучение дойдёт до шага валидации. Можете подсказать, что ещё мог упустить при обучении?

Sign up or log in to comment