Source code for zoo.orca.learn.pytorch.torch_runner

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Copyright 2017 The Ray Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is adapted from
# https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/torch_runner.py

from filelock import FileLock
import logging
import inspect
import io
import itertools
import os
import tempfile
import torch
import torch.nn as nn

import ray
from zoo.orca import OrcaContext
from zoo.orca.learn.pytorch.constants import SCHEDULER_STEP, NUM_STEPS
from zoo.orca.learn.pytorch.training_operator import TrainingOperator
from zoo.orca.learn.pytorch import utils

logger = logging.getLogger(__name__)

try:
    from collections.abc import Iterable
except ImportError:
    from collections import Iterable


[docs]class TorchRunner: """Manages a PyTorch model for training.""" def __init__(self, model_creator, optimizer_creator, loss_creator=None, scheduler_creator=None, training_operator_cls=None, config=None, use_tqdm=False, scheduler_step_freq=None): self.model_creator = model_creator self.optimizer_creator = optimizer_creator self.loss_creator = loss_creator self.scheduler_creator = scheduler_creator self.training_operator_cls = training_operator_cls or TrainingOperator self.config = {} if config is None else config self.timers = utils.TimerCollection() self.epochs = 0 self.models = None self.optimizers = None self.criterion = None self.schedulers = None self.train_loader = None self.validation_loader = None self.training_operator = None self.use_tqdm = use_tqdm self.scheduler_step_freq = scheduler_step_freq def _create_loss(self): if not self.loss_creator: return logger.debug("Creating loss.") if inspect.isclass(self.loss_creator) and issubclass( self.loss_creator, torch.nn.modules.loss._Loss): self.criterion = self.loss_creator() else: self.criterion = self.loss_creator(self.config) def _create_schedulers_if_available(self): # Learning rate schedules are optional. if not self.scheduler_creator: return self.schedulers = self.scheduler_creator(self.given_optimizers, self.config) if not isinstance(self.schedulers, Iterable): self.schedulers = [self.schedulers]
[docs] def setup_components(self): """Runs the creator functions without any distributed coordination.""" logger.debug("Creating model") self.models = self.model_creator(self.config) if not isinstance(self.models, Iterable): self.models = [self.models] assert all(isinstance(model, nn.Module) for model in self.models), ( "All models must be PyTorch models: {}.".format(self.models)) logger.debug("Creating optimizer.") self.optimizers = self.optimizer_creator(self.given_models, self.config) if not isinstance(self.optimizers, Iterable): self.optimizers = [self.optimizers] self._create_schedulers_if_available() self._create_loss()
[docs] def setup_operator(self): """Create the training operator.""" self.training_operator = self.training_operator_cls( self.config, models=self.models, optimizers=self.optimizers, criterion=self.criterion, train_loader=self.train_loader, validation_loader=self.validation_loader, world_rank=0, schedulers=self.schedulers, use_tqdm=self.use_tqdm)
[docs] def get_node_ip(self): """Returns the IP address of the current node.""" return ray.services.get_node_ip_address()
[docs] def find_free_port(self): """Finds a free port on the current node.""" return utils.find_free_port()
[docs] def with_sampler(self, loader): raise NotImplementedError("Please implement with_sampler in the subclass of TorchRunner")
[docs] @staticmethod def should_wrap_dataloader(loader): from torch.utils.data import DataLoader, IterableDataset return (isinstance(loader, DataLoader) and not isinstance(loader.dataset, IterableDataset))
[docs] def train_epochs(self, data_creator, epochs=1, profile=False, info=None): if OrcaContext.serialize_data_creation: with FileLock( os.path.join(tempfile.gettempdir(), ".orcadata.lock")): loader = data_creator(self.config) else: loader = data_creator(self.config) if TorchRunner.should_wrap_dataloader(loader): loader = self.with_sampler(loader) stats_list = list() for i in range(epochs): stats = self.train_epoch(loader, profile=profile, info=info) stats_list.append(stats) return stats_list
[docs] def train_epoch(self, data_loader, profile=False, info=None): """Runs a training epoch and updates the model parameters.""" logger.debug("Begin Training Step {}".format(self.epochs + 1)) info = info or {} self._toggle_profiling(profile=profile) info.update({ SCHEDULER_STEP: self.scheduler_step_freq }) with self.timers.record("train_epoch"): data_loader = iter(data_loader) train_stats = self.training_operator.train_epoch(data_loader, info) self.epochs += 1 # This is so that `epochs` is first in ordering. stats = dict(epoch=self.epochs, **train_stats) if profile: stats.update(profile=self.timers.stats()) return stats
[docs] def validate(self, data_creator, num_steps=None, profile=False, info=None): """Evaluates the model on the validation data set.""" info = info or {} self._toggle_profiling(profile=profile) if OrcaContext.serialize_data_creation: with FileLock( os.path.join(tempfile.gettempdir(), ".orcadata.lock")): loader = data_creator(self.config) else: loader = data_creator(self.config) with self.timers.record("validation"): if TorchRunner.should_wrap_dataloader(loader): loader = iter(self.with_sampler(loader)) if num_steps: loader = itertools.islice(loader, num_steps) validation_stats = self.training_operator.validate(loader, info=info) if profile: validation_stats.update(profile=self.timers.stats()) return validation_stats
def _toggle_profiling(self, profile=False): """Enables/Disables and resets timing profiles.""" if profile: self.timers.enable() self.timers.reset() else: self.timers.disable() self.training_operator._set_timers(self.timers)
[docs] def state_dict(self): """Returns the state of the runner.""" state = { "epoch": self.epochs, "operator": self.training_operator.state_dict(), "models": [model.state_dict() for model in self.models], "optimizers": [opt.state_dict() for opt in self.optimizers] } if self.schedulers: state.update({ "schedulers": [ scheduler.state_dict() for scheduler in self.schedulers ] }) return state
[docs] def load_state_dict(self, state): """Sets the state of the model.""" for model, state_dict in zip(self.models, state["models"]): model.load_state_dict(state_dict) for optimizer, state_dict in zip(self.optimizers, state["optimizers"]): optimizer.load_state_dict(state_dict) if self.schedulers: for scheduler, state_dict in zip(self.schedulers, state["schedulers"]): scheduler.load_state_dict(state_dict) self.epochs = state["epoch"] self.training_operator.load_state_dict(state_dict)
[docs] def state_stream(self): """Returns a bytes object for the state dict.""" state_dict = self.state_dict() _buffer = io.BytesIO() torch.save(state_dict, _buffer) return _buffer.getvalue()
[docs] def load_state_stream(self, byte_obj): """Loads a bytes object the training state dict.""" _buffer = io.BytesIO(byte_obj) state_dict = torch.load(_buffer) return self.load_state_dict(state_dict)
[docs] def apply(self, fn): return fn()
[docs] def apply_operator(self, fn): return fn(self.training_operator)
[docs] def shutdown(self): """Attempts to shut down the worker.""" del self.training_operator del self.validation_loader del self.train_loader del self.criterion del self.optimizers del self.models if torch.cuda.is_available(): torch.cuda.empty_cache()
@property def given_models(self): if len(self.models) > 1: return self.models else: return self.models[0] @property def given_optimizers(self): if len(self.optimizers) > 1: return self.optimizers else: return self.optimizers[0] @property def given_schedulers(self): if not self.schedulers: return self.schedulers if len(self.schedulers) > 1: return self.schedulers else: return self.schedulers[0]