File size: 4,200 Bytes
1bb90fa
2c140eb
 
 
 
 
 
 
 
 
 
 
 
 
 
777a816
2c140eb
3388bb6
2c140eb
 
 
 
 
 
970c656
 
 
 
 
 
 
2c140eb
970c656
2c140eb
 
 
 
970c656
 
 
 
 
5648cf2
970c656
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import data
import torch
import gradio as gr
from models import imagebind_model
from models.imagebind_model import ModalityType


device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)


def image_text_zeroshot(image, text_list):
    image_paths = [image]
    labels = [label.strip(" ") for label in text_list.strip(" ").split("|")]
    inputs = {
        ModalityType.TEXT: data.load_and_transform_text(labels, device),
        ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
    }

    with torch.no_grad():
        embeddings = model(inputs)

    scores = (
        torch.softmax(
            embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1
        )
        .squeeze(0)
        .tolist()
    )

    score_dict = {label: score for label, score in zip(labels, scores)}

    return score_dict


def audio_text_zeroshot(audio, text_list):
    audio_paths = [audio]
    labels = [label.strip(" ") for label in text_list.strip(" ").split("|")]
    inputs = {
        ModalityType.TEXT: data.load_and_transform_text(labels, device),
        ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
    }

    with torch.no_grad():
        embeddings = model(inputs)

    scores = (
        torch.softmax(
            embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1
        )
        .squeeze(0)
        .tolist()
    )

    score_dict = {label: score for label, score in zip(labels, scores)}

    return score_dict


def inference(
    task,
    image=None,
    audio=None,
    text_list=None,
):
    if task == "image-text":
        result = image_text_zeroshot(image, text_list)
    elif task == "audio-text":
        result = audio_text_zeroshot(audio, text_list)
    else:
        raise NotImplementedError
    return result


def main():
    inputs = [
        gr.inputs.Radio(
            choices=[
                "image-text",
                "audio-text",
            ],
            type="value",
            default="image-text",
            label="Task",
        ),
        gr.inputs.Image(type="filepath", label="Input image"),
        gr.inputs.Audio(type="filepath", label="Input audio"),
        gr.inputs.Textbox(lines=1, label="Candidate texts"),
    ]

    iface = gr.Interface(
        inference,
        inputs,
        "label",
        examples=[
            ["image-text", "assets/dog_image.jpg", None, "A dog|A car|A bird"],
            ["image-text", "assets/car_image.jpg", None, "A dog|A car|A bird"],
            ["audio-text", None, "assets/bird_audio.wav", "A dog|A car|A bird"],
            ["audio-text", None, "assets/dog_audio.wav", "A dog|A car|A bird"],
        ],
        description="""<p>This is a simple demo of ImageBind for zero-shot cross-modal understanding (now including image classification and audio classification). Please refer to the original <a href='https://arxiv.org/abs/2305.05665' target='_blank'>paper</a> and <a href='https://github.com/facebookresearch/ImageBind' target='_blank'>repo</a> for more details.<br>
                    To test your own cases, you can upload an image or an audio, and provide the candidate texts separated by "|".<br>
                    You can duplicate this space and run it privately: <a href='https://huggingface.co/spaces/OFA-Sys/chinese-clip-zero-shot-image-classification?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14' alt='Duplicate Space'></a></p>""",
        title="ImageBind: Zero-shot Cross-modal Understanding",
    )

    iface.launch()


if __name__ == "__main__":
    main()