Spaces:
Runtime error
Runtime error
| import textattack | |
| import transformers | |
| import pandas as pd | |
| import csv | |
| import string | |
| import pickle | |
| # Construct our four components for `Attack` | |
| from textattack.constraints.pre_transformation import ( | |
| RepeatModification, | |
| StopwordModification, | |
| ) | |
| from textattack.constraints.semantics import WordEmbeddingDistance | |
| from textattack.transformations import WordSwapEmbedding | |
| from textattack.search_methods import GreedyWordSwapWIR | |
| import numpy as np | |
| import json | |
| import random | |
| import re | |
| import textattack.shared.attacked_text as atk | |
| import torch.nn.functional as F | |
| import torch | |
| class InvertedText: | |
| def __init__( | |
| self, | |
| swapped_indexes, | |
| score, | |
| attacked_text, | |
| new_class, | |
| ): | |
| self.attacked_text = attacked_text | |
| self.swapped_indexes = ( | |
| swapped_indexes # dict of swapped indexes with their synonym | |
| ) | |
| self.score = score # value of original class | |
| self.new_class = new_class # class after inversion | |
| def __repr__(self): | |
| return f"InvertedText:\n attacked_text='{self.attacked_text}', \n swapped_indexes={self.swapped_indexes},\n score={self.score}" | |
| def count_matching_classes(original, corrected, perturbed_texts=None): | |
| if len(original) != len(corrected): | |
| raise ValueError("Arrays must have the same length") | |
| hard_samples = [] | |
| easy_samples = [] | |
| matching_count = 0 | |
| for i in range(len(corrected)): | |
| if original[i] == corrected[i]: | |
| matching_count += 1 | |
| easy_samples.append(perturbed_texts[i]) | |
| elif perturbed_texts != None: | |
| hard_samples.append(perturbed_texts[i]) | |
| return matching_count, hard_samples, easy_samples | |
| class Flow_Corrector: | |
| def __init__( | |
| self, | |
| attack, | |
| word_rank_file="en_full_ranked.json", | |
| word_freq_file="en_full_freq.json", | |
| wir_threshold=0.3, | |
| ): | |
| self.attack = attack | |
| self.attack.cuda_() | |
| self.wir_threshold = wir_threshold | |
| with open(word_rank_file, "r") as f: | |
| self.word_ranked_frequence = json.load(f) | |
| with open(word_freq_file, "r") as f: | |
| self.word_frequence = json.load(f) | |
| self.victim_model = attack.goal_function.model | |
| def wir_gradient( | |
| self, | |
| attack, | |
| victim_model, | |
| detected_text, | |
| ): | |
| _, indices_to_order = attack.get_indices_to_order(detected_text) | |
| index_scores = np.zeros(len(indices_to_order)) | |
| grad_output = victim_model.get_grad(detected_text.tokenizer_input) | |
| gradient = grad_output["gradient"] | |
| word2token_mapping = detected_text.align_with_model_tokens(victim_model) | |
| for i, index in enumerate(indices_to_order): | |
| matched_tokens = word2token_mapping[index] | |
| if not matched_tokens: | |
| index_scores[i] = 0.0 | |
| else: | |
| agg_grad = np.mean(gradient[matched_tokens], axis=0) | |
| index_scores[i] = np.linalg.norm(agg_grad, ord=1) | |
| index_order = np.array(indices_to_order)[(-index_scores).argsort()] | |
| return index_order | |
| def get_syn_freq_dict( | |
| self, | |
| index_order, | |
| detected_text, | |
| ): | |
| most_frequent_syn_dict = {} | |
| no_syn = [] | |
| freq_thershold = len(self.word_ranked_frequence) / 10 | |
| for idx in index_order: | |
| # get the synonyms of a specific index | |
| try: | |
| synonyms = [ | |
| attacked_text.words[idx] | |
| for attacked_text in self.attack.get_transformations( | |
| detected_text, detected_text, indices_to_modify=[idx] | |
| ) | |
| ] | |
| # getting synonyms that exists in dataset with thiere frequency rank | |
| ranked_synonyms = { | |
| syn: self.word_ranked_frequence[syn] | |
| for syn in synonyms | |
| if syn in self.word_ranked_frequence.keys() | |
| and self.word_ranked_frequence[syn] < freq_thershold | |
| and self.word_ranked_frequence[detected_text.words[idx]] | |
| > self.word_ranked_frequence[syn] | |
| } | |
| # selecting the M most frequent synonym | |
| if list(ranked_synonyms.keys()) != []: | |
| most_frequent_syn_dict[idx] = list(ranked_synonyms.keys()) | |
| except: | |
| # no synonyms avaialble in the dataset | |
| no_syn.append(idx) | |
| return most_frequent_syn_dict | |
| def build_candidates( | |
| self, detected_text, most_frequent_syn_dict: dict, max_attempt: int | |
| ): | |
| candidates = {} | |
| for _ in range(max_attempt): | |
| syn_dict = {} | |
| current_text = detected_text | |
| for index in most_frequent_syn_dict.keys(): | |
| syn = random.choice(most_frequent_syn_dict[index]) | |
| syn_dict[index] = syn | |
| current_text = current_text.replace_word_at_index(index, syn) | |
| candidates[current_text] = syn_dict | |
| return candidates | |
| def find_dominant_class(self, inverted_texts): | |
| class_counts = {} # Dictionary to store the count of each new class | |
| for text in inverted_texts: | |
| new_class = text.new_class | |
| class_counts[new_class] = class_counts.get(new_class, 0) + 1 | |
| # Find the most dominant class | |
| most_dominant_class = max(class_counts, key=class_counts.get) | |
| return most_dominant_class | |
| def correct(self, detected_texts): | |
| corrected_classes = [] | |
| for detected_text in detected_texts: | |
| # convert to Attacked texts | |
| detected_text = atk.AttackedText(detected_text) | |
| # getting 30% most important indexes | |
| index_order = self.wir_gradient( | |
| self.attack, self.victim_model, detected_text | |
| ) | |
| index_order = index_order[: int(len(index_order) * self.wir_threshold)] | |
| # getting synonyms according to frequency conditiontions | |
| most_frequent_syn_dict = self.get_syn_freq_dict(index_order, detected_text) | |
| # generate M candidates | |
| candidates = self.build_candidates( | |
| detected_text, most_frequent_syn_dict, max_attempt=100 | |
| ) | |
| original_probs = F.softmax(self.victim_model(detected_text.text), dim=1) | |
| original_class = torch.argmax(original_probs).item() | |
| original_golden_prob = float(original_probs[0][original_class]) | |
| nbr_inverted = 0 | |
| inverted_texts = [] # a dictionary of inverted texts with | |
| bad, impr = 0, 0 | |
| dict_deltas = {} | |
| batch_inputs = [candidate.text for candidate in candidates.keys()] | |
| batch_outputs = self.victim_model(batch_inputs) | |
| probabilities = F.softmax(batch_outputs, dim=1) | |
| for i, (candidate, syn_dict) in enumerate(candidates.items()): | |
| corrected_class = torch.argmax(probabilities[i]).item() | |
| new_golden_probability = float(probabilities[i][corrected_class]) | |
| if corrected_class != original_class: | |
| nbr_inverted += 1 | |
| inverted_texts.append( | |
| InvertedText( | |
| syn_dict, new_golden_probability, candidate, corrected_class | |
| ) | |
| ) | |
| else: | |
| delta = new_golden_probability - original_golden_prob | |
| if delta <= 0: | |
| bad += 1 | |
| else: | |
| impr += 1 | |
| dict_deltas[candidate] = delta | |
| if len(original_probs[0]) > 2 and len(inverted_texts) >= len(candidates) / ( | |
| len(original_probs[0]) | |
| ): | |
| # selecting the most dominant class | |
| dominant_class = self.find_dominant_class(inverted_texts) | |
| elif len(inverted_texts) >= len(candidates) / 2: | |
| dominant_class = corrected_class | |
| else: | |
| dominant_class = original_class | |
| corrected_classes.append(dominant_class) | |
| return corrected_classes | |
| def remove_brackets(text): | |
| text = text.replace("[[", "") | |
| text = text.replace("]]", "") | |
| return text | |
| def clean_text(text): | |
| pattern = "[" + re.escape(string.punctuation) + "]" | |
| cleaned_text = re.sub(pattern, " ", text) | |
| return cleaned_text | |
| # Load model, tokenizer, and model_wrapper | |
| model = transformers.AutoModelForSequenceClassification.from_pretrained( | |
| "textattack/bert-base-uncased-imdb" | |
| ) | |
| tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| "textattack/bert-base-uncased-imdb" | |
| ) | |
| model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) | |
| goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper) | |
| constraints = [ | |
| RepeatModification(), | |
| StopwordModification(), | |
| WordEmbeddingDistance(min_cos_sim=0.9), | |
| ] | |
| transformation = WordSwapEmbedding(max_candidates=50) | |
| search_method = GreedyWordSwapWIR(wir_method="gradient") | |
| # Construct the actual attack | |
| attack = textattack.Attack(goal_function, constraints, transformation, search_method) | |
| attack.cuda_() | |
| results = pd.read_csv("IMDB_results.csv") | |
| perturbed_texts = [ | |
| results["perturbed_text"][i] | |
| for i in range(len(results)) | |
| if results["result_type"][i] == "Successful" | |
| ] | |
| original_texts = [ | |
| results["original_text"][i] | |
| for i in range(len(results)) | |
| if results["result_type"][i] == "Successful" | |
| ] | |
| perturbed_texts = [remove_brackets(text) for text in perturbed_texts] | |
| original_texts = [remove_brackets(text) for text in original_texts] | |
| perturbed_texts = [clean_text(text) for text in perturbed_texts] | |
| original_texts = [clean_text(text) for text in original_texts] | |
| victim_model = attack.goal_function.model | |
| print("Getting corrected classes") | |
| print("This may take a while ...") | |
| # we can use directly resultds in csv file | |
| original_classes = [ | |
| torch.argmax(F.softmax(victim_model(original_text), dim=1)).item() | |
| for original_text in original_texts | |
| ] | |
| batch_size = 1000 | |
| num_batches = (len(perturbed_texts) + batch_size - 1) // batch_size | |
| batched_perturbed_texts = [] | |
| batched_original_texts = [] | |
| batched_original_classes = [] | |
| for i in range(num_batches): | |
| start = i * batch_size | |
| end = min(start + batch_size, len(perturbed_texts)) | |
| batched_perturbed_texts.append(perturbed_texts[start:end]) | |
| batched_original_texts.append(original_texts[start:end]) | |
| batched_original_classes.append(original_classes[start:end]) | |
| print(batched_original_classes) | |
| hard_samples_list = [] | |
| easy_samples_list = [] | |
| # Open a CSV file for writing | |
| csv_filename = "flow_correction_results_imdb.csv" | |
| with open(csv_filename, "w", newline="") as csvfile: | |
| fieldnames = ["freq_threshold", "batch_num", "match_perturbed", "match_original"] | |
| writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | |
| # Write the header row | |
| writer.writeheader() | |
| # Iterate over batched lists | |
| batch_num = 0 | |
| for perturbed, original, classes in zip( | |
| batched_perturbed_texts, batched_original_texts, batched_original_classes | |
| ): | |
| batch_num += 1 | |
| print(f"Processing batch number: {batch_num}") | |
| for i in range(2): | |
| wir_threshold = 0.1 * (i + 1) | |
| print(f"Setting Word threshold to: {wir_threshold}") | |
| corrector = Flow_Corrector( | |
| attack, | |
| word_rank_file="en_full_ranked.json", | |
| word_freq_file="en_full_freq.json", | |
| wir_threshold=wir_threshold, | |
| ) | |
| # Correct perturbed texts | |
| print("Correcting perturbed texts...") | |
| corrected_perturbed_classes = corrector.correct(perturbed) | |
| match_perturbed, hard_samples, easy_samples = count_matching_classes( | |
| classes, corrected_perturbed_classes, perturbed | |
| ) | |
| hard_samples_list.extend(hard_samples) | |
| easy_samples_list.extend(easy_samples) | |
| print(f"Number of matching classes (perturbed): {match_perturbed}") | |
| # Correct original texts | |
| print("Correcting original texts...") | |
| corrected_original_classes = corrector.correct(original) | |
| match_original, hard_samples, easy_samples = count_matching_classes( | |
| classes, corrected_original_classes, perturbed | |
| ) | |
| print(f"Number of matching classes (original): {match_original}") | |
| # Write results to CSV file | |
| print("Writing results to CSV file...") | |
| writer.writerow( | |
| { | |
| "freq_threshold": wir_threshold, | |
| "batch_num": batch_num, | |
| "match_perturbed": match_perturbed/len(perturbed), | |
| "match_original": match_original/len(perturbed), | |
| } | |
| ) | |
| print("-" * 20) | |
| print("savig samples for more statistics studies") | |
| # Save hard_samples_list and easy_samples_list to files | |
| with open('hard_samples.pkl', 'wb') as f: | |
| pickle.dump(hard_samples_list, f) | |
| with open('easy_samples.pkl', 'wb') as f: | |
| pickle.dump(easy_samples_list, f) |