In [2]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
DEVICE = torch.device("cuda:0")

model_name_or_path = "sberbank-ai/rugpt3small_based_on_gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
model = GPT2LMHeadModel.from_pretrained(model_name_or_path).to(DEVICE)

In [3]:
with open('anekdoty.txt', 'r', encoding='utf-8') as file:
    text = file.read()

In [4]:
from transformers import TextDataset, DataCollatorForLanguageModeling

# –°–æ—Ö—Ä–∞–Ω–∏–º –æ–±—É—á–∞—é—â–∏–µ –¥–∞–Ω–Ω—ã–µ –≤ .txt —Ñ–∞–π–ª 
train_path = 'train_dataset.txt'
with open(train_path, "w") as f:
    f.write(text)

# –°–æ–∑–¥–∞–Ω–∏–µ –¥–∞—Ç–∞—Å–µ—Ç–∞
train_dataset = TextDataset(tokenizer=tokenizer,file_path=train_path,block_size=32)
  
# –°–æ–∑–¥–∞–Ω–∏–µ –¥–∞—Ç–∞–ª–æ–¥–µ—Ä–∞ (–Ω–∞—Ä–µ–∑–∞–µ—Ç —Ç–µ–∫—Å—Ç –Ω–∞ –æ–ø—Ç–∏–º–∞–ª—å–Ω—ã–µ –ø–æ –¥–ª–∏–Ω–µ –∫—É—Å–∫–∏)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)



In [5]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./finetuned",
    overwrite_output_dir=True,
    num_train_epochs=30,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=16,
    warmup_steps=10,
    gradient_accumulation_steps=32,
    )

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    optimizers = (torch.optim.AdamW(model.parameters(),lr=0.001),None)
)

In [5]:
trainer.train()

Step,Training Loss


TrainOutput(global_step=240, training_loss=0.9343911488850911, metrics={'train_runtime': 4515.8084, 'train_samples_per_second': 58.428, 'train_steps_per_second': 0.053, 'total_flos': 4011240960000000.0, 'train_loss': 0.9343911488850911, 'epoch': 27.927272727272726})

In [9]:
model_path = "finetuned"
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
model = GPT2LMHeadModel.from_pretrained(model_path).to(DEVICE)

In [70]:
def generate_jokes(prompt, temperature, top_p, max_length, num_return_sequences):
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(DEVICE)
    
    # –ì–µ–Ω–µ—Ä–∏—Ä—É–µ–º –Ω–µ—Å–∫–æ–ª—å–∫–æ —à—É—Ç–æ–∫
    outputs = model.generate(
        input_ids=input_ids,
        do_sample=True,
        # num_beams=5,
        temperature=temperature,
        top_p=top_p,
        max_length=max_length,
        num_return_sequences=num_return_sequences
    )
    
    # –û–±—Ä–∞–±–æ—Ç–∫–∞ –≤—Å–µ—Ö —Å–≥–µ–Ω–µ—Ä–∏—Ä–æ–≤–∞–Ω–Ω—ã—Ö —à—É—Ç–æ–∫
    jokes = []
    for output in outputs:
        generated_text = tokenizer.decode(output, skip_special_tokens=True)
        # –û–±—Ä–µ–∑–∞–µ–º —Ç–µ–∫—Å—Ç –ø–æ—Å–ª–µ –ø–µ—Ä–≤–æ–π —Ç–æ—á–∫–∏
        if '‚Ä¶' in generated_text:
            generated_text = generated_text.split('‚Ä¶')[0] + '.'
        elif '.' in generated_text:
            generated_text = generated_text.split('.')[0] + '.'
        elif '!' in generated_text:
            generated_text = generated_text.split('!')[0] + '.'
        jokes.append(generated_text)
    
    return jokes

In [73]:
text = "–®–ª–∞ –°–∞—à–∞ –ø–æ —à–æ—Å—Å–µ"
print(generate_jokes(text, 1, 0.9, 30, 4))

