File size: 1,903 Bytes
cd607b2 eac37df 7b856a8 f5ec828 eac37df cd607b2 7b856a8 69deff6 7b856a8 8200c4e 4e3dc76 7b856a8 4e3dc76 8200c4e 4e3dc76 7b856a8 8200c4e 4e3dc76 7b856a8 4e3dc76 8200c4e 4e3dc76 69deff6 4e3dc76 eac37df 7b856a8 4e3dc76 8200c4e 4e3dc76 d2c6ae0 5b30d27 7b856a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# + tags=["hide_inp"]
desc = """
### Question Answering with Retrieval
Chain that answers questions with embeedding based retrieval. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/qa.ipynb)
(Adapted from [OpenAI Notebook](https://github.com/openai/openai-cookbook/blob/main/examples/Question_answering_using_embeddings.ipynb).)
"""
# -
# $
import datasets
import numpy as np
from minichain import prompt, transform, show, OpenAIEmbed, OpenAI
from manifest import Manifest
# We use Hugging Face Datasets as the database by assigning
# a FAISS index.
olympics = datasets.load_from_disk("olympics.data")
olympics.add_faiss_index("embeddings")
# Fast KNN retieval prompt
@prompt(OpenAIEmbed())
def embed(model, inp):
return model(inp)
@transform()
def get_neighbors(inp, k):
res = olympics.get_nearest_examples("embeddings", np.array(inp), k)
return res.examples["content"]
@prompt(OpenAI(), template_file="qa.pmpt.tpl")
def get_result(model, query, neighbors):
return model(dict(question=query, docs=neighbors))
def qa(query):
n = get_neighbors(embed(query), 3)
return get_result(query, n)
# $
questions = ["Who won the 2020 Summer Olympics men's high jump?",
"Why was the 2020 Summer Olympics originally postponed?",
"In the 2020 Summer Olympics, how many gold medals did the country which won the most medals win?",
"What is the total number of medals won by France?",
"What is the tallest mountain in the world?"]
gradio = show(qa,
examples=questions,
subprompts=[embed, get_result],
description=desc,
code=open("qa.py", "r").read().split("$")[1].strip().strip("#").strip(),
)
if __name__ == "__main__":
gradio.queue().launch()
|