MMADS commited on
Commit
7891910
·
1 Parent(s): 118b9b4

updated the model card

Browse files
Files changed (1) hide show
  1. README.md +86 -8
README.md CHANGED
@@ -83,21 +83,99 @@ Potential downstream uses include:
83
  Use the code below to get started with the model.
84
 
85
  ```python
 
 
86
  import torch
87
  from transformers import RobertaTokenizer, RobertaForSequenceClassification
88
- import json
89
 
90
- # Load the model, tokenizer, and configuration
91
  model_path = "MMADS/MoralFoundationsClassifier"
92
  model = RobertaForSequenceClassification.from_pretrained(model_path)
93
  tokenizer = RobertaTokenizer.from_pretrained(model_path)
94
 
95
- # Load label names
96
- with open(f"{model_path}/label_names.json", 'r') as f:
97
- label_names = json.load(f)
98
-
99
- # Your function to make predictions
100
- ### ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  ```
103
 
 
83
  Use the code below to get started with the model.
84
 
85
  ```python
86
+ # How to Get Started with the Model
87
+
88
  import torch
89
  from transformers import RobertaTokenizer, RobertaForSequenceClassification
 
90
 
91
+ # Load the model and tokenizer
92
  model_path = "MMADS/MoralFoundationsClassifier"
93
  model = RobertaForSequenceClassification.from_pretrained(model_path)
94
  tokenizer = RobertaTokenizer.from_pretrained(model_path)
95
 
96
+ # Define label names based on Moral Foundations Theory
97
+ # Each foundation has a virtue (positive) and vice (negative) dimension
98
+ label_names = [
99
+ "care_virtue", # Compassion, kindness, nurturing
100
+ "care_vice", # Harm, cruelty, suffering
101
+ "fairness_virtue", # Justice, equality, reciprocity
102
+ "fairness_vice", # Cheating, inequality, injustice
103
+ "loyalty_virtue", # Loyalty, patriotism, self-sacrifice
104
+ "loyalty_vice", # Betrayal, treason, disloyalty
105
+ "authority_virtue", # Respect, tradition, order
106
+ "authority_vice", # Subversion, disobedience, chaos
107
+ "sanctity_virtue", # Purity, sanctity, nobility
108
+ "sanctity_vice" # Degradation, contamination, impurity
109
+ ]
110
+
111
+ # Function to make predictions
112
+ def predict_moral_foundations(texts, threshold=0.65):
113
+ """
114
+ Predict moral foundations present in a batch of texts.
115
+
116
+ Args:
117
+ texts (list of str): A list of input texts to analyze.
118
+ threshold (float): Probability threshold for positive prediction (default: 0.65).
119
+
120
+ Returns:
121
+ list of dict: A list of dictionaries, one for each input text.
122
+ """
123
+ # Tokenize and prepare input
124
+ # The tokenizer handles a list of strings automatically, creating a batch.
125
+ inputs = tokenizer(texts, return_tensors="pt", truncation=True,
126
+ padding=True, max_length=512)
127
+
128
+ # Move to GPU if available
129
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130
+ model.to(device)
131
+ inputs = {k: v.to(device) for k, v in inputs.items()}
132
+
133
+ # Get predictions
134
+ with torch.no_grad():
135
+ outputs = model(**inputs)
136
+ logits = outputs.logits
137
+ probabilities = torch.sigmoid(logits)
138
+
139
+ # Format results for the entire batch
140
+ all_results = []
141
+ batch_probs = probabilities.cpu().numpy()
142
+
143
+ for single_text_probs in batch_probs:
144
+ results = {}
145
+ for i, label in enumerate(label_names):
146
+ results[label] = {
147
+ "probability": float(single_text_probs[i]),
148
+ "predicted": bool(single_text_probs[i] > threshold)
149
+ }
150
+ all_results.append(results)
151
+
152
+ return all_results
153
+
154
+ # Example usage with a list of texts
155
+ texts = [
156
+ "You don't actually believe what you're saying.",
157
+ "People are calling you stupid but you're just good old fashioned lying.",
158
+ "Even if you've never held employment in your life there is no way you think employers just hand out sick days whenever their employees feel like it.",
159
+ "Troll on."
160
+ ]
161
+ all_predictions = predict_moral_foundations(texts)
162
+
163
+ # Display detected foundations for each text
164
+ for i, text in enumerate(texts):
165
+ print(f"Analyzing: '{text}'")
166
+ print("Detected moral foundations:")
167
+
168
+ predictions = all_predictions[i]
169
+ detected_foundations = False
170
+ for foundation, data in predictions.items():
171
+ if data['predicted']:
172
+ print(f" - {foundation}: {data['probability']:.3f}")
173
+ detected_foundations = True
174
+
175
+ if not detected_foundations:
176
+ print(" - None")
177
+
178
+ print("-" * 30) # Separator for clarity
179
 
180
  ```
181