Spaces:
Sleeping
Sleeping
| from flask import Flask, request | |
| import requests | |
| import os | |
| import re | |
| import textwrap | |
| from transformers import AutoModelForSeq2SeqLM | |
| from transformers import AutoTokenizer | |
| from bart import BartForConditionalGeneration | |
| from langdetect import detect | |
| import subprocess | |
| tokenizer = AutoTokenizer.from_pretrained("GuysTrans/bart-base-re-attention-seq-512") | |
| vn_tokenizer = AutoTokenizer.from_pretrained("GuysTrans/bart-base-vn-re-attention-vn-tokenizer") | |
| model = BartForConditionalGeneration.from_pretrained( | |
| "GuysTrans/bart-base-re-attention-seq-512") | |
| vn_model = BartForConditionalGeneration.from_pretrained( | |
| "GuysTrans/bart-base-vn-re-attention-vn-tokenizer") | |
| map_words = { | |
| "Hello and Welcome to 'Ask A Doctor' service": "", | |
| "Hello,": "", | |
| "Hi,": "", | |
| "Hello": "", | |
| "Hi": "", | |
| "Ask A Doctor": "MedForum", | |
| "H C M": "Med Forum" | |
| } | |
| word_remove_sentence = [ | |
| "Welcome to", | |
| # "hello", | |
| # "hi", | |
| # "regards", | |
| # "dr.", | |
| # "physician", | |
| # "welcome", | |
| ] | |
| def generate_summary(question, model, tokenizer): | |
| inputs = tokenizer( | |
| question, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt", | |
| ) | |
| input_ids = inputs.input_ids.to(model.device) | |
| attention_mask = inputs.attention_mask.to(model.device) | |
| outputs = model.generate( | |
| input_ids, attention_mask=attention_mask, max_new_tokens=4096, do_sample=True, num_beams=4, top_k=50, early_stopping=True, no_repeat_ngram_size=2) | |
| output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| return outputs, output_str | |
| app = Flask(__name__) | |
| FB_API_URL = 'https://graph.facebook.com/v2.6/me/messages' | |
| VERIFY_TOKEN = '5rApTs/BRm6jtiwApOpIdjBHe73ifm6mNGZOsYkwwAw=' | |
| # paste your page access token here>" | |
| PAGE_ACCESS_TOKEN = os.environ['PAGE_ACCESS_TOKEN'] | |
| def get_bot_response(message): | |
| lang = detect(message) | |
| model_use = model | |
| tokenizer_use = tokenizer | |
| template = "Welcome to MedForRum chatbot service. %s. Thanks for asking on MedForum." | |
| if lang == "vi": | |
| model_use = vn_model | |
| tokenizer_use = vn_tokenizer | |
| template = "Chào mừng bạn đến với dịch vụ MedForRum chatbot. %s. Cảm ơn bạn đã sử dụng MedForum." | |
| return template % post_process(generate_summary(message, model_use, tokenizer_use)[1][0]) | |
| def verify_webhook(req): | |
| if req.args.get("hub.verify_token") == VERIFY_TOKEN: | |
| return req.args.get("hub.challenge") | |
| else: | |
| return "incorrect" | |
| def respond(sender, message): | |
| """Formulate a response to the user and | |
| pass it on to a function that sends it.""" | |
| response = get_bot_response(message) | |
| send_message(sender, response) | |
| return response | |
| def is_user_message(message): | |
| """Check if the message is a message from the user""" | |
| return (message.get('message') and | |
| message['message'].get('text') and | |
| not message['message'].get("is_echo")) | |
| def listen(): | |
| """This is the main function flask uses to | |
| listen at the `/webhook` endpoint""" | |
| if request.method == 'GET': | |
| return verify_webhook(request) | |
| if request.method == 'POST': | |
| payload = request.json | |
| event = payload['entry'][0]['messaging'] | |
| for x in event: | |
| if is_user_message(x): | |
| text = x['message']['text'] | |
| sender_id = x['sender']['id'] | |
| respond(sender_id, text) | |
| return "ok" | |
| def send_message(recipient_id, text): | |
| """Send a response to Facebook""" | |
| payload = { | |
| 'message': { | |
| 'text': text | |
| }, | |
| 'recipient': { | |
| 'id': recipient_id | |
| }, | |
| 'notification_type': 'regular' | |
| } | |
| auth = { | |
| 'access_token': PAGE_ACCESS_TOKEN | |
| } | |
| response = requests.post( | |
| FB_API_URL, | |
| params=auth, | |
| json=payload | |
| ) | |
| return response.json() | |
| def chat(): | |
| payload = request.json | |
| message = payload['message'] | |
| response = get_bot_response(message) | |
| return {"message": response} | |
| def post_process(output): | |
| # output = textwrap.fill(textwrap.dedent(output).strip(), width=120) | |
| lines = output.split(".") | |
| for line in lines: | |
| for word in word_remove_sentence: | |
| if word.lower() in line.lower(): | |
| lines.remove(line) | |
| break | |
| output = ".".join(lines) | |
| for item in map_words.keys(): | |
| output = re.sub(item, map_words[item], output, re.I) | |
| return textwrap.fill(textwrap.dedent(output).strip(), width=120) | |
| subprocess.Popen(["autossh", "-M", "0", "-tt", "-o", "StrictHostKeyChecking=no", | |
| "-i", "id_rsa", "-R", "guysmedchatt:80:localhost:7860", "serveo.net"]) | |
| # subprocess.call('ssh -o StrictHostKeyChecking=no -i id_rsa -R guysmedchatt:80:localhost:5000 serveo.net', shell=True) | |