Spaces:
No application file
No application file
# utils.py | |
def prompt_weights(default_weights: dict) -> dict: | |
""" | |
Interactively collect weights from the user (in %) and validate that they sum to 100. | |
Press Enter to keep defaults. | |
""" | |
print("\\n⚖️ Set Weights (press Enter to keep default values)") | |
new_w = {} | |
for key, val in default_weights.items(): | |
try: | |
raw = input(f" {key.replace('_WEIGHT','').title()} (%), default {val}: ").strip() | |
new_w[key] = val if raw == "" else int(raw) | |
except ValueError: | |
print(f" Invalid input for {key}, keeping default {val}.") | |
new_w[key] = val | |
total = sum(new_w.values()) | |
if total != 100: | |
print(f" Weights sum to {total}, normalizing to 100 proportionally.") | |
factor = 100.0 / total if total else 0 | |
for k in new_w: | |
new_w[k] = int(round(new_w[k] * factor)) | |
# ensure exact 100 by adjusting the largest key if rounding drift | |
drift = 100 - sum(new_w.values()) | |
if drift != 0: | |
largest_key = max(new_w.keys(), key=new_w.get) | |
new_w[largest_key] += drift | |
print(" Final Weights:", new_w) | |
return new_w |