Spaces:
Runtime error
Runtime error
| import torch | |
| import random | |
| import numpy as np | |
| from collections import defaultdict | |
| from torch.distributions import Categorical | |
| from torch.nn.utils.rnn import pad_sequence | |
| from scrl.model import labels_to_summary | |
| from nltk import word_tokenize | |
| from pprint import pprint | |
| def sample_from_policy( | |
| input_ids, | |
| probs, | |
| device="cuda", | |
| force_diff=True, | |
| diff_trials=1000, | |
| ): | |
| m = Categorical(probs) | |
| argmax_labels = torch.argmax(probs, dim=2) | |
| sample_labels = m.sample() | |
| if force_diff: | |
| for _ in range(diff_trials): | |
| if (argmax_labels == sample_labels).all(): | |
| sample_labels = m.sample() | |
| else: | |
| break | |
| sample_probs = m.log_prob(sample_labels) | |
| return sample_probs, sample_labels | |
| def best_of_k_samples( | |
| args, | |
| manager, | |
| tokenizer, | |
| reward_generator, | |
| input_ids, | |
| batch, | |
| probs, | |
| k_samples=50, | |
| return_all=False | |
| ): | |
| batch_size = probs.size(0) | |
| prob_batches = [] | |
| summary_batches = [] | |
| reward_batches = [] | |
| detail_batches = [] | |
| label_batches = [] | |
| for _ in range(k_samples): | |
| sample_probs, sample_labels = sample_from_policy( | |
| input_ids, | |
| probs, | |
| device=args.device | |
| ) | |
| sample_summaries = labels_to_summary( | |
| input_ids, sample_labels, tokenizer | |
| ) | |
| sample_rewards, sample_details = reward_generator( | |
| batch["document"], sample_summaries | |
| ) | |
| prob_batches.append(sample_probs) | |
| summary_batches.append(sample_summaries) | |
| reward_batches.append(sample_rewards) | |
| detail_batches.append(sample_details) | |
| label_batches.append(sample_labels) | |
| best_indices = [] | |
| for i in range(batch_size): | |
| rewards = [reward_batches[j][i] for j in range(k_samples)] | |
| scored = sorted(enumerate(rewards), key=lambda x: x[1], reverse=True) | |
| best_idx = scored[0][0] | |
| best_indices.append(best_idx) | |
| sample_probs = torch.stack([prob_batches[j][i] for i, j in enumerate(best_indices)]) | |
| sample_summaries = [summary_batches[j][i] for i, j in enumerate(best_indices)] | |
| sample_rewards = [reward_batches[j][i] for i, j in enumerate(best_indices)] | |
| sample_labels = torch.stack([label_batches[j][i] for i, j in enumerate(best_indices)]) | |
| sample_details = [] | |
| for i, j in enumerate(best_indices): | |
| detail_keys = sorted(detail_batches[0].keys()) | |
| details = defaultdict(list) | |
| for k in detail_keys: | |
| details[k].append(detail_batches[j][k][i]) | |
| sample_details.append(details) | |
| sample_data = { | |
| "probs": prob_batches, | |
| "rewards": reward_batches, | |
| "summaries": summary_batches, | |
| "details": detail_batches, | |
| "labels": label_batches, | |
| } | |
| return sample_probs, sample_summaries, sample_rewards, sample_details, sample_labels, sample_data | |