Fix generated response text
Browse filesjoin="" does not exist for python 3.10
And a regular print was printing text with jumbled \n
This PR instead just make the code look nice
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, a, b):
return a + b
def get_inputs():
# randomly generate input tensors based on the model architecture
a = torch.randn(1, 128).cuda()
b = torch.randn(1, 128).cuda()
return [a, b]
def get_init_inputs():
# randomly generate tensors required for initialization based on the model architecture
return []
```
The example new arch with custom Triton kernels looks like this:
```
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
@triton
.jit
def add_kernel(
x_ptr, # Pointer to first input
y_ptr, # Pointer to second input
out_ptr,
```
@@ -47,7 +47,7 @@ response = pipeline(
|
|
47 |
max_length=200,
|
48 |
truncation=True,
|
49 |
)[0]
|
50 |
-
print(
|
51 |
```
|
52 |
|
53 |
## Model Details
|
|
|
47 |
max_length=200,
|
48 |
truncation=True,
|
49 |
)[0]
|
50 |
+
print(response["generated_text"])
|
51 |
```
|
52 |
|
53 |
## Model Details
|