#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib
import os
import sys
from zoo.common.utils import callZooFunc
from bigdl.nn.layer import Model as BModel
from zoo.pipeline.api.net.graph_net import GraphNet
if sys.version >= '3':
long = int
unicode = str
[docs]class JavaToPython:
# TODO: Add more mapping here as it only support Model and Sequential for now.
def __init__(self, jvalue, bigdl_type="float"):
self.jvaule = jvalue
self.jfullname = callZooFunc(bigdl_type,
"getRealClassNameOfJValue",
jvalue)
[docs] def get_python_class(self):
"""
Redirect the jvalue to the proper python class.
:param jvalue: Java object create by Py4j
:return: A proper Python wrapper which would be a Model, Sequential...
"""
jpackage_name = ".".join(self.jfullname.split(".")[:-1])
pclass_name = self._get_py_name(self.jfullname.split(".")[-1])
base_module = self._load_ppackage_by_jpackage(jpackage_name)
if pclass_name in dir(base_module):
pclass = getattr(base_module, pclass_name)
assert "from_jvalue" in dir(pclass), \
"pclass: {} should implement from_jvalue method".format(pclass)
return pclass
raise Exception("No proper python class for: {}".format(self.jfullname))
def _get_py_name(self, jclass_name):
if jclass_name == "Model":
return "Model"
elif jclass_name == "Sequential":
return "Sequential"
else:
raise Exception("Not supported type: {}".format(jclass_name))
def _load_ppackage_by_jpackage(self, jpackage_name):
if "com.intel.analytics.zoo.pipeline.api.keras.models":
return importlib.import_module('zoo.pipeline.api.keras.models')
raise Exception("Not supported package: {}".format(jpackage_name))
[docs]class Net:
[docs] @staticmethod
def from_jvalue(jvalue, bigdl_type="float"):
pclass = JavaToPython(jvalue).get_python_class()
return getattr(pclass, "from_jvalue")(jvalue, bigdl_type)
[docs] @staticmethod
def load_bigdl(model_path, weight_path=None, bigdl_type="float"):
"""
Load a pre-trained BigDL model.
:param model_path: The path to the pre-trained model.
:param weight_path: The path to the weights of the pre-trained model. Default is None.
:return: A pre-trained model.
"""
jmodel = callZooFunc(bigdl_type, "netLoadBigDL", model_path, weight_path)
return GraphNet.from_jvalue(jmodel)
[docs] @staticmethod
def load(model_path, weight_path=None, bigdl_type="float"):
"""
Load an existing Analytics Zoo model defined in Keras-style(with weights).
:param model_path: The path to load the saved model.
Local file system, HDFS and Amazon S3 are supported.
HDFS path should be like 'hdfs://[host]:[port]/xxx'.
Amazon S3 path should be like 's3a://bucket/xxx'.
:param weight_path: The path for pre-trained weights if any. Default is None.
:return: An Analytics Zoo model.
"""
jmodel = callZooFunc(bigdl_type, "netLoad", model_path, weight_path)
return Net.from_jvalue(jmodel, bigdl_type)
[docs] @staticmethod
def load_torch(path, bigdl_type="float"):
"""
Load a pre-trained Torch model.
:param path: The path containing the pre-trained model.
:return: A pre-trained model.
"""
jmodel = callZooFunc(bigdl_type, "netLoadTorch", path)
return GraphNet.from_jvalue(jmodel, bigdl_type)
[docs] @staticmethod
def load_caffe(def_path, model_path, bigdl_type="float"):
"""
Load a pre-trained Caffe model.
:param def_path: The path containing the caffe model definition.
:param model_path: The path containing the pre-trained caffe model.
:return: A pre-trained model.
"""
jmodel = callZooFunc(bigdl_type, "netLoadCaffe", def_path, model_path)
return GraphNet.from_jvalue(jmodel, bigdl_type)
[docs] @staticmethod
def load_keras(json_path=None, hdf5_path=None, by_name=False):
"""
Load a pre-trained Keras model.
:param json_path: The json path containing the keras model definition. Default is None.
:param hdf5_path: The HDF5 path containing the pre-trained keras model weights
with or without the model architecture. Default is None.
:param by_name: by default the architecture should be unchanged.
If set as True, only layers with the same name will be loaded.
:return: A BigDL model.
"""
return BModel.load_keras(json_path, hdf5_path, by_name)