File size: 2,057 Bytes
fe51ab9
 
75db47e
 
 
 
 
fe51ab9
 
297a61e
 
fe51ab9
 
2e8ba25
fe51ab9
75db47e
fe51ab9
b07e0da
 
2e8ba25
de37095
297a61e
52d04c3
2319997
52d04c3
2319997
52d04c3
9e94506
 
 
 
 
297a61e
9e94506
b07e0da
297a61e
 
4f70792
b07e0da
 
 
 
fe51ab9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import gradio as gr

from off_topic import OffTopicDetector, Translator


translator = Translator("facebook/nllb-200-distilled-600M")
detector = OffTopicDetector("openai/clip-vit-base-patch32", image_size="V", translator=translator)


def validate(item_id: str, use_title: bool, threshold: float):
    images, domain, probas, valid_probas, invalid_probas = detector.predict_probas_item(item_id, use_title=use_title)
    valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
    invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
    return f"## Domain: {domain}", valid_images, invalid_images


with gr.Blocks() as demo:
    gr.Markdown("""
                # Off topic image detector
                ### This app takes an item ID and classifies its pictures as valid/invalid depending on whether they relate to the domain in which it's been listed.
                Input an item ID or select one of the preloaded examples below.""")
    with gr.Row():
        item_id = gr.Textbox(label="Item ID")
        with gr.Column():
            use_title = gr.Checkbox(label="Use translated item title", value=True)
            threshold = gr.Number(label="Threshold", value=0.25, precision=2)
        submit = gr.Button("Submit")
    gr.HTML("<hr>")
    domain = gr.Markdown()
    valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
    gr.HTML("<hr>")
    invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
    submit.click(inputs=[item_id, use_title, threshold], outputs=[domain, valid, invalid], fn=validate)
    gr.HTML("<hr>")
    gr.Examples(
        examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25],
                  ["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]],
        inputs=[item_id, use_title, threshold],
        outputs=[domain, valid, invalid],
        fn=validate,
        cache_examples=True,
    )

demo.launch()