PEFT
English
MrLight commited on
Commit
47e6886
·
1 Parent(s): 92cfbfa

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +62 -0
README.md CHANGED
@@ -1,3 +1,65 @@
1
  ---
2
  license: llama2
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: llama2
3
  ---
4
+
5
+
6
+ # RepLLaMA-7B-Passage
7
+
8
+ [Fine-Tuning LLaMA for Multi-Stage Text Retrieval](TODO).
9
+ Xueguang Ma, Liang Wang, Nan Yang, Furu Wei, Jimmy Lin, arXiv 2023
10
+
11
+ This model is fine-tuned from LLaMA-2-7B using LoRA and the embedding size is 4096.
12
+
13
+ ## Usage
14
+
15
+ Below is an example to encode a query and a document, and then compute their similarity using their embedding.
16
+
17
+ ```python
18
+ import torch
19
+ from transformers import AutoModel, AutoTokenizer
20
+ from peft import PeftModel, PeftConfig
21
+
22
+ def get_model(peft_model_name):
23
+ config = PeftConfig.from_pretrained(peft_model_name)
24
+ base_model = AutoModel.from_pretrained(config.base_model_name_or_path)
25
+ model = PeftModel.from_pretrained(base_model, peft_model_name)
26
+ model = model.merge_and_unload()
27
+ model.eval()
28
+ return model
29
+
30
+ # Load the tokenizer and model
31
+ tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
32
+ model = get_model('castorini/repllama-v1-7b-lora-passage')
33
+
34
+ # Define query and document inputs
35
+ query = "What is llama?"
36
+ title = "Llama"
37
+ passage = "The llama is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era."
38
+ query_input = tokenizer(f'query: {query}</s>', return_tensors='pt')
39
+ document_input = tokenizer(f'passage: {title} {passage}</s>', return_tensors='pt')
40
+
41
+ # Run the model forward to compute embeddings and query-document similarity score
42
+ with torch.no_grad():
43
+ # compute query embedding
44
+ query_outputs = model(**query_input)
45
+ query_embedding = query_outputs.last_hidden_state[0][-1]
46
+ query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=0)
47
+
48
+ # compute document embedding
49
+ document_outputs = model(**document_input)
50
+ document_embeddings = document_outputs.last_hidden_state[0][-1]
51
+ document_embeddings = torch.nn.functional.normalize(document_embeddings, p=2, dim=0)
52
+
53
+ # compute similarity score
54
+ score = torch.dot(query_embedding, document_embeddings)
55
+ print(score)
56
+
57
+ ```
58
+
59
+ ## Citation
60
+
61
+ If you find our paper or models helpful, please consider cite as follows:
62
+
63
+ ```
64
+ TODO
65
+ ```