Model is generating the same output but different images.
I am downloading the model as shown in the documentation:
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_ID = "google/shieldgemma-2-4b-it"
MODEL_DIR = "model"
if __name__ == "__main__":
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = ShieldGemma2ForImageClassification.from_pretrained(MODEL_ID, token=HF_TOKEN)
processor.save_pretrained(MODEL_DIR)
model.save_pretrained(MODEL_DIR)
print(f"Model and processor saved to {MODEL_DIR}")
I am then loading the model using this function:
def load_model():
global processor, model
if processor is None or model is None:
_log_hardware_stats("Before model load")
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(
MODEL_DIR, local_files_only=True
)
model = ShieldGemma2ForImageClassification.from_pretrained(
MODEL_DIR,
torch_dtype=torch.bfloat16,
local_files_only=True,
).eval()
model = model.to(device)
_log_hardware_stats("After model load")
logger.info("Shield Gemma 2 model loaded successfully.")
return processor, model
Finally using it this way:
@app
.post("/predict/image")
async def predict_image(file: UploadFile = File(...)):
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
inputs = processor(
images=[image], policies=POLICY_LABELS, return_tensors="pt"
).to(model.device)
with torch.no_grad():
output = model(**inputs)
probabilities = output.probabilities.tolist()
results = dict(zip(POLICY_LABELS, probabilities[0]))
return {"filename": file.filename, "moderation_results": results}
except Exception as e:
return {"error": str(e)}
I am new to using a large model like this for inference tasks. I am not sure if I am missing something but the result is that any image passed to the endpoint defined above results in the same output probabilities. Probabilities only change when the policies are changed meaning the output is only dependent on the policies. I have redownloaded the model multiple times and know that the downloaded files are not corrupted.
Hi @tathagat10 ,
Welcome to Google Gemma family of open source models, the shield Gemma models produces the output based on the given input image + policy that is provided to it to validate whether the given image is violating the policy or not. The output is the probability of 'Yes'/'No' tokens, with a higher score indicating the model's higher confidence that the image violates the specified policy. 'Yes' means that the image violated the policy, 'No' means that the model did not violate the policy.
Thanks.
Good