Spaces:
Runtime error
Runtime error
Commit
·
8079453
1
Parent(s):
a5e4f9a
Update lora.py
Browse files
lora.py
CHANGED
|
@@ -114,7 +114,7 @@ class LoRAModule(torch.nn.Module):
|
|
| 114 |
|
| 115 |
lx = self.lora_up(lx)
|
| 116 |
|
| 117 |
-
return org_forwarded + lx * self.multiplier
|
| 118 |
|
| 119 |
|
| 120 |
class LoRAInfModule(LoRAModule):
|
|
@@ -219,7 +219,12 @@ class LoRAInfModule(LoRAModule):
|
|
| 219 |
|
| 220 |
def default_forward(self, x):
|
| 221 |
# print("default_forward", self.lora_name, x.size())
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
def forward(self, x):
|
| 225 |
if not self.enabled:
|
|
|
|
| 114 |
|
| 115 |
lx = self.lora_up(lx)
|
| 116 |
|
| 117 |
+
return org_forwarded + lx * self.multiplier * scale
|
| 118 |
|
| 119 |
|
| 120 |
class LoRAInfModule(LoRAModule):
|
|
|
|
| 219 |
|
| 220 |
def default_forward(self, x):
|
| 221 |
# print("default_forward", self.lora_name, x.size())
|
| 222 |
+
org_forward = self.org_forward(x)
|
| 223 |
+
lora_up_down = self.lora_up(self.lora_down(x))
|
| 224 |
+
print(org_forward)
|
| 225 |
+
print(lora_up_down)
|
| 226 |
+
print(self.multiplier)
|
| 227 |
+
return org_forward + lora_up_down * self.multiplier #* self.scale
|
| 228 |
|
| 229 |
def forward(self, x):
|
| 230 |
if not self.enabled:
|