abhinavkashyap92 commited on
Commit
cf2363e
·
1 Parent(s): 06db20c

Create Initial Model Card for flan-t5-xl-lora

Browse files
Files changed (1) hide show
  1. README.md +96 -0
README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - tatsu-lab/alpaca
5
+ ---
6
+
7
+ ## 🍮 🦙 Flan-Alpaca: Instruction Tuning from Humans and Machines
8
+
9
+ Our [repository](https://github.com/declare-lab/flan-alpaca) contains code for extending the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
10
+ synthetic instruction tuning to existing instruction-tuned models such as [Flan-T5](https://arxiv.org/abs/2210.11416).
11
+ The pretrained models and demos are available on HuggingFace 🤗 :
12
+
13
+ | Model | Parameters | Training GPUs |
14
+ |---------------------------------------------------------------------------|------------|-----------------|
15
+ | [Flan-Alpaca-Base](https://huggingface.co/declare-lab/flan-alpaca-base) | 220M | 1x A6000 |
16
+ | [Flan-Alpaca-Large](https://huggingface.co/declare-lab/flan-alpaca-large) | 770M | 1x A6000 |
17
+ | [Flan-Alpaca-XL](https://huggingface.co/declare-lab/flan-alpaca-xl) | 3B | 1x A6000 |
18
+ | [Flan-Alpaca-XXL](https://huggingface.co/declare-lab/flan-alpaca-xxl) | 11B | 4x A6000 (FSDP) |
19
+
20
+ ### Why?
21
+
22
+ [Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html) represents an exciting new direction
23
+ to approximate the performance of large language models (LLMs) like ChatGPT cheaply and easily.
24
+ Concretely, they leverage an LLM such as GPT-3 to generate instructions as synthetic training data.
25
+ The synthetic data which covers more than 50k tasks can then be used to finetune a smaller model.
26
+ However, the original implementation is less accessible due to licensing constraints of the
27
+ underlying [LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) model.
28
+ Furthermore, users have noted [potential noise](https://github.com/tloen/alpaca-lora/issues/65) in the synthetic
29
+ dataset. Hence, it may be better to explore a fully accessible model that is already trained on high-quality (but
30
+ less diverse) instructions such as [Flan-T5](https://arxiv.org/abs/2210.11416).
31
+
32
+ ### Usage
33
+ This uses Huggingface PEFT library for Parameter Efficient Fine Tuning
34
+
35
+ ```
36
+ import torch
37
+ from peft import PeftModel
38
+ from transformers import GenerationConfig
39
+
40
+
41
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
42
+
43
+
44
+ BASE_MODEL = "google/flan-t5-xl"
45
+ LORA_WEIGHTS = "declare-lab/flan-alpaca-xl-lora"
46
+ TEMPERATURE = 1.0
47
+ TOP_P = 0.75
48
+ TOP_K = 40
49
+ NUM_BEAMS = 4
50
+ MAX_NEW_TOKENS = 128
51
+
52
+ if torch.cuda.is_available():
53
+ device = "cuda"
54
+ else:
55
+ device = "cpu"
56
+
57
+
58
+ if device == "cuda":
59
+ model = AutoModelForSeq2SeqLM.from_pretrained(
60
+ BASE_MODEL,
61
+ device_map="auto",
62
+ )
63
+ model = PeftModel.from_pretrained(model, LORA_WEIGHTS, force_download=True)
64
+ else:
65
+ model = AutoModelForSeq2SeqLM.from_pretrained(
66
+ BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
67
+ )
68
+ model = PeftModel.from_pretrained(
69
+ model,
70
+ LORA_WEIGHTS,
71
+ device_map={"": device},
72
+ )
73
+
74
+
75
+ prompt = "Write a short email to show that 42 is the optimal seed for training neural networks"
76
+
77
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
78
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
79
+ input_ids = input_ids.to(device)
80
+
81
+ generation_config = GenerationConfig(
82
+ temperature=TEMPERATURE,
83
+ top_p=TOP_P,
84
+ top_k=TOP_K,
85
+ num_beams=NUM_BEAMS,
86
+ )
87
+ generation_output = model.generate(
88
+ input_ids=input_ids,
89
+ generation_config=generation_config,
90
+ return_dict_in_generate=True,
91
+ output_scores=True,
92
+ max_new_tokens=MAX_NEW_TOKENS,
93
+ )
94
+ print(tokenizer.batch_decode(generation_output.sequences)[0])
95
+
96
+ ```