Source code for zoo.common.utils

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from bigdl.util.common import Sample as BSample, JTensor as BJTensor,\
    JavaCreator, _get_gateway, _java2py, _py2java
import numpy as np
import os
import tempfile
import uuid
import shutil

from urllib.parse import urlparse


[docs]def convert_to_safe_path(input_path, follow_symlinks=True): # resolves symbolic links if follow_symlinks: return os.path.realpath(input_path) # covert to abs path return os.path.abspath(input_path)
[docs]def to_list_of_numpy(elements): if isinstance(elements, np.ndarray): return [elements] elif np.isscalar(elements): return [np.array(elements)] elif not isinstance(elements, list): raise ValueError("Wrong type: %s" % type(elements)) results = [] for element in elements: if np.isscalar(element): results.append(np.array(element)) elif isinstance(element, np.ndarray): results.append(element) else: raise ValueError("Wrong type: %s" % type(element)) return results
[docs]def get_file_list(path, recursive=False): return callZooFunc("float", "listPaths", path, recursive)
[docs]def is_local_path(path): parse_result = urlparse(path) return len(parse_result.scheme.lower()) == 0 or parse_result.scheme.lower() == "file"
[docs]def append_suffix(prefix, path): # append suffix splits = path.split(".") if len(splits) > 0: file_name = prefix + "." + splits[-1] else: file_name = prefix return file_name
[docs]def save_file(save_func, path, **kwargs): if is_local_path(path): save_func(path, **kwargs) else: file_name = str(uuid.uuid1()) file_name = append_suffix(file_name, path) temp_path = os.path.join(tempfile.gettempdir(), file_name) try: save_func(temp_path, **kwargs) put_local_file_to_remote(temp_path, path) finally: os.remove(temp_path)
[docs]def load_from_file(load_func, path): if is_local_path(path): return load_func(path) else: file_name = str(uuid.uuid1()) file_name = append_suffix(file_name, path) temp_path = os.path.join(tempfile.gettempdir(), file_name) get_remote_file_to_local(path, temp_path) try: return load_func(temp_path) finally: os.remove(temp_path)
[docs]def get_remote_file_to_local(remote_path, local_path, over_write=False): callZooFunc("float", "getRemoteFileToLocal", remote_path, local_path, over_write)
[docs]def put_local_file_to_remote(local_path, remote_path, over_write=False): callZooFunc("float", "putLocalFileToRemote", local_path, remote_path, over_write)
[docs]def set_core_number(num): callZooFunc("float", "setCoreNumber", num)
[docs]def callZooFunc(bigdl_type, name, *args): """ Call API in PythonBigDL """ gateway = _get_gateway() args = [_py2java(gateway, a) for a in args] error = Exception("Cannot find function: %s" % name) for jinvoker in JavaCreator.instance(bigdl_type, gateway).value: # hasattr(jinvoker, name) always return true here, # so you need to invoke the method to check if it exist or not try: api = getattr(jinvoker, name) java_result = api(*args) result = _java2py(gateway, java_result) except Exception as e: error = e if not ("does not exist" in str(e) and "Method {}".format(name) in str(e)): raise e else: return result raise error
[docs]class JTensor(BJTensor): def __init__(self, storage, shape, bigdl_type="float", indices=None): super(JTensor, self).__init__(storage, shape, bigdl_type, indices)
[docs] @classmethod def from_ndarray(cls, a_ndarray, bigdl_type="float"): """ Convert a ndarray to a DenseTensor which would be used in Java side. """ if a_ndarray is None: return None assert isinstance(a_ndarray, np.ndarray), \ "input should be a np.ndarray, not %s" % type(a_ndarray) return cls(a_ndarray, a_ndarray.shape, bigdl_type)
[docs]class Sample(BSample): def __init__(self, features, labels, bigdl_type="float"): super(Sample, self).__init__(features, labels, bigdl_type)
[docs] @classmethod def from_ndarray(cls, features, labels, bigdl_type="float"): features = to_list_of_numpy(features) labels = to_list_of_numpy(labels) return cls( features=[JTensor(feature, feature.shape) for feature in features], labels=[JTensor(label, label.shape) for label in labels], bigdl_type=bigdl_type)