Snapshot
Browse files- expand.py +26 -16
- expand_llm.py +1 -1
- expand_test.py +29 -28
- run.py +3 -2
expand.py
CHANGED
|
@@ -1,26 +1,33 @@
|
|
| 1 |
from collections import defaultdict
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
-
from typing import Protocol
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
@dataclass
|
| 6 |
class Series:
|
| 7 |
id: int
|
| 8 |
tokens: list[int]
|
| 9 |
budget: float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
@dataclass
|
| 12 |
class Batch:
|
| 13 |
items: list[Series]
|
| 14 |
|
| 15 |
-
@dataclass
|
| 16 |
-
class ExpansionOne:
|
| 17 |
-
token: int
|
| 18 |
-
cost: float
|
| 19 |
-
|
| 20 |
@dataclass
|
| 21 |
class ExpansionOneResult:
|
| 22 |
series: Series
|
| 23 |
-
expansions: list[
|
| 24 |
|
| 25 |
@dataclass
|
| 26 |
class ExpansionOneResultBatch:
|
|
@@ -33,7 +40,7 @@ class ExpanderOneBatch(Protocol):
|
|
| 33 |
@dataclass
|
| 34 |
class ExpansionResult:
|
| 35 |
series: Series
|
| 36 |
-
expansions: list[list[
|
| 37 |
|
| 38 |
@dataclass
|
| 39 |
class ExpansionResultBatch:
|
|
@@ -42,7 +49,12 @@ class ExpansionResultBatch:
|
|
| 42 |
def compute_new_series(result: ExpansionOneResult) -> list[Series]:
|
| 43 |
results = []
|
| 44 |
for expansion in result.expansions:
|
| 45 |
-
results.append(Series(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
return results
|
| 47 |
|
| 48 |
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
|
|
@@ -51,16 +63,14 @@ def compute_expansions(original_series: list[Series], expanded_series: list[Seri
|
|
| 51 |
# group original series by id
|
| 52 |
original_series_by_id = {s.id: s for s in original_series}
|
| 53 |
# group expanded series by id
|
| 54 |
-
expanded_series_by_id: dict[int, list[list[
|
| 55 |
for s in expanded_series:
|
| 56 |
-
|
|
|
|
| 57 |
results = []
|
| 58 |
for id, s in original_series_by_id.items():
|
| 59 |
expansions = expanded_series_by_id[id]
|
| 60 |
-
|
| 61 |
-
l = len(s.tokens)
|
| 62 |
-
trimmed_expansions = [e[l:] for e in expansions if len(e) > l]
|
| 63 |
-
expansion_result = ExpansionResult(series=s, expansions=trimmed_expansions)
|
| 64 |
results.append(expansion_result)
|
| 65 |
return ExpansionResultBatch(items=results)
|
| 66 |
|
|
|
|
| 1 |
from collections import defaultdict
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Protocol, Self
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class Expansion:
|
| 7 |
+
token: int
|
| 8 |
+
cost: float
|
| 9 |
|
| 10 |
@dataclass
|
| 11 |
class Series:
|
| 12 |
id: int
|
| 13 |
tokens: list[int]
|
| 14 |
budget: float
|
| 15 |
+
expansions: list[Expansion] = field(default_factory=list)
|
| 16 |
+
|
| 17 |
+
def get_all_tokens(self) -> list[int]:
|
| 18 |
+
return self.tokens + [e.token for e in self.expansions]
|
| 19 |
+
|
| 20 |
+
def get_remaining_budget(self) -> float:
|
| 21 |
+
return self.budget + sum(e.cost for e in self.expansions)
|
| 22 |
|
| 23 |
@dataclass
|
| 24 |
class Batch:
|
| 25 |
items: list[Series]
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
@dataclass
|
| 28 |
class ExpansionOneResult:
|
| 29 |
series: Series
|
| 30 |
+
expansions: list[Expansion]
|
| 31 |
|
| 32 |
@dataclass
|
| 33 |
class ExpansionOneResultBatch:
|
|
|
|
| 40 |
@dataclass
|
| 41 |
class ExpansionResult:
|
| 42 |
series: Series
|
| 43 |
+
expansions: list[list[Expansion]]
|
| 44 |
|
| 45 |
@dataclass
|
| 46 |
class ExpansionResultBatch:
|
|
|
|
| 49 |
def compute_new_series(result: ExpansionOneResult) -> list[Series]:
|
| 50 |
results = []
|
| 51 |
for expansion in result.expansions:
|
| 52 |
+
results.append(Series(
|
| 53 |
+
id=result.series.id,
|
| 54 |
+
tokens=result.series.tokens,
|
| 55 |
+
expansions=result.series.expansions + [expansion],
|
| 56 |
+
budget=result.series.budget
|
| 57 |
+
))
|
| 58 |
return results
|
| 59 |
|
| 60 |
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
|
|
|
|
| 63 |
# group original series by id
|
| 64 |
original_series_by_id = {s.id: s for s in original_series}
|
| 65 |
# group expanded series by id
|
| 66 |
+
expanded_series_by_id: dict[int, list[list[Expansion]]] = defaultdict(list)
|
| 67 |
for s in expanded_series:
|
| 68 |
+
if len(s.expansions) != 0:
|
| 69 |
+
expanded_series_by_id[s.id].append(s.expansions)
|
| 70 |
results = []
|
| 71 |
for id, s in original_series_by_id.items():
|
| 72 |
expansions = expanded_series_by_id[id]
|
| 73 |
+
expansion_result = ExpansionResult(series=s, expansions=expansions)
|
|
|
|
|
|
|
|
|
|
| 74 |
results.append(expansion_result)
|
| 75 |
return ExpansionResultBatch(items=results)
|
| 76 |
|
expand_llm.py
CHANGED
|
@@ -15,6 +15,6 @@ class ExpanderOneBatchLLM:
|
|
| 15 |
next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
|
| 16 |
results = []
|
| 17 |
for s, next_tokens in zip(batch.items, next_tokens):
|
| 18 |
-
expansions = [
|
| 19 |
results.append(ExpansionOneResult(series=s, expansions=expansions))
|
| 20 |
return ExpansionOneResultBatch(items=results)
|
|
|
|
| 15 |
next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
|
| 16 |
results = []
|
| 17 |
for s, next_tokens in zip(batch.items, next_tokens):
|
| 18 |
+
expansions = [Expansion(token=token, cost=logprob) for token, logprob in next_tokens if logprob + s.get_remaining_budget() >= 0]
|
| 19 |
results.append(ExpansionOneResult(series=s, expansions=expansions))
|
| 20 |
return ExpansionOneResultBatch(items=results)
|
expand_test.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
-
from expand import Series, ExpanderOneBatch,
|
| 3 |
|
| 4 |
possible_sequences = [
|
| 5 |
[1, 21, 31, 41],
|
|
@@ -9,11 +9,12 @@ possible_sequences = [
|
|
| 9 |
[1, 22, 34, 41],
|
| 10 |
]
|
| 11 |
|
| 12 |
-
def expand_series(series: Series) -> list[
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
class HardcodedExpanderOneBatch(ExpanderOneBatch):
|
| 19 |
def expand(self, batch: Batch) -> ExpansionOneResultBatch:
|
|
@@ -38,8 +39,8 @@ def test_expander_budget_one():
|
|
| 38 |
expanded = expander.expand(Batch(items=[s]))
|
| 39 |
expected = ExpansionOneResultBatch(
|
| 40 |
items=[ExpansionOneResult(series=s, expansions=[
|
| 41 |
-
|
| 42 |
-
|
| 43 |
])]
|
| 44 |
)
|
| 45 |
assert expected == expanded
|
|
@@ -49,8 +50,8 @@ def test_expander_budget_two():
|
|
| 49 |
expanded = expander.expand(Batch(items=[s]))
|
| 50 |
expected = ExpansionOneResultBatch(
|
| 51 |
items=[ExpansionOneResult(series=s, expansions=[
|
| 52 |
-
|
| 53 |
-
|
| 54 |
])]
|
| 55 |
)
|
| 56 |
assert expected == expanded
|
|
@@ -68,8 +69,8 @@ def test_expander_budget_one_two_tokens():
|
|
| 68 |
expanded = expander.expand(Batch(items=[s]))
|
| 69 |
expected = ExpansionOneResultBatch(
|
| 70 |
items=[ExpansionOneResult(series=s, expansions=[
|
| 71 |
-
|
| 72 |
-
|
| 73 |
])]
|
| 74 |
)
|
| 75 |
assert expected == expanded
|
|
@@ -81,12 +82,12 @@ def test_expander_budget_one_two_tokens_two_series():
|
|
| 81 |
expected = ExpansionOneResultBatch(
|
| 82 |
items=[
|
| 83 |
ExpansionOneResult(series=s1, expansions=[
|
| 84 |
-
|
| 85 |
-
|
| 86 |
]),
|
| 87 |
ExpansionOneResult(series=s2, expansions=[
|
| 88 |
-
|
| 89 |
-
|
| 90 |
])
|
| 91 |
]
|
| 92 |
)
|
|
@@ -102,15 +103,15 @@ def test_expand_01():
|
|
| 102 |
ExpansionResult(
|
| 103 |
series=Series(id=0, tokens=[1, 21], budget=1.0),
|
| 104 |
expansions=[
|
| 105 |
-
[31],
|
| 106 |
-
[32],
|
| 107 |
]
|
| 108 |
),
|
| 109 |
ExpansionResult(
|
| 110 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
| 111 |
expansions=[
|
| 112 |
-
[33],
|
| 113 |
-
[34],
|
| 114 |
]
|
| 115 |
),
|
| 116 |
])
|
|
@@ -125,16 +126,16 @@ def test_expand_02():
|
|
| 125 |
ExpansionResult(
|
| 126 |
series=Series(id=0, tokens=[1, 21], budget=2.0),
|
| 127 |
expansions=[
|
| 128 |
-
[31, 41],
|
| 129 |
-
[31, 42],
|
| 130 |
-
[32, 41],
|
| 131 |
]
|
| 132 |
),
|
| 133 |
ExpansionResult(
|
| 134 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
| 135 |
expansions=[
|
| 136 |
-
[33],
|
| 137 |
-
[34],
|
| 138 |
]
|
| 139 |
),
|
| 140 |
])
|
|
@@ -149,9 +150,9 @@ def test_expand_03():
|
|
| 149 |
ExpansionResult(
|
| 150 |
series=Series(id=0, tokens=[1, 21], budget=3.0),
|
| 151 |
expansions=[
|
| 152 |
-
[31, 41],
|
| 153 |
-
[31, 42],
|
| 154 |
-
[32, 41, 51],
|
| 155 |
]
|
| 156 |
),
|
| 157 |
ExpansionResult(
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
+
from expand import Series, ExpanderOneBatch, Expansion, Batch, ExpansionOneResult, ExpansionOneResultBatch, ExpansionResult, ExpansionResultBatch, expand
|
| 3 |
|
| 4 |
possible_sequences = [
|
| 5 |
[1, 21, 31, 41],
|
|
|
|
| 9 |
[1, 22, 34, 41],
|
| 10 |
]
|
| 11 |
|
| 12 |
+
def expand_series(series: Series) -> list[Expansion]:
|
| 13 |
+
all_tokens = series.get_all_tokens()
|
| 14 |
+
l = len(all_tokens)
|
| 15 |
+
items = [s[l] for s in possible_sequences if s[:l] == all_tokens and len(s) > l]
|
| 16 |
+
candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)]
|
| 17 |
+
return [c for c in candidates if c.cost + series.get_remaining_budget() >= 0]
|
| 18 |
|
| 19 |
class HardcodedExpanderOneBatch(ExpanderOneBatch):
|
| 20 |
def expand(self, batch: Batch) -> ExpansionOneResultBatch:
|
|
|
|
| 39 |
expanded = expander.expand(Batch(items=[s]))
|
| 40 |
expected = ExpansionOneResultBatch(
|
| 41 |
items=[ExpansionOneResult(series=s, expansions=[
|
| 42 |
+
Expansion(token=21, cost=-1.0),
|
| 43 |
+
Expansion(token=22, cost=-1.0),
|
| 44 |
])]
|
| 45 |
)
|
| 46 |
assert expected == expanded
|
|
|
|
| 50 |
expanded = expander.expand(Batch(items=[s]))
|
| 51 |
expected = ExpansionOneResultBatch(
|
| 52 |
items=[ExpansionOneResult(series=s, expansions=[
|
| 53 |
+
Expansion(token=21, cost=-1.0),
|
| 54 |
+
Expansion(token=22, cost=-1.0),
|
| 55 |
])]
|
| 56 |
)
|
| 57 |
assert expected == expanded
|
|
|
|
| 69 |
expanded = expander.expand(Batch(items=[s]))
|
| 70 |
expected = ExpansionOneResultBatch(
|
| 71 |
items=[ExpansionOneResult(series=s, expansions=[
|
| 72 |
+
Expansion(token=33, cost=-1.0),
|
| 73 |
+
Expansion(token=34, cost=-1.0),
|
| 74 |
])]
|
| 75 |
)
|
| 76 |
assert expected == expanded
|
|
|
|
| 82 |
expected = ExpansionOneResultBatch(
|
| 83 |
items=[
|
| 84 |
ExpansionOneResult(series=s1, expansions=[
|
| 85 |
+
Expansion(token=41, cost=-1.0),
|
| 86 |
+
Expansion(token=42, cost=-1.0),
|
| 87 |
]),
|
| 88 |
ExpansionOneResult(series=s2, expansions=[
|
| 89 |
+
Expansion(token=33, cost=-1.0),
|
| 90 |
+
Expansion(token=34, cost=-1.0),
|
| 91 |
])
|
| 92 |
]
|
| 93 |
)
|
|
|
|
| 103 |
ExpansionResult(
|
| 104 |
series=Series(id=0, tokens=[1, 21], budget=1.0),
|
| 105 |
expansions=[
|
| 106 |
+
[Expansion(token=31, cost=-1.0)],
|
| 107 |
+
[Expansion(token=32, cost=-1.0)],
|
| 108 |
]
|
| 109 |
),
|
| 110 |
ExpansionResult(
|
| 111 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
| 112 |
expansions=[
|
| 113 |
+
[Expansion(token=33, cost=-1.0)],
|
| 114 |
+
[Expansion(token=34, cost=-1.0)],
|
| 115 |
]
|
| 116 |
),
|
| 117 |
])
|
|
|
|
| 126 |
ExpansionResult(
|
| 127 |
series=Series(id=0, tokens=[1, 21], budget=2.0),
|
| 128 |
expansions=[
|
| 129 |
+
[Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
| 130 |
+
[Expansion(token=31, cost=-1.0), Expansion(token=42, cost=-1.0)],
|
| 131 |
+
[Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
| 132 |
]
|
| 133 |
),
|
| 134 |
ExpansionResult(
|
| 135 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
| 136 |
expansions=[
|
| 137 |
+
[Expansion(token=33, cost=-1.0)],
|
| 138 |
+
[Expansion(token=34, cost=-1.0)],
|
| 139 |
]
|
| 140 |
),
|
| 141 |
])
|
|
|
|
| 150 |
ExpansionResult(
|
| 151 |
series=Series(id=0, tokens=[1, 21], budget=3.0),
|
| 152 |
expansions=[
|
| 153 |
+
[Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
| 154 |
+
[Expansion(token=31, cost=-1.0), Expansion(token=42, cost=-1.0)],
|
| 155 |
+
[Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0), Expansion(token=51, cost=-1.0)],
|
| 156 |
]
|
| 157 |
),
|
| 158 |
ExpansionResult(
|
run.py
CHANGED
|
@@ -29,7 +29,7 @@ expander = ExpanderOneBatchLLM(model, tokenizer)
|
|
| 29 |
#%%
|
| 30 |
series = []
|
| 31 |
for i, x in enumerate(contexts):
|
| 32 |
-
series.append(Series(id=i, tokens=x, budget=
|
| 33 |
|
| 34 |
#%%
|
| 35 |
batch = Batch(items=series)
|
|
@@ -42,7 +42,8 @@ def print_expansions(expansions: ExpansionResultBatch):
|
|
| 42 |
for result in expansions.items:
|
| 43 |
for expansion in result.expansions:
|
| 44 |
# convert tokens to string
|
| 45 |
-
|
|
|
|
| 46 |
print(f"{result.series.id}: {expansion} {s}")
|
| 47 |
|
| 48 |
print_expansions(expanded)
|
|
|
|
| 29 |
#%%
|
| 30 |
series = []
|
| 31 |
for i, x in enumerate(contexts):
|
| 32 |
+
series.append(Series(id=i, tokens=x, budget=7.0))
|
| 33 |
|
| 34 |
#%%
|
| 35 |
batch = Batch(items=series)
|
|
|
|
| 42 |
for result in expansions.items:
|
| 43 |
for expansion in result.expansions:
|
| 44 |
# convert tokens to string
|
| 45 |
+
tokens = [e.token for e in expansion]
|
| 46 |
+
s = tokenizer.decode(tokens)
|
| 47 |
print(f"{result.series.id}: {expansion} {s}")
|
| 48 |
|
| 49 |
print_expansions(expanded)
|