['–®–ª–∞ –°–∞—à–∞ –ø–æ —à–æ—Å—Å–µ, –≥—Ä–æ–º–∫–æ —Ä–∞–∑–≥–æ–≤–∞—Ä–∏–≤–∞—è —Å —à–æ—Ñ–µ—Ä–æ–º.', '–®–ª–∞ –°–∞—à–∞ –ø–æ —à–æ—Å—Å–µ, –≥—Ä–æ–º–∫–æ –º–∞—Ç–µ—Ä—è—Å—å –∏ —É–ø–∏—Ä–∞—è —Ä—É–∫—É –≤ —à–∏—Ä–∏–Ω–∫—É.', '–®–ª–∞ –°–∞—à–∞ –ø–æ —à–æ—Å—Å–µ, –Ω–µ—Å–ª–∞ –ø—É—Ä–≥—É –∏, –∫–∞–∫ —Ä–∞–∑, –¥–æ–∂–¥—å.', '–®–ª–∞ –°–∞—à–∞ –ø–æ —à–æ—Å—Å–µ, –Ω–æ –Ω–µ –∑–∞ —Ç—Ä–∞–∫—Ç–æ—Ä–æ–º.']


In [10]:
text = "–æ–¥–Ω–∞–∂–¥—ã —è –ø—Ä–∏—à–µ–ª –∏–∑ —à–∫–æ–ª—ã"
input_ids = tokenizer.encode(text, return_tensors="pt").to(DEVICE)
model.eval()
with torch.no_grad():
    out = model.generate(input_ids, 
                        do_sample=True,
                        num_beams=2,
                        temperature=1.5,
                        top_p=0.9,
                        max_length=30,
                        
                        )

generated_text = list(map(tokenizer.decode, out))[0]
print()
print(generated_text)


–æ–¥–Ω–∞–∂–¥—ã —è –ø—Ä–æ–≤–∞–ª–∏–≤–∞–ª —ç–∫–∑–∞–º–µ–Ω –ø–æ –∏—Å—Ç–æ—Ä–∏–∏.
‚Äî –í–∏–Ω–æ —Å –≤–æ–∑—Ä–∞—Å—Ç–æ–º —Å—Ç–∞–Ω–æ–≤–∏—Ç—Å—è –ª—É—á—à–µ. –Ø —Å—Ç–∞–Ω–æ–≤–ª—é—Å—å –ª—É—á—à–µ —Å –≤–∏–Ω–æ–º‚Ä¶
‚Äî –°–Ω–∏–º–∏


In [8]:
# model.save_pretrained('./finetuned')
# tokenizer.save_pretrained('./finetuned')

In [38]:
# import requests
# from bs4 import BeautifulSoup
# import re

# # –§—É–Ω–∫—Ü–∏—è –¥–ª—è –ø–æ–ª—É—á–µ–Ω–∏—è —à—É—Ç–æ–∫ —Å –æ–¥–Ω–æ–π —Å—Ç—Ä–∞–Ω–∏—Ü—ã
# def get_jokes_from_page(url):
#     response = requests.get(url, headers=headers)
#     response.raise_for_status()  # –ü—Ä–æ–≤–µ—Ä–∫–∞ –Ω–∞ –æ—à–∏–±–∫–∏ –∑–∞–ø—Ä–æ—Å–∞

#     soup = BeautifulSoup(response.text, 'html.parser')

#     # –ù–∞—Ö–æ–¥–∏–º –≤—Å–µ –∞–Ω–µ–∫–¥–æ—Ç—ã –Ω–∞ —Å—Ç—Ä–∞–Ω–∏—Ü–µ
#     jokes = soup.find_all('div', class_='anekdot-text')  # –ó–∞–º–µ–Ω–∏—Ç–µ —Å–µ–ª–µ–∫—Ç–æ—Ä –Ω–∞ –ø—Ä–∞–≤–∏–ª—å–Ω—ã–π

