Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline, ImageClassificationPipeline | |
class MultiClassLabel(ImageClassificationPipeline): | |
def postprocess(self, model_outputs, top_k=5): | |
if top_k > self.model.config.num_labels: | |
top_k = self.model.config.num_labels | |
if self.framework == "pt": | |
probs = model_outputs.logits.sigmoid()[0] | |
scores, ids = probs.topk(top_k) | |
elif self.framework == "tf": | |
probs = stable_softmax(model_outputs.logits, axis=-1)[0] | |
topk = tf.math.top_k(probs, k=top_k) | |
scores, ids = topk.values.numpy(), topk.indices.numpy() | |
else: | |
raise ValueError(f"Unsupported framework: {self.framework}") | |
scores = scores.tolist() | |
ids = ids.tolist() | |
return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)] | |
pipe_aesthetic = pipeline("image-classification", "./sonic", pipeline_class=MultiClassLabel) | |
def aesthetic(input_img): | |
data = pipe_aesthetic(input_img, top_k=5) | |
final = {} | |
for d in data: | |
final[d["label"]] = d["score"] | |
return final | |
demo_aesthetic = gr.Interface(fn=aesthetic, inputs=gr.Image(type="pil"), outputs=gr.Label(label="characters")) | |
gr.Parallel(demo_aesthetic).launch() |