Spaces:
Runtime error
Runtime error
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| # 1. Reranker λͺ¨λΈ λ‘λ© | |
| reranker_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-base") | |
| reranker_model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-base") | |
| def rerank_documents(query: str, docs: list, top_k: int = 5) -> list: | |
| """ | |
| κ²μλ λ¬Έμ 리μ€νΈλ₯Ό Queryμ λΉκ΅ν΄μ relevance μμλ‘ μ¬μ λ ¬νλ€. | |
| """ | |
| pairs = [(query, doc) for doc in docs] | |
| inputs = reranker_tokenizer.batch_encode_plus( | |
| pairs, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=512 | |
| ) | |
| with torch.no_grad(): | |
| scores = reranker_model(**inputs).logits.squeeze(-1) # (batch_size,) | |
| scores = scores.tolist() | |
| # μ μ λμ μμλλ‘ μ λ ¬ | |
| sorted_docs = [doc for _, doc in sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)] | |
| return sorted_docs[:top_k] | |