MTP Integration: Unexpectedly High Loss with Loaded Weights

#105
by parambole - opened

Hey,

I am currently working on integrating MTP into Maxtext. To verify the implementation, I loaded the open MTP weights and triggered a few training steps on the C4 dataset. The main model's loss on Deepseek V3 is low, around 2.5, but the MTP module's loss is over 12. I expected the loss of the trained weights of the MTP module to be around the same or on the lower side (below 12).

I am not sure if I am expecting the right thing, i.e., for the MTP module to be optimized with the loaded weights and provide a low loss on the C4 dataset, or if this is expected behavior, or if something is wrong with my code. Any help would be appreciated.

Sign up or log in to comment