File size: 2,200 Bytes
927adb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import json
from PIL import Image
import numpy as np
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
from tensorflow.keras.models import load_model
import ipywidgets as widgets
from IPython.display import display
model_path = 'final_teath_classifier.h5'
model = tf.keras.models.load_model(model_path)
# Load the model from Hugging Face model hub
def preprocess_image(image: Image.Image) -> np.ndarray:
# Resize the image to match input size
image = image.resize((256, 256))
# Convert image to array and preprocess input
img_array = np.array(image) / 255.0
# Add batch dimension
img_array = np.expand_dims(img_array, axis=0)
return img_array
def predict_image(image_path):
img = Image.open(image_path)
# Preprocess the image
img_array = preprocess_image(img)
# Convert image array to string using base64 encoding (for text-based models)
inputs = tokenizer.encode(img_array, return_tensors="tf")
# Make prediction
outputs = model(inputs)
predictions = tf.nn.softmax(outputs.logits, axis=-1)
predicted_class = np.argmax(predictions)
if predicted_class == 0:
predict_label = "Clean"
else:
predict_label = "Carries"
return predict_label, predictions.numpy().flatten()
# Create a file uploader widget
uploader = widgets.FileUpload(accept="image/*", multiple=False)
# Display the file uploader widget
display(uploader)
# Define a callback function to handle the uploaded image
def on_upload(change):
# Get the uploaded image file
image_file = list(uploader.value.values())[0]["content"]
# Save the image to a temporary file
with open("temp_image.jpg", "wb") as f:
f.write(image_file)
# Get predictions for the uploaded image
predict_label, logits = predict_image("temp_image.jpg")
# Create a JSON object with the predictions
predictions_json = {
"predicted_class": predict_label,
"evaluations": [f"{logit*100:.4f}%" for logit in logits]
}
# Print the JSON object
print(json.dumps(predictions_json, indent=4))
# Set the callback function for when a file is uploaded
uploader.observe(on_upload, names="value")
|