# OCR Translate v0.1
# 创建人:曾逸夫
# 创建时间:2022-06-14
# email: zyfiy1314@163.com
import os
import gradio as gr
import nltk
import pytesseract
from nltk.tokenize import sent_tokenize
from transformers import MarianMTModel, MarianTokenizer
nltk.download('punkt')
# ----------- 翻译 -----------
# https://huggingface.co/Helsinki-NLP/opus-mt-en-zh
modchoice = "Helsinki-NLP/opus-mt-en-zh"  # 模型名称
tokenizer = MarianTokenizer.from_pretrained(modchoice)  # 分词器
model = MarianMTModel.from_pretrained(modchoice)  # 模型
OCR_TR_DESCRIPTION = '''# OCR Translate v0.1
基于Tesseract的OCR翻译系统
'''
# 图片路径
img_dir = "./data"
# 获取tesseract语言列表
choices = os.popen('tesseract --list-langs').read().split('\n')[1:-1]
# tesseract语言列表转pytesseract语言
def ocr_lang(lang_list):
    lang_str = ""
    lang_len = len(lang_list)
    if lang_len == 1:
        return lang_list[0]
    else:
        for i in range(lang_len):
            lang_list.insert(lang_len - i, "+")
        lang_str = "".join(lang_list[:-1])
        return lang_str
# ocr tesseract
def ocr_tesseract(img, languages):
    ocr_str = pytesseract.image_to_string(img, lang=ocr_lang(languages))
    return ocr_str
# 示例
def set_example_image(example: list) -> dict:
    return gr.Image.update(value=example[0])
# 清除
def clear_content():
    return None
# 翻译
def translate(input_text):
    # 参考:https://huggingface.co/docs/transformers/model_doc/marian
    if input_text is None or input_text == "":
        return "系统提示:没有可翻译的内容!"
    translated = model.generate(**tokenizer(sent_tokenize(input_text), return_tensors="pt", padding=True))
    tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
    translate_text = "".join(tgt_text)
    return translate_text
def main():
    with gr.Blocks(css='style.css') as ocr_tr:
        gr.Markdown(OCR_TR_DESCRIPTION)
        # -------------- OCR 文字提取 --------------
        with gr.Box():
            with gr.Row():
                gr.Markdown("### Step 01: 文字提取")
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        inputs_img = gr.Image(image_mode="RGB", source="upload", type="pil", label="图片")
                    with gr.Row():
                        inputs_lang = gr.CheckboxGroup(choices=choices, type="value", value=['eng'], label='语言')
                    with gr.Row():
                        clear_img_btn = gr.Button('Clear')
                        ocr_btn = gr.Button(value='OCR 提取', variant="primary")
                with gr.Column():
                    imgs_path = sorted(os.listdir(img_dir))
                    example_images = gr.Dataset(components=[inputs_img],
                                                samples=[[f"{img_dir}/{i}"] for i in imgs_path])
        # -------------- 翻译 --------------
        with gr.Box():
            with gr.Row():
                gr.Markdown("### Step 02: 翻译")
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        outputs_text = gr.Textbox(label="提取内容", lines=20)
                    with gr.Row():
                        clear_text_btn = gr.Button('Clear')
                        translate_btn = gr.Button(value='翻译', variant="primary")
                with gr.Column():
                    outputs_tr_text = gr.Textbox(label="翻译内容", lines=20)
        # ---------------------- OCR Tesseract ----------------------
        ocr_btn.click(fn=ocr_tesseract, inputs=[inputs_img, inputs_lang], outputs=[
            outputs_text,])
        clear_img_btn.click(fn=clear_content, inputs=[], outputs=[inputs_img])
        example_images.click(fn=set_example_image, inputs=[
            example_images,], outputs=[
                inputs_img,])
        # ---------------------- OCR Tesseract ----------------------
        translate_btn.click(fn=translate, inputs=[outputs_text], outputs=[outputs_tr_text])
        clear_text_btn.click(fn=clear_content, inputs=[], outputs=[outputs_text])
    ocr_tr.launch(inbrowser=True)
if __name__ == '__main__':
    main()