CorrSteer / config.py
seonglae's picture
feat: hf space corr-steer
889f722
# =========
# Configuration: Define dataset information.
# =========
datasets_config = {
"emgsd": {
"id": "holistic-ai/EMGSD",
"columns": ["category", "stereotype_type", "text"],
"filter": {"category": "stereotype"},
"text_field": "text",
"label_field": "stereotype_type",
"max_length": 512,
"category": {
"lgbtq+": {
"prompt": "The gay people are",
},
"nationality": {
"prompt": "The French people are",
},
"race": {
"prompt": "The Black people are",
}
}
},
"yahoo_question": {
"id": "community-datasets/yahoo_answers_topics",
"columns": ["topic", "question_title", "question_content"],
"text_field": "question_title",
"label_field": "topic",
"max_length": 512
},
"yahoo_answer": {
"id": "community-datasets/yahoo_answers_topics",
"columns": ["topic", "best_answer"],
"text_field": "best_answer",
"label_field": "topic",
"max_length": 512
},
"science": {
"id": "knowledgator/Scientific-text-classification",
"columns": ["text", "label"],
"text_field": "text",
"label_field": "label",
"max_length": 512
},
"wiki256": {
"id": "seonglae/wikipedia-256",
"columns": ["text", "title"],
"text_field": "text",
"label_field": "title",
"max_length": 512
},
"wiki512": {
"id": "seonglae/wikipedia-512",
"columns": ["text", "title"],
"text_field": "text",
"label_field": "title",
"max_length": 1024
}
}
# =========
# Configuration: Define model-specific information.
# For "gpt2", we specify the SAE source and the list of hooks to use.
# f"{model}-{dataset}" is the key for trained models.
# =========
models_config = {
"gpt2": {
"id": "gpt2",
"sae": "jbloom/GPT2-Small-SAEs-Reformatted",
"hooks": [
"blocks.11.hook_resid_pre",
"blocks.10.hook_resid_pre",
"blocks.9.hook_resid_pre",
"blocks.8.hook_resid_pre",
"blocks.7.hook_resid_pre",
"blocks.6.hook_resid_pre",
"blocks.5.hook_resid_pre",
"blocks.4.hook_resid_pre",
"blocks.3.hook_resid_pre",
"blocks.2.hook_resid_pre",
"blocks.1.hook_resid_pre",
"blocks.0.hook_resid_pre"
]
},
"gpt2-emgsd": {
"id": "holistic-ai/gpt2-EMGSD",
"sae": "jbloom/GPT2-Small-SAEs-Reformatted",
"hooks": [
"blocks.11.hook_resid_pre",
"blocks.10.hook_resid_pre",
"blocks.9.hook_resid_pre",
"blocks.8.hook_resid_pre",
"blocks.7.hook_resid_pre",
"blocks.6.hook_resid_pre",
"blocks.5.hook_resid_pre",
"blocks.4.hook_resid_pre",
"blocks.3.hook_resid_pre",
"blocks.2.hook_resid_pre",
"blocks.1.hook_resid_pre",
"blocks.0.hook_resid_pre"
]
}
}