Upload 3 files
Browse files- gradient_reversal/__init__.py +2 -0
- gradient_reversal/functional.py +16 -0
- gradient_reversal/module.py +11 -0
gradient_reversal/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .functional import revgrad
|
2 |
+
from .module import GradientReversal
|
gradient_reversal/functional.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.autograd import Function
|
2 |
+
|
3 |
+
class GradientReversal(Function):
|
4 |
+
@staticmethod
|
5 |
+
def forward(ctx, x, alpha):
|
6 |
+
ctx.save_for_backward(x, alpha)
|
7 |
+
return x
|
8 |
+
|
9 |
+
@staticmethod
|
10 |
+
def backward(ctx, grad_output):
|
11 |
+
grad_input = None
|
12 |
+
_, alpha = ctx.saved_tensors
|
13 |
+
if ctx.needs_input_grad[0]:
|
14 |
+
grad_input = - alpha*grad_output
|
15 |
+
return grad_input, None
|
16 |
+
revgrad = GradientReversal.apply
|
gradient_reversal/module.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .functional import revgrad
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
class GradientReversal(nn.Module):
|
6 |
+
def __init__(self, alpha):
|
7 |
+
super().__init__()
|
8 |
+
self.alpha = torch.tensor(alpha, requires_grad=False)
|
9 |
+
|
10 |
+
def forward(self, x):
|
11 |
+
return revgrad(x, self.alpha)
|