Дообучение
Добрый день!
Пробую дообучить модель на кастомных диалогах и столкнулся с проблемой запуска обучения.
Загружаю модель с квантизацией следующим образом:
@dataclass
class Config:
model_name = "../preload/GigaChat-20B-A3B-instruct-v1.5/"
new_model = "./training/new_gigachat_20B/"
torch_dtype = torch.bfloat16
cfg = Config()
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=cfg.torch_dtype,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
cfg.model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
model.generation_config = GenerationConfig.from_pretrained(cfg.model_name)
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = model.config.eos_token_id
Для дообучения используется следующий lora конфиг:
peft_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0,
bias="none",
task_type="CAUSAL_LM",
target_modules=["gate_proj", "up_proj"]
)
При обучении используется следующий сборщих данных, предварительно кодирую через (tokenizer.apply_chat_template(dialogue, tokenize=False, add_generation_prompt=False)
) :
response_template = "assistant<|role_sep|>"
collator = DataCollatorForCompletionOnlyLM(
response_template=tokenizer.encode(
response_template,
add_special_tokens=False
),
tokenizer=tokenizer,
mlm=False
)
Обучаю через SFTTrain из trl со следующими параметрами:
training_config = SFTConfig(
output_dir=cfg.new_model,
overwrite_output_dir=True,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=8,
optim="paged_adamw_32bit",
num_train_epochs=100,
max_steps=100,
eval_strategy="steps",
logging_steps=20,
warmup_steps=50,
logging_strategy="steps",
learning_rate=2e-4,
fp16=False,
bf16=False,
group_by_length=True,
gradient_checkpointing=False,
label_names=["labels"],
packing=False,
max_seq_length=512,
dataset_num_proc=2,
dataset_text_field="text"
)
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
data_collator=collator,
processing_class=tokenizer,
args=training_config
)
И при вызове обучения trainer.train()
получаю следующую ошибку:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[40], line 1
----> 1 trainer.train()
File ~/venv/lib/python3.10/site-packages/transformers/trainer.py:2241, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
2239 hf_hub_utils.enable_progress_bars()
2240 else:
-> 2241 return inner_training_loop(
2242 args=args,
2243 resume_from_checkpoint=resume_from_checkpoint,
2244 trial=trial,
2245 ignore_keys_for_eval=ignore_keys_for_eval,
2246 )
File ~/venv/lib/python3.10/site-packages/transformers/trainer.py:2548, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2541 context = (
2542 functools.partial(self.accelerator.no_sync, model=model)
2543 if i != len(batch_samples) - 1
2544 and self.accelerator.distributed_type != DistributedType.DEEPSPEED
2545 else contextlib.nullcontext
2546 )
2547 with context():
-> 2548 tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
2550 if (
2551 args.logging_nan_inf_filter
2552 and not is_torch_xla_available()
2553 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
2554 ):
2555 # if loss is nan or inf simply add the average of previous logged losses
2556 tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
File ~/venv/lib/python3.10/site-packages/transformers/trainer.py:3698, in Trainer.training_step(self, model, inputs, num_items_in_batch)
3695 return loss_mb.reduce_mean().detach().to(self.args.device)
3697 with self.compute_loss_context_manager():
-> 3698 loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
3700 del inputs
3701 if (
3702 self.args.torch_empty_cache_steps is not None
3703 and self.state.global_step % self.args.torch_empty_cache_steps == 0
3704 ):
File ~/venv/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:438, in SFTTrainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
434 def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
435 """
436 Compute training loss and additionally compute token accuracies
437 """
--> 438 (loss, outputs) = super().compute_loss(
439 model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
440 )
442 # Compute token accuracy if we have labels and if the model is not using Liger (no logits)
443 if "labels" in inputs and not self.args.use_liger:
File ~/venv/lib/python3.10/site-packages/transformers/trainer.py:3759, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
3757 loss_kwargs["num_items_in_batch"] = num_items_in_batch
3758 inputs = {**inputs, **loss_kwargs}
-> 3759 outputs = model(**inputs)
3760 # Save past state if it exists
3761 # TODO: this needs to be fixed and made cleaner later.
3762 if self.args.past_index >= 0:
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/venv/lib/python3.10/site-packages/peft/peft_model.py:1719, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
1717 with self._enable_peft_forward_hooks(**kwargs):
1718 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1719 return self.base_model(
1720 input_ids=input_ids,
1721 attention_mask=attention_mask,
1722 inputs_embeds=inputs_embeds,
1723 labels=labels,
1724 output_attentions=output_attentions,
1725 output_hidden_states=output_hidden_states,
1726 return_dict=return_dict,
1727 **kwargs,
1728 )
1730 batch_size = _get_batch_size(input_ids, inputs_embeds)
1731 if attention_mask is not None:
1732 # concat prompt attention mask
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/venv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:197, in BaseTuner.forward(self, *args, **kwargs)
196 def forward(self, *args: Any, **kwargs: Any):
--> 197 return self.model.forward(*args, **kwargs)
File ~/venv/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/a32fc228a93b8c22084d5e227a44afffb674bca5/modelling_deepseek.py:1279, in DeepseekForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **kwargs)
1276 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1278 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1279 outputs = self.model(
1280 input_ids=input_ids,
1281 attention_mask=attention_mask,
1282 position_ids=position_ids,
1283 past_key_values=past_key_values,
1284 inputs_embeds=inputs_embeds,
1285 use_cache=use_cache,
1286 output_attentions=output_attentions,
1287 output_hidden_states=output_hidden_states,
1288 return_dict=return_dict,
1289 )
1291 hidden_states = outputs[0]
1292 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/venv/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/a32fc228a93b8c22084d5e227a44afffb674bca5/modelling_deepseek.py:1030, in DeepseekModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)
1018 layer_outputs = self._gradient_checkpointing_func(
1019 decoder_layer.__call__,
1020 hidden_states,
(...)
1027 position_embeddings,
1028 )
1029 else:
-> 1030 layer_outputs = decoder_layer(
1031 hidden_states,
1032 attention_mask=causal_mask,
1033 position_ids=position_ids,
1034 past_key_value=past_key_values,
1035 output_attentions=output_attentions,
1036 use_cache=use_cache,
1037 cache_position=cache_position,
1038 position_embeddings=position_embeddings,
1039 **flash_attn_kwargs,
1040 )
1042 hidden_states = layer_outputs[0]
1044 if use_cache:
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/venv/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/a32fc228a93b8c22084d5e227a44afffb674bca5/modelling_deepseek.py:773, in DeepseekDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
771 residual = hidden_states
772 hidden_states = self.post_attention_layernorm(hidden_states)
--> 773 hidden_states = self.mlp(hidden_states)
774 hidden_states = residual + hidden_states
776 outputs = (hidden_states,)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/venv/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/a32fc228a93b8c22084d5e227a44afffb674bca5/modelling_deepseek.py:318, in DeepseekMoE.forward(self, hidden_states)
316 if self.training:
317 y = self.moe_train(hidden_states, flat_topk_idx, topk_weight.view(-1, 1))
--> 318 y = y.view(*orig_shape)
319 y = AddAuxiliaryLoss.apply(y, aux_loss)
320 else:
RuntimeError: shape '[1, 512, 2048]' is invalid for input of size 6291456
Версии библиотек следующие:
torch==2.5.1+cu121
transformers==4.48.3
trl==0.15.0
peft==0.14.0
Также добавляю минимальный пример для воспроизведения ошибки:
import os
os.environ["WANDB_DISABLED"] = "true"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
)
from peft import (
LoraConfig
)
from datasets import Dataset
from trl import SFTTrainer, SFTConfig
from trl import DataCollatorForCompletionOnlyLM
from dataclasses import dataclass
@dataclass
class Config:
model_name = "ai-sage/GigaChat-20B-A3B-instruct-v1.5"
new_model = "./new_gigachat_20B/"
torch_dtype = torch.bfloat16
def get_example_data():
return [
[
{"role": "system", "content": "Работа системы 1"},
{"role": "user", "content": "Привет!"},
{"role": "assistant", "content": "Привет!"}
],
[
{"role": "system", "content": "Работа системы 2"},
{"role": "user", "content": "Привет! Как дела?"},
{"role": "assistant", "content": "Привет! Хорошо! У тебя как?"}
]
]
if __name__ == "__main__":
cfg = Config()
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=cfg.torch_dtype,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
cfg.model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
model.generation_config = GenerationConfig.from_pretrained(cfg.model_name)
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = model.config.eos_token_id
peft_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0,
bias="none",
task_type="CAUSAL_LM",
target_modules=["gate_proj", "up_proj"]
)
train_dataset_texts = []
for d in get_example_data():
train_dataset_texts.append(
{
"text": tokenizer.apply_chat_template(d, tokenize=False, add_generation_prompt=False)
}
)
val_dataset_texts = train_dataset_texts
train_dataset = Dataset.from_list(train_dataset_texts)
eval_dataset = Dataset.from_list(val_dataset_texts)
response_template = "assistant<|role_sep|>"
collator = DataCollatorForCompletionOnlyLM(
response_template=tokenizer.encode(
response_template,
add_special_tokens=False
),
tokenizer=tokenizer,
mlm=False
)
training_config = SFTConfig(
output_dir=cfg.new_model,
overwrite_output_dir=True,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=8,
optim="paged_adamw_32bit",
num_train_epochs=100,
max_steps=100,
eval_strategy="steps",
logging_steps=20,
warmup_steps=50,
logging_strategy="steps",
learning_rate=2e-4,
fp16=False,
bf16=False,
group_by_length=True,
gradient_checkpointing=False,
label_names=["labels"],
packing=False,
max_seq_length=512,
dataset_num_proc=2
)
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
data_collator=collator,
processing_class=tokenizer,
args=training_config
)
trainer.train()
Я заметил, что такую же ошибку можно получить, если перед инференсом не использовать загрузку конфига генерации model.generation_config = GenerationConfig.from_pretrained(cfg.model_name)
.
Можете подсказать, как правлиьно дообучать модель?
За основу брал рабочий код для дообучения моделей LLaMA, но понимаю, что тут MoE, может быть есть какие-то туториалы про дообучение GigaChat?
Заранее спасибо!
Добрый день! Проблему решили, код обновили (если точнее, то moddeling файл)
Просим Вас попробовать подтянуть модель заново и перезапустить дообучение
Спасибо!
Добрый день, если ещё к созданию модели добавить параметр return_dict=True
, то обучение запускается, обновленный код создания модели:
model = AutoModelForCausalLM.from_pretrained(
cfg.model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
return_dict=True
)
Процесс обучения запускается, но на валидации возникает следующая ошибка:
Traceback (most recent call last):
File "/home/alexey/programming/work/CarLLM_Project/experiments/min_example.py", line 137, in <module>
trainer.train()
File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 2241, in train
return inner_training_loop(
File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 2612, in _inner_training_loop
self._maybe_log_save_evaluate(
File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 3085, in _maybe_log_save_evaluate
metrics = self._evaluate(trial, ignore_keys_for_eval)
File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 3039, in _evaluate
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 4105, in evaluate
output = eval_loop(
File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 4299, in evaluation_loop
losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 4515, in prediction_step
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
File "/home/alexey/venv/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 464, in compute_loss
(loss, outputs) = super().compute_loss(
File "/home/alexey/venv/lib/python3.10/site-packages/transformers/trainer.py", line 3759, in compute_loss
outputs = model(**inputs)
File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/alexey/venv/lib/python3.10/site-packages/peft/peft_model.py", line 1719, in forward
return self.base_model(
File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/alexey/venv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 197, in forward
return self.model.forward(*args, **kwargs)
File "/home/alexey/venv/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/alexey/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/7c4e5992e33198828505e86fbdfb27c5f4c24cd9/modelling_deepseek.py", line 1302, in forward
outputs = self.model(
File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/alexey/venv/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/alexey/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/7c4e5992e33198828505e86fbdfb27c5f4c24cd9/modelling_deepseek.py", line 1053, in forward
layer_outputs = decoder_layer(
File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/alexey/venv/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/alexey/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/7c4e5992e33198828505e86fbdfb27c5f4c24cd9/modelling_deepseek.py", line 796, in forward
hidden_states = self.mlp(hidden_states)
File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/alexey/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/alexey/venv/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/alexey/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/7c4e5992e33198828505e86fbdfb27c5f4c24cd9/modelling_deepseek.py", line 344, in forward
y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight).view(*orig_shape) # removed unnecessary .view(-1, 1)
File "/home/alexey/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/alexey/.cache/huggingface/modules/transformers_modules/ai-sage/GigaChat-20B-A3B-base/7c4e5992e33198828505e86fbdfb27c5f4c24cd9/modelling_deepseek.py", line 371, in moe_infer
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
RuntimeError: The size of tensor a (2048) must match the size of tensor b (6) at non-singleton dimension 1
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [0,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [1,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [2,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [3,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [4,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [5,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [6,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [7,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [8,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [9,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [10,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [11,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
Если доабвить return_dict=True
в код, который отправлял ранее для воспроизведения ошибки, придёт также и к этой ошибке, когда обучение дойдёт до шага валидации. Можете подсказать, что ещё мог упустить при обучении?