Source code for zoo.pipeline.api.net.graph_net

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import sys

from bigdl.nn.layer import Model as BModel
from zoo.feature.image import ImageSet
from zoo.feature.text import TextSet
from zoo.pipeline.api.keras.base import ZooKerasLayer
from zoo.pipeline.api.keras.utils import *
from bigdl.nn.layer import Layer
from zoo.common.utils import callZooFunc

if sys.version >= '3':
    long = int
    unicode = str


[docs]class GraphNet(BModel): def __init__(self, input, output, jvalue=None, bigdl_type="float", **kwargs): super(BModel, self).__init__(jvalue, to_list(input), to_list(output), bigdl_type, **kwargs)
[docs] def predict(self, x, batch_per_thread=4, distributed=True): """ Use a model to do prediction. # Arguments x: Prediction data. A Numpy array or RDD of Sample or ImageSet. batch_per_thread: The default value is 4. When distributed is True,the total batch size is batch_per_thread * rdd.getNumPartitions. When distributed is False the total batch size is batch_per_thread * numOfCores. distributed: Boolean. Whether to do prediction in distributed mode or local mode. Default is True. In local mode, x must be a Numpy array. """ if isinstance(x, ImageSet) or isinstance(x, TextSet): results = callZooFunc(self.bigdl_type, "zooPredict", self.value, x, batch_per_thread) return ImageSet(results) if isinstance(x, ImageSet) else TextSet(results) if distributed: if isinstance(x, np.ndarray): data_rdd = to_sample_rdd(x, np.zeros([x.shape[0]])) elif isinstance(x, RDD): data_rdd = x else: raise TypeError("Unsupported prediction data type: %s" % type(x)) results = callZooFunc(self.bigdl_type, "zooPredict", self.value, data_rdd, batch_per_thread) return results.map(lambda result: Layer.convert_output(result)) else: if isinstance(x, np.ndarray) or isinstance(x, list): results = callZooFunc(self.bigdl_type, "zooPredict", self.value, self._to_jtensors(x), batch_per_thread) return [Layer.convert_output(result) for result in results] else: raise TypeError("Unsupported prediction data type: %s" % type(x))
[docs] def flattened_layers(self, include_container=False): jlayers = callZooFunc(self.bigdl_type, "getFlattenSubModules", self, include_container) layers = [Layer.of(jlayer) for jlayer in jlayers] return layers
@property def layers(self): jlayers = callZooFunc(self.bigdl_type, "getSubModules", self) layers = [Layer.of(jlayer) for jlayer in jlayers] return layers
[docs] @staticmethod def from_jvalue(jvalue, bigdl_type="float"): """ Create a Python Model base on the given java value :param jvalue: Java object create by Py4j :return: A Python Model """ model = GraphNet([], [], jvalue=jvalue, bigdl_type=bigdl_type) model.value = jvalue return model
[docs] def new_graph(self, outputs): """ Specify a list of nodes as output and return a new graph using the existing nodes :param outputs: A list of nodes specified :return: A graph model """ value = callZooFunc(self.bigdl_type, "newGraph", self.value, outputs) return self.from_jvalue(value, self.bigdl_type)
[docs] def freeze_up_to(self, names): """ Freeze the model from the bottom up to the layers specified by names (inclusive). This is useful for finetuning a model :param names: A list of module names to be Freezed :return: current graph model """ callZooFunc(self.bigdl_type, "freezeUpTo", self.value, names)
[docs] def unfreeze(self, names=None): """ "unfreeze" module, i.e. make the module parameters(weight/bias, if exists) to be trained(updated) in training process. If 'names' is a non-empty list, unfreeze layers that match given names :param names: list of module names to be unFreezed. Default is None. :return: current graph model """ callZooFunc(self.bigdl_type, "unFreeze", self.value, names)
[docs] def to_keras(self): value = callZooFunc(self.bigdl_type, "netToKeras", self.value) return ZooKerasLayer.of(value, self.bigdl_type)