File size: 3,249 Bytes
889f722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# =========
# Configuration: Define dataset information.
# =========
datasets_config = {
    "emgsd": {
        "id": "holistic-ai/EMGSD",
        "columns": ["category", "stereotype_type", "text"],
        "filter": {"category": "stereotype"},
        "text_field": "text",
        "label_field": "stereotype_type",
        "max_length": 512,
        "category": {
            "lgbtq+": {
              "prompt": "The gay people are",
            },
            "nationality": {
              "prompt": "The French people are",
            },
            "race": {
              "prompt": "The Black people are",
            }
        }
    },
    "yahoo_question": {
        "id": "community-datasets/yahoo_answers_topics",
        "columns": ["topic", "question_title", "question_content"],
        "text_field": "question_title",
        "label_field": "topic",
        "max_length": 512
    },
    "yahoo_answer": {
        "id": "community-datasets/yahoo_answers_topics",
        "columns": ["topic", "best_answer"],
        "text_field": "best_answer",
        "label_field": "topic",
        "max_length": 512
    },
    "science": {
        "id": "knowledgator/Scientific-text-classification",
        "columns": ["text", "label"],
        "text_field": "text",
        "label_field": "label",
        "max_length": 512
    },
    "wiki256": {
        "id": "seonglae/wikipedia-256",
        "columns": ["text", "title"],
        "text_field": "text",
        "label_field": "title",
        "max_length": 512
    },
    "wiki512": {
        "id": "seonglae/wikipedia-512",
        "columns": ["text", "title"],
        "text_field": "text",
        "label_field": "title",
        "max_length": 1024
    }
}

# =========
# Configuration: Define model-specific information.
# For "gpt2", we specify the SAE source and the list of hooks to use.
# f"{model}-{dataset}" is the key for trained models.
# =========
models_config = {
    "gpt2": {
        "id": "gpt2",
        "sae": "jbloom/GPT2-Small-SAEs-Reformatted",
        "hooks": [
            "blocks.11.hook_resid_pre",
            "blocks.10.hook_resid_pre",
            "blocks.9.hook_resid_pre",
            "blocks.8.hook_resid_pre",
            "blocks.7.hook_resid_pre",
            "blocks.6.hook_resid_pre",
            "blocks.5.hook_resid_pre",
            "blocks.4.hook_resid_pre",
            "blocks.3.hook_resid_pre",
            "blocks.2.hook_resid_pre",
            "blocks.1.hook_resid_pre",
            "blocks.0.hook_resid_pre"
        ]
    },
    "gpt2-emgsd": {
      "id": "holistic-ai/gpt2-EMGSD",
      "sae": "jbloom/GPT2-Small-SAEs-Reformatted",
        "hooks": [
            "blocks.11.hook_resid_pre",
            "blocks.10.hook_resid_pre",
            "blocks.9.hook_resid_pre",
            "blocks.8.hook_resid_pre",
            "blocks.7.hook_resid_pre",
            "blocks.6.hook_resid_pre",
            "blocks.5.hook_resid_pre",
            "blocks.4.hook_resid_pre",
            "blocks.3.hook_resid_pre",
            "blocks.2.hook_resid_pre",
            "blocks.1.hook_resid_pre",
            "blocks.0.hook_resid_pre"
        ]
    }
}