|
--- |
|
license: mit |
|
--- |
|
|
|
# ESM-2 QLoRA for Binding Site Prediction |
|
|
|
In this model, we wanted to see how the performance metrics were effected by adapting additional weight matrices with QLoRA. This was |
|
shown to be the most important hyperparameter for improvement in performance metrics by far, whereas hyperparameters such as rank and scaling |
|
factor were shown to be negligible in importance, with lower rank being just as good as higher rank. So, we decided to test the difference between |
|
simply using the query, key, and value weight matrix adapters to using adapters for all possible weight matrices. The comparison for the |
|
first epoch can be seen below. Note the minor performance improvements for the model using every possible weight matrix (this model). |
|
|
|
### This model |
|
|
|
```python |
|
Test (epoch 1): |
|
'eval_loss': 0.41490185260772705, |
|
'eval_accuracy': 0.8625347674451358, |
|
'eval_precision': 0.11370668247419904, |
|
'eval_recall': 0.7800926533683039, |
|
'eval_f1': 0.19848246486644372, |
|
'eval_auc': 0.8222331548742136, |
|
'eval_mcc': 0.2639007297474409} |
|
``` |
|
|
|
### Query, Key, Value only Model: |
|
|
|
```python |
|
Test (epoch 1): |
|
{'eval_loss': 0.3398605287075043, |
|
'eval_accuracy': 0.8557050926566265, |
|
'eval_precision': 0.10792930844408741, |
|
'eval_recall': 0.7726298654561553, |
|
'eval_f1': 0.18940102955847055, |
|
'eval_auc': 0.8150939843855006, |
|
'eval_mcc': 0.2535956911257298} |
|
``` |
|
|
|
The metrics on the datasets [mentioned here](https://github.com/hamzagamouh/pt-lm-gnn) |
|
can be [found here](https://huggingface.co/AmelieSchreiber/esm2_t6_8m_qlora_binding_sites_v1/blob/main/pdb_structure_metrics.txt). |
|
|
|
## Testing for Overfitting |
|
|
|
Notably, it appears adding in the adapters for the additional weight matrices serves as a more robust regularization technique, and |
|
that these models appear to generalize better. |
|
|
|
### Epoch 1: |
|
```python |
|
Train metrics: |
|
{'eval_loss': 0.35603779554367065, |
|
'eval_accuracy': 0.8439650327744697, |
|
'eval_precision': 0.11529132737114746, |
|
'eval_recall': 0.9162279099673907, |
|
'eval_f1': 0.20481078411524478, |
|
'eval_auc': 0.8792862815250805, |
|
'eval_mcc': 0.29286338236467047} |
|
|
|
Test metrics: |
|
{'eval_loss': 0.3942357003688812, |
|
'eval_accuracy': 0.8246741787222583, |
|
'eval_precision': 0.0942294455869611, |
|
'eval_recall': 0.8169195154212542, |
|
'eval_f1': 0.16896879944226734, |
|
'eval_auc': 0.8208833317810486, |
|
'eval_mcc': 0.23939865094539936} |
|
``` |