File size: 4,027 Bytes
babe057
 
 
 
 
 
 
 
 
 
bc085bc
babe057
 
 
 
 
 
 
 
bc085bc
 
 
 
 
 
babe057
 
 
 
 
 
bc085bc
 
 
babe057
 
 
 
 
 
 
 
 
 
 
 
 
 
bc085bc
babe057
 
 
 
 
 
 
 
 
bc085bc
 
 
 
 
 
 
 
babe057
 
 
 
 
 
 
 
 
 
 
 
 
bc085bc
babe057
 
bc085bc
 
babe057
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# app.py

import gradio as gr
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer


def load_model(model_name):
    if model_name == "GLuCoSE-base-ja-v2":
        return SentenceTransformer("pkshatech/GLuCoSE-base-ja-v2")
    elif model_name == "RoSEtta-base":
        return SentenceTransformer("pkshatech/RoSEtta-base", trust_remote_code=True)
    elif model_name == "ruri-large":
        return SentenceTransformer("cl-nagoya/ruri-large")


def get_similarities(model_name, sentences):
    model = load_model(model_name)

    if model_name in ["GLuCoSE-base-ja-v2", "RoSEtta-base"]:
        sentences = [
            "query: " + s if i % 2 == 0 else "passage: " + s
            for i, s in enumerate(sentences)
        ]
    elif model_name == "ruri-large":
        sentences = [
            "クエリ: " + s if i % 2 == 0 else "文章: " + s
            for i, s in enumerate(sentences)
        ]

    embeddings = model.encode(sentences, convert_to_tensor=True)
    similarities = F.cosine_similarity(
        embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2
    )

    return similarities.cpu().numpy()


def format_similarities(similarities):
    return "\n".join([" ".join([f"{val:.4f}" for val in row]) for row in similarities])


def process_input(model_name, input_text):
    sentences = [s.strip() for s in input_text.split("\n") if s.strip()]
    similarities = get_similarities(model_name, sentences)
    return format_similarities(similarities)


models = ["GLuCoSE-base-ja-v2", "RoSEtta-base", "ruri-large"]

with gr.Blocks() as demo:
    gr.Markdown("# Sentence Similarity Demo")

    with gr.Row():
        with gr.Column():
            model_dropdown = gr.Dropdown(
                choices=models, label="Select Model", value=models[0]
            )
            input_text = gr.Textbox(
                lines=5,
                label="Input Sentences (one per line)",
                placeholder="Enter query and passage pairs, alternating lines.",
            )
            gr.Markdown("""
            **Note:** Prefixes ('query:' / 'passage:' or 'クエリ:' / '文章:') are added automatically. Just input your sentences.
            """)
            submit_btn = gr.Button(value="Calculate Similarities")

        with gr.Column():
            output_text = gr.Textbox(label="Similarity Matrix", lines=10)

    submit_btn.click(
        process_input, inputs=[model_dropdown, input_text], outputs=output_text
    )

    gr.Examples(
        examples=[
            [
                "GLuCoSE-base-ja-v2",
                "PKSHAはどんな会社ですか?\n研究開発したアルゴリズムを、多くの企業のソフトウエア・オペレーションに導入しています。",
            ],
            [
                "RoSEtta-base",
                "PKSHAはどんな会社ですか?\n研究開発したアルゴリズムを、多くの企業のソフトウエア・オペレーションに導入しています。",
            ],
            [
                "ruri-large",
                "瑠璃色はどんな色?\n瑠璃色(るりいろ)は、紫みを帯びた濃い青。名は、半貴石の瑠璃(ラピスラズリ、英: lapis lazuli)による。JIS慣用色名では「こい紫みの青」(略号 dp-pB)と定義している[1][2]。\nワシやタカのように、鋭いくちばしと爪を持った大型の鳥類を総称して「何類」というでしょう?\nワシ、タカ、ハゲワシ、ハヤブサ、コンドル、フクロウが代表的である。これらの猛禽類はリンネ前後の時代(17~18世紀)には鷲類・鷹類・隼類及び梟類に分類された。ちなみにリンネは狩りをする鳥を単一の目(もく)にまとめ、vultur(コンドル、ハゲワシ)、falco(ワシ、タカ、ハヤブサなど)、strix(フクロウ)、lanius(モズ)の4属を含めている。",
            ],
        ],
        inputs=[model_dropdown, input_text],
    )

demo.launch()