File size: 3,249 Bytes
889f722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
# =========
# Configuration: Define dataset information.
# =========
datasets_config = {
"emgsd": {
"id": "holistic-ai/EMGSD",
"columns": ["category", "stereotype_type", "text"],
"filter": {"category": "stereotype"},
"text_field": "text",
"label_field": "stereotype_type",
"max_length": 512,
"category": {
"lgbtq+": {
"prompt": "The gay people are",
},
"nationality": {
"prompt": "The French people are",
},
"race": {
"prompt": "The Black people are",
}
}
},
"yahoo_question": {
"id": "community-datasets/yahoo_answers_topics",
"columns": ["topic", "question_title", "question_content"],
"text_field": "question_title",
"label_field": "topic",
"max_length": 512
},
"yahoo_answer": {
"id": "community-datasets/yahoo_answers_topics",
"columns": ["topic", "best_answer"],
"text_field": "best_answer",
"label_field": "topic",
"max_length": 512
},
"science": {
"id": "knowledgator/Scientific-text-classification",
"columns": ["text", "label"],
"text_field": "text",
"label_field": "label",
"max_length": 512
},
"wiki256": {
"id": "seonglae/wikipedia-256",
"columns": ["text", "title"],
"text_field": "text",
"label_field": "title",
"max_length": 512
},
"wiki512": {
"id": "seonglae/wikipedia-512",
"columns": ["text", "title"],
"text_field": "text",
"label_field": "title",
"max_length": 1024
}
}
# =========
# Configuration: Define model-specific information.
# For "gpt2", we specify the SAE source and the list of hooks to use.
# f"{model}-{dataset}" is the key for trained models.
# =========
models_config = {
"gpt2": {
"id": "gpt2",
"sae": "jbloom/GPT2-Small-SAEs-Reformatted",
"hooks": [
"blocks.11.hook_resid_pre",
"blocks.10.hook_resid_pre",
"blocks.9.hook_resid_pre",
"blocks.8.hook_resid_pre",
"blocks.7.hook_resid_pre",
"blocks.6.hook_resid_pre",
"blocks.5.hook_resid_pre",
"blocks.4.hook_resid_pre",
"blocks.3.hook_resid_pre",
"blocks.2.hook_resid_pre",
"blocks.1.hook_resid_pre",
"blocks.0.hook_resid_pre"
]
},
"gpt2-emgsd": {
"id": "holistic-ai/gpt2-EMGSD",
"sae": "jbloom/GPT2-Small-SAEs-Reformatted",
"hooks": [
"blocks.11.hook_resid_pre",
"blocks.10.hook_resid_pre",
"blocks.9.hook_resid_pre",
"blocks.8.hook_resid_pre",
"blocks.7.hook_resid_pre",
"blocks.6.hook_resid_pre",
"blocks.5.hook_resid_pre",
"blocks.4.hook_resid_pre",
"blocks.3.hook_resid_pre",
"blocks.2.hook_resid_pre",
"blocks.1.hook_resid_pre",
"blocks.0.hook_resid_pre"
]
}
}
|