File size: 2,564 Bytes
0df179d
 
 
 
 
 
ae08a8d
 
 
0df179d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae08a8d
 
 
0df179d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Model loading and setup
model_name = "jhu-clsp/FollowIR-7B"
model = AutoModelForCausalLM.from_pretrained(model_name)
if torch.cuda.is_available():
    model = model.cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
token_false_id = tokenizer.get_vocab()["false"]
token_true_id = tokenizer.get_vocab()["true"]

template = """<s> [INST] You are an expert Google searcher, whose job is to determine if the following document is relevant to the query (true/false). Answer using only one word, one of those two choices.

Query: {query}
Document: {text}
Relevant (only output one word, either "true" or "false"): [/INST] """

def check_relevance(query, instruction, passage):
    full_query = f"{query} {instruction}"
    prompt = template.format(query=full_query, text=passage)
    
    tokens = tokenizer(
        [prompt],
        padding=True,
        truncation=True,
        return_tensors="pt",
        pad_to_multiple_of=None,
    )

    if torch.cuda.is_available():
        for key in tokens:
            tokens[key] = tokens[key].cuda()

    batch_scores = model(**tokens).logits[:, -1, :]
    true_vector = batch_scores[:, token_true_id]
    false_vector = batch_scores[:, token_false_id]
    batch_scores = torch.stack([false_vector, true_vector], dim=1)
    batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
    score = batch_scores[:, 1].exp().item()
    
    return f"{score:.4f}"

# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# FollowIR Relevance Checker")
    gr.Markdown("This app uses the FollowIR-7B model to determine the relevance of a passage to a given query and instruction.")
    
    with gr.Row():
        with gr.Column():
            query_input = gr.Textbox(label="Query", placeholder="Enter your search query here")
            instruction_input = gr.Textbox(label="Instruction", placeholder="Enter additional instructions or criteria")
            passage_input = gr.Textbox(label="Passage", placeholder="Enter the passage to check for relevance", lines=5)
            submit_button = gr.Button("Check Relevance")
        
        with gr.Column():
            output = gr.Textbox(label="Relevance Probability")
    
    submit_button.click(
        check_relevance,
        inputs=[query_input, instruction_input, passage_input],
        outputs=[output]
    )

demo.launch()