LaferriereJC commited on
Commit
fc90506
·
verified ·
1 Parent(s): d9f7113

Updated README.me

Browse files

created proper readme

Files changed (1) hide show
  1. README.md +148 -3
README.md CHANGED
@@ -1,3 +1,148 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+ #datasets
6
+
7
+ - Yale-LILY/FOLIO
8
+ - yuan-yang/MALLS-v0
9
+ - apergo-ai/text2log (1661 records)
10
+
11
+ how to load
12
+
13
+ ```
14
+ device = "cuda"
15
+ model_name_or_path = "microsoft/Phi-3-mini-4k-instruct"
16
+
17
+ model = transformers.AutoModelForCausalLM.from_pretrained(
18
+ model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)
19
+
20
+ reft_model = pyreft.ReftModel.load(
21
+ "LaferriereJC/Phi-3-mini-4k-instruct-FOL-pyreft", model
22
+ )
23
+ ```
24
+
25
+ how to use
26
+ ```
27
+ !git clone https://huggingface.co/LaferriereJC/Phi-3-mini-4k-instruct-FOL-pyreft
28
+ from transformers import AutoModelForCausalLM
29
+ import torch
30
+ import pyreft
31
+ import os
32
+ import transformers
33
+
34
+ device = 'cuda'
35
+ model_name_or_path = "microsoft/Phi-3-mini-4k-instruct"
36
+
37
+ attn_implementation = "eager"
38
+ torch_dtype = torch.float16
39
+ #"microsoft/Phi-3-mini-4k-instruct"
40
+
41
+ model = transformers.AutoModelForCausalLM.from_pretrained(
42
+ model_name_or_path, torch_dtype=torch.bfloat16, device_map=device,trust_remote_code=True)
43
+
44
+
45
+ # Define the PyReFT configuration
46
+ layers = range(model.config.num_hidden_layers)
47
+ representations = [{
48
+ "component": f"model.layers[{l}].output",
49
+ "intervention": pyreft.LoreftIntervention(
50
+ embed_dim=model.config.hidden_size,
51
+ low_rank_dimension=16
52
+ )
53
+ } for l in layers]
54
+
55
+ reft_config = pyreft.ReftConfig(representations=representations)
56
+
57
+ # Initialize the PyReFT model
58
+ reft_model = pyreft.get_reft_model(model, reft_config)
59
+
60
+ # Load the saved PyReFT model
61
+ local_directory = "./Phi-3-mini-4k-instruct-FOL-pyreft"
62
+ interventions = {}
63
+ for l in layers:
64
+ component = f"model.layers[{l}].output"
65
+ file_path = os.path.join(local_directory, f"intkey_comp.{component}.unit.pos.nunit.1#0.bin")
66
+ if os.path.exists(file_path):
67
+ with open(file_path, "rb") as f:
68
+ adjusted_key = f"comp.{component}.unit.pos.nunit.1#0"
69
+ interventions[adjusted_key] = torch.load(f)
70
+
71
+ # Apply the loaded weights to the model
72
+ for component, state_dict in interventions.items():
73
+ if component in reft_model.interventions:
74
+ reft_model.interventions[component][0].load_state_dict(state_dict)
75
+ else:
76
+ print(f"Key mismatch: {component} not found in reft_model.interventions")
77
+
78
+ # Set the device to CUDA
79
+ reft_model.set_device("cuda")
80
+
81
+ # Verify the model
82
+ reft_model.print_trainable_parameters()
83
+
84
+ #model.half()
85
+ # get tokenizer
86
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
87
+ model_name_or_path, model_max_length=216,
88
+ padding_side="right", use_fast=True,
89
+ attn_implementation=attn_implementation
90
+ #, add_eos_token=True, add_bos_token=True
91
+ )
92
+
93
+ tokenizer.pad_token = tokenizer.eos_token
94
+
95
+ # position info about the interventions
96
+ share_weights = True # whether the prefix and suffix interventions sharing weights.
97
+ positions="f3+l3" # the intervening positions of prefix tokens (f[irst]1) and suffix tokens (l[ast]1).
98
+ first_n, last_n = pyreft.parse_positions(positions)
99
+
100
+ terminators = [
101
+ tokenizer.eos_token_id,
102
+ ]
103
+
104
+ prompt_no_input_template = """\n<|user|>:%s</s>\n<|assistant|>:"""
105
+
106
+ test_instruction = f"""tell me something I don't know"""
107
+ # tokenize and prepare the input
108
+ prompt = prompt_no_input_template % test_instruction
109
+ prompt = tokenizer(prompt, return_tensors="pt").to(device)
110
+
111
+ unit_locations = torch.IntTensor([pyreft.get_intervention_locations(
112
+ last_position=prompt["input_ids"].shape[-1],
113
+ first_n=first_n,
114
+ last_n=last_n,
115
+ pad_mode="last",
116
+ num_interventions=len(reft_config.representations),
117
+ share_weights=share_weights
118
+ )]).permute(1, 0, 2).tolist()
119
+
120
+ _, reft_response = reft_model.generate(
121
+ prompt, unit_locations={"sources->base": (None, unit_locations)},
122
+ intervene_on_prompt=True, max_new_tokens=216, do_sample=True, top_k=50,temperature=0.7,
123
+ eos_token_id=terminators, early_stopping=True
124
+ )
125
+ print(tokenizer.decode(reft_response[0], skip_special_tokens=True))
126
+
127
+
128
+ ```
129
+
130
+ response
131
+ ```
132
+ :tell me something I don't know</s> :exists x1.(_thing(x1) & _donknow(x1))
133
+ ```
134
+
135
+ training settings
136
+ ```
137
+ per_device_train_batch_size=6,
138
+ logging_steps=1,
139
+ optim='paged_lion_8bit',
140
+ gradient_checkpointing_kwargs={"use_reentrant": False},
141
+ learning_rate=0.0003,
142
+ warmup_ratio=.1,
143
+ adam_beta2=0.95,
144
+ adam_epsilon=0.00001,
145
+ save_strategy='epoch',
146
+ max_grad_norm=1.0,
147
+ lr_scheduler_type='cosine',
148
+ ```