marksaroufim commited on
Commit
3c9bd97
·
verified ·
1 Parent(s): 8064ae2

Fix generated response text

Browse files

join="" does not exist for python 3.10

And a regular print was printing text with jumbled \n

This PR instead just make the code look nice

```python
import torch
import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, a, b):
return a + b


def get_inputs():
# randomly generate input tensors based on the model architecture
a = torch.randn(1, 128).cuda()
b = torch.randn(1, 128).cuda()
return [a, b]


def get_init_inputs():
# randomly generate tensors required for initialization based on the model architecture
return []

```
The example new arch with custom Triton kernels looks like this:
```
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl




@triton
.jit
def add_kernel(
x_ptr, # Pointer to first input
y_ptr, # Pointer to second input
out_ptr,
```

Files changed (1) hide show
  1. README.md +1 -1
README.md CHANGED
@@ -47,7 +47,7 @@ response = pipeline(
47
  max_length=200,
48
  truncation=True,
49
  )[0]
50
- print(prompt, response, join="")
51
  ```
52
 
53
  ## Model Details
 
47
  max_length=200,
48
  truncation=True,
49
  )[0]
50
+ print(response["generated_text"])
51
  ```
52
 
53
  ## Model Details