|
from transformers import PreTrainedModel
|
|
from typing import Optional
|
|
import torch
|
|
|
|
class MinerUModel(PreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self._setup_models()
|
|
|
|
def _setup_models(self):
|
|
from model_loader import MinerUModelLoader
|
|
self.models = MinerUModelLoader.load_models("./")
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
config = kwargs.pop("config", None)
|
|
model = cls(config)
|
|
model._setup_models()
|
|
return model
|
|
|
|
def forward(self, input_data):
|
|
|
|
return self.models["layout"](input_data)
|
|
|
|
def load_model():
|
|
model = MinerUModel.from_pretrained("./")
|
|
return model
|
|
|
|
def inference(pdf_content):
|
|
model = load_model()
|
|
return model(pdf_content) |