How to perform forward pass when loading the model with bfloat16?

#12
by etoml - opened

Getting this error when I load the model with bfloat 16 and try to pass an image through it: expected mat1 and mat2 to have the same dtype, but got: float != struct c10:BFloat16

The error occurs because there’s a mismatch between the data types expected during the matrix multiplication. I am guessing your image part of the model may still be using float32, while others are using bfloat16.
Try this:

inputs["images"] = inputs["images"].to(torch.bfloat16)

Sign up or log in to comment