MTP Integration: Unexpectedly High Loss with Loaded Weights
#105
by
parambole
- opened
Hey,
I am currently working on integrating MTP into Maxtext. To verify the implementation, I loaded the open MTP weights and triggered a few training steps on the C4 dataset. The main model's loss on Deepseek V3 is low, around 2.5, but the MTP module's loss is over 12. I expected the loss of the trained weights of the MTP module to be around the same or on the lower side (below 12).
I am not sure if I am expecting the right thing, i.e., for the MTP module to be optimized with the loaded weights and provide a low loss on the C4 dataset, or if this is expected behavior, or if something is wrong with my code. Any help would be appreciated.