Metal079's picture
Update app.py
308d90e
raw
history blame
1.31 kB
import gradio as gr
from transformers import pipeline, ImageClassificationPipeline
class MultiClassLabel(ImageClassificationPipeline):
def postprocess(self, model_outputs, top_k=5):
if top_k > self.model.config.num_labels:
top_k = self.model.config.num_labels
if self.framework == "pt":
probs = model_outputs.logits.sigmoid()[0]
scores, ids = probs.topk(top_k)
elif self.framework == "tf":
probs = stable_softmax(model_outputs.logits, axis=-1)[0]
topk = tf.math.top_k(probs, k=top_k)
scores, ids = topk.values.numpy(), topk.indices.numpy()
else:
raise ValueError(f"Unsupported framework: {self.framework}")
scores = scores.tolist()
ids = ids.tolist()
return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
pipe_aesthetic = pipeline("image-classification", "./sonic", pipeline_class=MultiClassLabel)
def aesthetic(input_img):
data = pipe_aesthetic(input_img, top_k=5)
final = {}
for d in data:
final[d["label"]] = d["score"]
return final
demo_aesthetic = gr.Interface(fn=aesthetic, inputs=gr.Image(type="pil"), outputs=gr.Label(label="characters"))
gr.Parallel(demo_aesthetic).launch()