Amossofer commited on
Commit
e45fce6
·
1 Parent(s): 8756636
Files changed (1) hide show
  1. app.py +16 -28
app.py CHANGED
@@ -11,40 +11,28 @@ model = AutoModelForCausalLM.from_pretrained(
11
  )
12
  model.eval()
13
 
14
- def blend_generate(sysA, sysB, wA, wB, user_message, max_new_tokens, temperature, top_p):
15
- promptA = f"<|system|>{sysA}\n<|user|>{user_message}\n<|assistant|>"
16
- promptB = f"<|system|>{sysB}\n<|user|>{user_message}\n<|assistant|>"
17
 
18
- idsA = tokenizer(promptA, return_tensors="pt").input_ids.to(model.device)
19
- idsB = tokenizer(promptB, return_tensors="pt").input_ids.to(model.device)
 
20
 
21
- outA, outB = idsA.clone(), idsB.clone()
22
- response = ""
23
 
24
- for _ in range(max_new_tokens):
25
- with torch.no_grad():
26
- logitsA = model(input_ids=outA).logits[:, -1, :]
27
- logitsB = model(input_ids=outB).logits[:, -1, :]
28
 
29
- blended = wA * logitsA + wB * logitsB
30
- blended = blended / temperature
31
 
32
- probs = F.softmax(blended, dim=-1)
33
- sorted_probs, sorted_idx = torch.sort(probs, descending=True)
34
- cum = torch.cumsum(sorted_probs, dim=-1)
35
- sorted_probs[cum > top_p] = 0
36
- sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
37
 
38
- token = sorted_idx[:, torch.multinomial(sorted_probs, 1)].squeeze()
39
- outA = torch.cat([outA, token.unsqueeze(0).unsqueeze(0)], dim=1)
40
- outB = torch.cat([outB, token.unsqueeze(0).unsqueeze(0)], dim=1)
41
-
42
- token_str = tokenizer.decode(token)
43
- response += token_str
44
- yield response
45
-
46
- if token.item() == tokenizer.eos_token_id:
47
- break
48
 
49
  with gr.Blocks() as demo:
50
  gr.Markdown("## Blended Prompt Chat (TinyLlama)")
 
11
  )
12
  model.eval()
13
 
14
+ def blend_generate(prompt, wa, wb):
15
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
 
16
 
17
+ with torch.no_grad():
18
+ output_a = model_a(input_ids)
19
+ output_b = model_b(input_ids)
20
 
21
+ logits_a = output_a.logits[:, -1, :]
22
+ logits_b = output_b.logits[:, -1, :]
23
 
24
+ # Weighted sum of raw logits (before softmax)
25
+ blended_logits = wa * logits_a + wb * logits_b
 
 
26
 
27
+ # Apply softmax safely to get valid probability distribution
28
+ probs = torch.softmax(blended_logits, dim=-1)
29
 
30
+ # Sample token from valid probability distribution
31
+ token = torch.multinomial(probs, 1)
32
+ next_token_id = token.item()
33
+ next_token = tokenizer.decode([next_token_id])
 
34
 
35
+ return next_token
 
 
 
 
 
 
 
 
 
36
 
37
  with gr.Blocks() as demo:
38
  gr.Markdown("## Blended Prompt Chat (TinyLlama)")