How to perform forward pass when loading the model with bfloat16?
#12
by
etoml
- opened
Getting this error when I load the model with bfloat 16 and try to pass an image through it: expected mat1 and mat2 to have the same dtype, but got: float != struct c10:BFloat16
The error occurs because there’s a mismatch between the data types expected during the matrix multiplication. I am guessing your image part of the model may still be using float32, while others are using bfloat16.
Try this:
inputs["images"] = inputs["images"].to(torch.bfloat16)