Edit model card

nli-entailment-verifier-xxl

Model description

nli-entailment-verifier-xxl is based on flan-t5-xxl model and finetuned with a ranking objective (rank the most supported hypothesis from a given pair of hypotheses for a given premise). Please refer to our paper Are Machines Better at Complex Reasoning? Unveiling Human-Machine Inference Gaps in Entailment Verification for more detals.

It is built to verify whether a given premise supports a hypothesis or not. It works for both NLI-style datasets and CoT rationales. This model is specifically trained to handle multi-sentence premises (similar to what we expect in CoT rationales and other modern LLM use cases).

Note: You can use 4-bit/8-bit quantization to reduce GPU memory usage.

Usage

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

def get_score(model, tokenizer, input_ids):
    pos_ids = tokenizer('Yes').input_ids
    neg_ids = tokenizer('No').input_ids
    pos_id = pos_ids[0]
    neg_id = neg_ids[0]
    
    logits = model(input_ids, decoder_input_ids=torch.zeros((input_ids.size(0), 1), dtype=torch.long)).logits
    pos_logits = logits[:, 0, pos_id]
    neg_logits = logits[:, 0, neg_id]
    posneg_logits = torch.cat([pos_logits.unsqueeze(-1), neg_logits.unsqueeze(-1)], dim=1)
    scores = torch.nn.functional.softmax(posneg_logits, dim=1)[:, 0]
    return scores

tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-xxl')
model = AutoModelForSeq2SeqLM.from_pretrained('soumyasanyal/nli-entailment-verifier-xxl')

premise = "A fossil fuel is a kind of natural resource. Coal is a kind of fossil fuel."
hypothesis = "Coal is a kind of natural resource."
prompt = f"Premise: {premise}\nHypothesis: {hypothesis}\nGiven the premise, is the hypothesis correct?\nAnswer:"

input_ids = tokenizer(prompt, return_tensors='pt').input_ids

scores = get_score(model, tokenizer, input_ids)
print(f'Hypothesis entails the premise: {bool(scores >= 0.5)}')

['Hypothesis entails the premise: False']

Downloads last month
221
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.