Source code for zoo.orca.learn.pytorch.estimator

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from zoo.pipeline.estimator.estimator import Estimator as SparkEstimator
from zoo.orca.learn.pytorch.training_operator import TrainingOperator
from zoo.orca.data import SparkXShards
from bigdl.optim.optimizer import MaxEpoch
from zoo.feature.common import FeatureSet

import torch
from torch.optim.optimizer import Optimizer as TorchOptimizer
from torch.utils.data import DataLoader


[docs]class Estimator(object):
[docs] def fit(self, data, epochs, **kwargs): pass
[docs] def predict(self, data, **kwargs): pass
[docs] def evaluate(self, data, **kwargs): pass
[docs] def get_model(self): pass
[docs] def save(self, checkpoint): pass
[docs] def load(self, checkpoint): pass
[docs] def shutdown(self, force=False): pass
[docs] @staticmethod def from_torch(*, model, optimizer, loss=None, scheduler_creator=None, training_operator_cls=TrainingOperator, initialization_hook=None, config=None, scheduler_step_freq="batch", use_tqdm=False, workers_per_node=1, model_dir=None, backend="horovod"): if backend == "horovod": return PyTorchHorovodEstimatorWrapper(model_creator=model, optimizer_creator=optimizer, loss_creator=loss, scheduler_creator=scheduler_creator, training_operator_cls=training_operator_cls, initialization_hook=initialization_hook, config=config, scheduler_step_freq=scheduler_step_freq, use_tqdm=use_tqdm, workers_per_node=workers_per_node) elif backend == "bigdl": return PytorchSparkEstimatorWrapper(model=model, loss=loss, optimizer=optimizer, model_dir=model_dir, bigdl_type="float") else: raise ValueError("only horovod and bigdl backend are supported for now")
[docs]class PyTorchHorovodEstimatorWrapper(Estimator): def __init__(self, *, model_creator, optimizer_creator, loss_creator=None, scheduler_creator=None, training_operator_cls=TrainingOperator, initialization_hook=None, config=None, scheduler_step_freq="batch", use_tqdm=False, workers_per_node=1): from zoo.orca.learn.pytorch.pytorch_horovod_estimator import PyTorchHorovodEstimator self.estimator = PyTorchHorovodEstimator(model_creator=model_creator, optimizer_creator=optimizer_creator, loss_creator=loss_creator, scheduler_creator=scheduler_creator, training_operator_cls=training_operator_cls, initialization_hook=initialization_hook, config=config, scheduler_step_freq=scheduler_step_freq, use_tqdm=use_tqdm, workers_per_node=workers_per_node)
[docs] def fit(self, data, epochs=1, profile=False, reduce_results=True, info=None): """ :param data: (callable) a funtion that takes a config dict as input and return a data loader containing the training data. :param epochs: (int) Number of epochs to train the model :param profile: (bool) Returns time stats for the training procedure. :param reduce_results: (bool) Whether to average all metrics across all workers into one dict. If a metric is a non-numerical value (or nested dictionaries), one value will be randomly selected among the workers. If False, returns a list of dicts. :param info: (dict) Optional dictionary passed to the training operator for ``train_epoch`` and ``train_batch``. :return: (list) A list of stats whose length will be equal to ``epochs``. stats is a dictionary of metrics for training. You can provide custom metrics by passing in a custom ``training_operator_cls``. If ``reduce_results=False``, this will return a list of metric dictionaries whose length will be equal to ``num_workers``. """ return self.estimator.train(data_creator=data, epochs=epochs, profile=profile, reduce_results=reduce_results, info=info)
[docs] def predict(self, data, **kwargs): pass
[docs] def evaluate(self, data, num_steps=None, profile=False, info=None): """ :param data: (callable) a funtion that takes a config dict as input and return a data loader containing the validation data. :param num_steps: (int) Number of batches to compute update steps on. This corresponds also to the number of times ``TrainingOperator.validate_batch`` is called. :param profile: (bool) Returns time stats for the evaluation procedure. :param info: (dict) Optional dictionary passed to the training operator for `validate` and `validate_batch`. :return: A dictionary of metrics for validation. You can provide custom metrics by passing in a custom ``training_operator_cls``. """ return self.estimator.validate(data_creator=data, num_steps=num_steps, profile=profile, info=info)
[docs] def get_model(self): """Returns the learned model(s).""" return self.estimator.get_model()
[docs] def save(self, checkpoint): """Saves the Estimator state to the provided checkpoint path. :param checkpoint: (str) Path to target checkpoint file. """ return self.estimator.save(checkpoint=checkpoint)
[docs] def load(self, checkpoint): """Loads the Estimator and all workers from the provided checkpoint. :param checkpoint: (str) Path to target checkpoint file. """ return self.estimator.load(checkpoint=checkpoint)
[docs] def shutdown(self, force=False): """Shuts down workers and releases resources.""" return self.estimator.shutdown(force=force)
[docs]class PytorchSparkEstimatorWrapper(Estimator): def __init__(self, model, loss, optimizer, model_dir=None, bigdl_type="float"): from zoo.pipeline.api.torch import TorchModel, TorchLoss, TorchOptim self.loss = loss if self.loss is None: self.loss = TorchLoss() else: self.loss = TorchLoss.from_pytorch(loss) if optimizer is None: from bigdl.optim.optimizer import SGD optimizer = SGD() elif isinstance(optimizer, TorchOptimizer): optimizer = TorchOptim.from_pytorch(optimizer) self.model = TorchModel.from_pytorch(model) self.estimator = SparkEstimator(self.model, optimizer, model_dir, bigdl_type=bigdl_type)
[docs] def fit(self, data, epochs=1, batch_size=32, validation_data=None, validation_methods=None, checkpoint_trigger=None): from zoo.orca.data.utils import to_sample from zoo.orca.learn.metrics import Metrics from zoo.orca.learn.trigger import Trigger end_trigger = MaxEpoch(epochs) assert batch_size > 0, "batch_size should be greater than 0" validation_methods = Metrics.convert_metrics_list(validation_methods) checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger) if isinstance(data, SparkXShards): train_rdd = data.rdd.flatMap(to_sample) train_feature_set = FeatureSet.sample_rdd(train_rdd) if validation_data is None: val_feature_set = None else: assert isinstance(validation_data, SparkXShards), "validation_data should be a " \ "SparkXShards" val_feature_set = FeatureSet.sample_rdd(validation_data.rdd.flatMap(to_sample)) self.estimator.train(train_feature_set, self.loss, end_trigger, checkpoint_trigger, val_feature_set, validation_methods, batch_size) elif isinstance(data, DataLoader) or callable(data): train_feature_set = FeatureSet.pytorch_dataloader(data, "", "") if validation_data is None: val_feature_set = None else: assert isinstance(validation_data, DataLoader) or callable(data), \ "validation_data should be a pytorch DataLoader or a callable data_creator" val_feature_set = FeatureSet.pytorch_dataloader(validation_data) self.estimator.train_minibatch(train_feature_set, self.loss, end_trigger, checkpoint_trigger, val_feature_set, validation_methods) else: raise ValueError("Data and validation data should be SparkXShards, DataLoaders or " "callable data_creators but get " + data.__class__.__name__) return self
[docs] def predict(self, data, **kwargs): pass
[docs] def evaluate(self, data, validation_methods=None, batch_size=32): from zoo.orca.data.utils import to_sample from zoo.orca.learn.metrics import Metrics assert data is not None, "validation data shouldn't be None" validation_methods = Metrics.convert_metrics_list(validation_methods) if isinstance(data, SparkXShards): val_feature_set = FeatureSet.sample_rdd(data.rdd.flatMap(to_sample)) return self.estimator.evaluate(val_feature_set, validation_methods, batch_size) elif isinstance(data, DataLoader) or callable(data): val_feature_set = FeatureSet.pytorch_dataloader(data) return self.estimator.evaluate_minibatch(val_feature_set, validation_methods) else: raise ValueError("Data should be a SparkXShards, a DataLoader or a callable " "data_creator, but get " + data.__class__.__name__)
[docs] def get_model(self): return self.model.to_pytorch()
[docs] def save(self, checkpoint): pass
[docs] def load(self, checkpoint): pass
[docs] def shutdown(self, force=False): pass
[docs] def clear_gradient_clipping(self): """ Clear gradient clipping parameters. In this case, gradient clipping will not be applied. In order to take effect, it needs to be called before fit. :return: """ self.estimator.clear_gradient_clipping()
[docs] def set_constant_gradient_clipping(self, min, max): """ Set constant gradient clipping during the training process. In order to take effect, it needs to be called before fit. :param min: The minimum value to clip by. :param max: The maximum value to clip by. :return: """ self.estimator.set_constant_gradient_clipping(min=min, max=max)
[docs] def set_l2_norm_gradient_clipping(self, clip_norm): """ Clip gradient to a maximum L2-Norm during the training process. In order to take effect, it needs to be called before fit. :param clip_norm: Gradient L2-Norm threshold. :return: """ self.estimator.set_l2_norm_gradient_clipping(clip_norm=clip_norm)