Spaces:
Sleeping
Sleeping
| import json | |
| import re | |
| import time | |
| import uuid | |
| from pathlib import Path | |
| from transformers import AutoModel, AutoTokenizer | |
| from utils_demo import * | |
| from concrete.ml.common.serialization.loaders import load | |
| from concrete.ml.deployment import FHEModelClient, FHEModelServer | |
| TOLERANCE_PROBA = 0.77 | |
| CURRENT_DIR = Path(__file__).parent | |
| DEPLOYMENT_DIR = CURRENT_DIR / "deployment" | |
| KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys" | |
| class FHEAnonymizer: | |
| def __init__(self): | |
| # Load tokenizer and model | |
| self.tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2") | |
| self.embeddings_model = AutoModel.from_pretrained("obi/deid_roberta_i2b2") | |
| self.punctuation_list = PUNCTUATION_LIST | |
| self.uuid_map = read_json(MAPPING_UUID_PATH) | |
| self.client = FHEModelClient(DEPLOYMENT_DIR, key_dir=KEYS_DIR) | |
| self.server = FHEModelServer(DEPLOYMENT_DIR) | |
| def generate_key(self): | |
| clean_directory() | |
| # Creates the private and evaluation keys on the client side | |
| self.client.generate_private_and_evaluation_keys() | |
| # Get the serialized evaluation keys | |
| self.evaluation_key = self.client.get_serialized_evaluation_keys() | |
| assert isinstance(self.evaluation_key, bytes) | |
| evaluation_key_path = KEYS_DIR / "evaluation_key" | |
| with evaluation_key_path.open("wb") as f: | |
| f.write(self.evaluation_key) | |
| def encrypt_query(self, text: str): | |
| # Pattern to identify words and non-words (including punctuation, spaces, etc.) | |
| tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text) | |
| encrypted_tokens = [] | |
| for token in tokens: | |
| if bool(re.match(r"^\s+$", token)): | |
| continue | |
| # Directly append non-word tokens or whitespace to processed_tokens | |
| # Prediction for each word | |
| emb_x = get_batch_text_representation([token], self.embeddings_model, self.tokenizer) | |
| encrypted_x = self.client.quantize_encrypt_serialize(emb_x) | |
| assert isinstance(encrypted_x, bytes) | |
| encrypted_tokens.append(encrypted_x) | |
| write_pickle(KEYS_DIR / f"encrypted_quantized_query", encrypted_tokens) | |
| def run_server(self): | |
| encrypted_tokens = read_pickle(KEYS_DIR / f"encrypted_quantized_query") | |
| encrypted_output, timing = [], [] | |
| for enc_x in encrypted_tokens: | |
| start_time = time.time() | |
| enc_y = self.server.run(enc_x, self.evaluation_key) | |
| timing.append((time.time() - start_time) / 60.0) | |
| encrypted_output.append(enc_y) | |
| write_pickle(KEYS_DIR / f"encrypted_output", encrypted_output) | |
| write_pickle(KEYS_DIR / f"encrypted_timing", timing) | |
| return encrypted_output, timing | |
| def decrypt_output(self, text): | |
| encrypted_output = read_pickle(KEYS_DIR / f"encrypted_output") | |
| tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text) | |
| decrypted_output, identified_words_with_prob = [], [] | |
| i = 0 | |
| for token in tokens: | |
| # Directly append non-word tokens or whitespace to processed_tokens | |
| if bool(re.match(r"^\s+$", token)): | |
| continue | |
| else: | |
| encrypted_token = encrypted_output[i] | |
| prediction_proba = self.client.deserialize_decrypt_dequantize(encrypted_token) | |
| probability = prediction_proba[0][1] | |
| i += 1 | |
| if probability >= TOLERANCE_PROBA: | |
| identified_words_with_prob.append((token, probability)) | |
| # Use the existing UUID if available, otherwise generate a new one | |
| tmp_uuid = self.uuid_map.get(token, str(uuid.uuid4())[:8]) | |
| decrypted_output.append(tmp_uuid) | |
| self.uuid_map[token] = tmp_uuid | |
| else: | |
| decrypted_output.append(token) | |
| # Update the UUID map with query. | |
| with open(MAPPING_UUID_PATH, "w") as file: | |
| json.dump(self.uuid_map, file) | |
| write_pickle(KEYS_DIR / f"reconstructed_sentence", " ".join(decrypted_output)) | |
| write_pickle(KEYS_DIR / f"identified_words_with_prob", identified_words_with_prob) | |
| def run_server_and_decrypt_output(self, text): | |
| self.run_server() | |
| self.decrypt_output(text) | |