language: en | |
tags: | |
- astronomy | |
- TDAMM | |
- classification | |
- multi-label | |
- NASA | |
- astrophysics | |
# TDAMM Multi-Label Classification Model | |
This model performs multi-label classification for Time Domain and Multi-Messenger Astronomy (TDAMM) topics. | |
## Model Description | |
Base Model: astroBERT | |
Task: Multi-label classification | |
Training Data: NASA and non-NASA documents related to TDAMM topics | |
## Usage | |
```python | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
tokenizer = AutoTokenizer.from_pretrained("nasa-impact/tdamm-classification") | |
model = AutoModelForSequenceClassification.from_pretrained("nasa-impact/tdamm-classification") | |
# Prepare input | |
text = "Your astronomical test text here" | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
# Get predictions | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predictions = torch.sigmoid(outputs.logits) | |
# Convert to binary predictions (threshold = 0.5) | |
predictions = (predictions > 0.5).int() | |
``` | |