#     page_jokes = []
#     for joke in jokes:
#         # –ò–∑–≤–ª–µ–∫–∞–µ–º —Ç–µ–∫—Å—Ç –∞–Ω–µ–∫–¥–æ—Ç–∞
#         joke_text = joke.get_text(strip=True)
        
#         # –£–¥–∞–ª—è–µ–º —Ü–∏—Ñ—Ä—ã –∏ —Å–∏–º–≤–æ–ª—ã –≤ –∫–æ–Ω—Ü–µ —Ç–µ–∫—Å—Ç–∞
#         joke_text_cleaned = re.sub(r'\d+[\#\d]*$', '', joke_text).strip()
        
#         # –î–æ–±–∞–≤–ª—è–µ–º –æ—á–∏—â–µ–Ω–Ω—ã–π —Ç–µ–∫—Å—Ç –≤ —Å–ø–∏—Å–æ–∫
#         page_jokes.append(joke_text_cleaned)
    
#     return page_jokes

# # URL-—à–∞–±–ª–æ–Ω –¥–ª—è —Å—Ç—Ä–∞–Ω–∏—Ü
# base_url = "https://anekdotovstreet.com/korotkie-anekdoty/{}/"

# # –ó–∞–≥–æ–ª–æ–≤–∫–∏ –¥–ª—è –∏–º–∏—Ç–∞—Ü–∏–∏ –±—Ä–∞—É–∑–µ—Ä–∞
# headers = {
#     'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
# }

# # –û—Ç–∫—Ä—ã–≤–∞–µ–º —Ñ–∞–π–ª –¥–ª—è –∑–∞–ø–∏—Å–∏ –∞–Ω–µ–∫–¥–æ—Ç–æ–≤
# with open('anekdoty.txt', 'w', encoding='utf-8') as file:
#     for page_number in range(2, 400):
#         # –§–æ—Ä–º–∏—Ä—É–µ–º URL –¥–ª—è —Ç–µ–∫—É—â–µ–π —Å—Ç—Ä–∞–Ω–∏—Ü—ã
#         url = base_url.format(page_number)
#         print(f"–°–æ–±–∏—Ä–∞—é —à—É—Ç–∫–∏ —Å–æ —Å—Ç—Ä–∞–Ω–∏—Ü—ã {page_number}...")

#         # –ü–æ–ª—É—á–∞–µ–º —à—É—Ç–∫–∏ —Å —Ç–µ–∫—É—â–µ–π —Å—Ç—Ä–∞–Ω–∏—Ü—ã
#         jokes = get_jokes_from_page(url)
        
#         # –ï—Å–ª–∏ —à—É—Ç–æ–∫ –Ω–µ—Ç, –∑–Ω–∞—á–∏—Ç, —Å—Ç—Ä–∞–Ω–∏—Ü—ã –∑–∞–∫–æ–Ω—á–∏–ª–∏—Å—å (–æ–ø—Ü–∏–æ–Ω–∞–ª—å–Ω–æ)
#         if not jokes:
#             print(f"–®—É—Ç–∫–∏ –Ω–∞ —Å—Ç—Ä–∞–Ω–∏—Ü–µ {page_number} –Ω–µ –Ω–∞–π–¥–µ–Ω—ã.")
#             continue
        
#         # –ó–∞–ø–∏—Å—ã–≤–∞–µ–º —à—É—Ç–∫–∏ –≤ —Ñ–∞–π–ª
#         for joke in jokes:
#             file.write(joke + '\n')

# print("–ê–Ω–µ–∫–¥–æ—Ç—ã —É—Å–ø–µ—à–Ω–æ —Å–æ—Ö—Ä–∞–Ω–µ–Ω—ã –≤ —Ñ–∞–π–ª 'anekdoty.txt'.")