spandey8 commited on
Commit
3650b90
·
verified ·
1 Parent(s): 8f734de

Upload 3 files

Browse files
gradient_reversal/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .functional import revgrad
2
+ from .module import GradientReversal
gradient_reversal/functional.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.autograd import Function
2
+
3
+ class GradientReversal(Function):
4
+ @staticmethod
5
+ def forward(ctx, x, alpha):
6
+ ctx.save_for_backward(x, alpha)
7
+ return x
8
+
9
+ @staticmethod
10
+ def backward(ctx, grad_output):
11
+ grad_input = None
12
+ _, alpha = ctx.saved_tensors
13
+ if ctx.needs_input_grad[0]:
14
+ grad_input = - alpha*grad_output
15
+ return grad_input, None
16
+ revgrad = GradientReversal.apply
gradient_reversal/module.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .functional import revgrad
2
+ import torch
3
+ from torch import nn
4
+
5
+ class GradientReversal(nn.Module):
6
+ def __init__(self, alpha):
7
+ super().__init__()
8
+ self.alpha = torch.tensor(alpha, requires_grad=False)
9
+
10
+ def forward(self, x):
11
+ return revgrad(x, self.alpha)