Spaces:
Runtime error
Runtime error
| import json | |
| import werkzeug | |
| import tensorflow as tf | |
| from config import config, parseArgs, configPDF | |
| from extract_feature import get_img_feat, build_model | |
| from main import setSession, loadWeights, setSavers | |
| from model import MACnet | |
| from preprocess import Preprocesser | |
| import warnings | |
| def predict(image, question): | |
| parseArgs() | |
| configPDF() | |
| with open(config.configFile(), "a+") as outFile: | |
| json.dump(vars(config), outFile) | |
| if config.gpus != "": | |
| config.gpusNum = len(config.gpus.split(",")) | |
| os.environ["CUDA_VISIBLE_DEVICES"] = config.gpus | |
| tf.reset_default_graph() | |
| tf.Graph().as_default() | |
| tf.logging.set_verbosity(tf.logging.ERROR) | |
| cnn_model = build_model() | |
| imageData = get_img_feat(cnn_model, image) | |
| preprocessor = Preprocesser() | |
| qData, embeddings, answerDict = preprocessor.preprocessData(question) | |
| model = MACnet(embeddings, answerDict) | |
| init = tf.global_variables_initializer() | |
| savers = setSavers(model) | |
| saver, emaSaver = savers["saver"], savers["emaSaver"] | |
| sessionConfig = setSession() | |
| data = {'data': qData, 'image': imageData} | |
| with tf.Session(config=sessionConfig) as sess: | |
| sess.graph.finalize() | |
| epoch = loadWeights(sess, saver, init) | |
| emaSaver.restore(sess, config.weightsFile(epoch)) | |
| evalRes = model.runBatch(sess, data['data'], data['image'], False) | |
| answer = None | |
| if evalRes in ['top', 'bottom']: | |
| answer = 'The caption at the %s side of the object.' % evalRes | |
| elif evalRes in ['True', 'False']: | |
| answer = 'There is at least one title object in this image.' | |
| else: | |
| answer = 'This image contain %s specific object(s).' % evalRes | |
| return answer |