#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from py4j.protocol import Py4JError
from zoo.orca.data.utils import *
from zoo.orca import OrcaContext
from zoo.common.nncontext import init_nncontext
from zoo import ZooContext, get_node_and_core_number
from zoo.util import nest
[docs]class XShards(object):
"""
A collection of data which can be pre-processed in parallel.
"""
[docs] def collect(self):
"""
Returns a list that contains all of the elements in this XShards
:return: list of elements
"""
pass
[docs] def num_partitions(self):
"""
return the number of partitions in this XShards
:return: an int
"""
pass
[docs] @classmethod
def load_pickle(cls, path, minPartitions=None):
"""
Load XShards from pickle files.
:param path: The pickle file path/directory
:param minPartitions: The minimum partitions for the XShards
:return: SparkXShards object
"""
sc = init_nncontext()
return SparkXShards(sc.pickleFile(path, minPartitions))
[docs] @staticmethod
def partition(data):
"""
Partition local in memory data and form a SparkXShards
:param data: np.ndarray, a tuple, list, dict of np.ndarray, or a nested structure
made of tuple, list, dict with ndarray as the leaf value
:return: a SparkXShards
"""
sc = init_nncontext()
node_num, core_num = get_node_and_core_number()
total_core_num = node_num * core_num
import numpy as np
type_err_msg = """
The types supported in zoo.orca.data.XShards.partition are
1. np.ndarray
2. a tuple, list, dict of np.ndarray
3. nested structure made of tuple, list, dict with ndarray as the leaf value
But got data of type {}
""".format(type(data))
supported_types = {list, tuple, dict}
if isinstance(data, np.ndarray):
arrays = np.array_split(data, total_core_num)
rdd = sc.parallelize(arrays)
else:
assert type(data) in supported_types, type_err_msg
flattened = nest.flatten(data)
data_length = len(flattened[0])
data_to_be_shard = []
for i in range(total_core_num):
data_to_be_shard.append([])
for x in flattened:
assert len(x) == data_length, \
"the ndarrays in data must all have the same size in first dimension, " \
"got first ndarray of size {} and another {}".format(data_length, len(x))
x_parts = np.array_split(x, total_core_num)
for idx, x_part in enumerate(x_parts):
data_to_be_shard[idx].append(x_part)
data_to_be_shard = [nest.pack_sequence_as(data, shard) for shard in data_to_be_shard]
rdd = sc.parallelize(data_to_be_shard)
data_shards = SparkXShards(rdd)
return data_shards
[docs]class RayXShards(XShards):
"""
A collection of data which can be pre-processed in parallel on Ray
"""
def __init__(self, partitions):
assert all([isinstance(partition, RayPartition) for partition in partitions]), \
"partitions should be list of RayPartition"
self.partitions = partitions
[docs] def collect(self):
return [partition.get_data() for partition in self.partitions]
[docs] def num_partitions(self):
return len(self.partitions)
[docs] def get_partitions(self):
"""
Return the list of RayPartition of the RayXShards
"""
return self.partitions
[docs] def colocate_actors(self, actors):
"""
Sort Ray actors and RayPartitions by node_ip so that each actor is colocated
with the data partition on the same node.
"""
if self.partitions[0].node_ip:
# Assume that the partitions are already sorted by node_ip
import ray
actor_ips = ray.get([actor.get_node_ip.remote() for actor in actors])
actor_zip_ips = list(zip(actors, actor_ips))
actor_zip_ips.sort(key=lambda x: x[1])
for i in range(len(actors)):
actor_ip = actor_zip_ips[i][1]
partition_ip = self.partitions[i].node_ip
assert actor_ip == partition_ip
return [actor_ip[0] for actor_ip in actor_zip_ips]
else:
return actors
[docs]class RayPartition(object):
"""
A partition of RayXShards containing the plasma ObjectID, the plasma object_store_address,
and the node of the partition.
"""
def __init__(self, object_id, node_ip, object_store_address):
# The object_id would contain a list of data from a partition of Spark RDD.
self.object_id = object_id
self.node_ip = node_ip
self.object_store_address = object_store_address
[docs] def get_data(self):
import pyarrow.plasma as plasma
# Default num_retries=-1 would try 80 times.
self.client = plasma.connect(self.object_store_address, num_retries=5)
return self.client.get(self.object_id)
def __del__(self):
if self.object_store_address:
import logging
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger()
try:
if "client" not in self.__dict__:
import pyarrow.plasma as plasma
self.client = plasma.connect(self.object_store_address, num_retries=5)
if self.client.contains(self.object_id):
self.client.delete([self.object_id])
assert not self.client.contains(self.object_id)
logger.info("Removed data from plasma object store on node " + str(self.node_ip))
self.client.disconnect()
del self.client
except Exception as e:
logger.warning(e)
logger.warning("Error occurred when removing the data from the plasma store "
"on node " + str(self.node_ip))
[docs]class SparkXShards(XShards):
"""
A collection of data which can be pre-processed in parallel on Spark
"""
def __init__(self, rdd, transient=False):
self.rdd = rdd
self.user_cached = False
if transient:
self.eager = False
else:
self.eager = OrcaContext._eager_mode
self.rdd.cache()
if self.eager:
self.compute()
self.type = {}
[docs] def collect(self):
"""
Returns a list that contains all of the elements in this SparkXShards
:return: a list of data elements.
"""
return self.rdd.collect()
[docs] def cache(self):
"""
Persist this SparkXShards in memory
:return:
"""
self.user_cached = True
self.rdd.cache()
return self
[docs] def uncache(self):
"""
Make this SparkXShards as non-persistent, and remove all blocks for it from memory
:return:
"""
self.user_cached = False
if self.is_cached():
try:
self.rdd.unpersist()
except Py4JError:
print("Try to unpersist an uncached rdd")
return self
def _uncache(self):
if not self.user_cached:
self.uncache()
[docs] def is_cached(self):
return self.rdd.is_cached
[docs] def compute(self):
self.rdd.count()
return self
[docs] def num_partitions(self):
"""
Get number of partitions for this SparkXShards.
:return: number of partitions.
"""
return self.rdd.getNumPartitions()
[docs] def repartition(self, num_partitions):
"""
Return a new SparkXShards that has exactly num_partitions partitions.
:param num_partitions: target number of partitions
:return: a new SparkXshards object.
"""
if self._get_class_name() == 'pandas.core.frame.DataFrame':
import pandas as pd
if num_partitions > self.rdd.getNumPartitions():
rdd = self.rdd\
.flatMap(lambda df: df.apply(lambda row: (row[0], row.values.tolist()), axis=1)
.values.tolist())\
.partitionBy(num_partitions)
schema = self._get_schema()
def merge_rows(iter):
data = [value[1] for value in list(iter)]
if data:
df = pd.DataFrame(data=data, columns=schema['columns'])\
.astype(schema['dtypes'])
return [df]
else:
# no data in this partition
return iter
repartitioned_shard = SparkXShards(rdd.mapPartitions(merge_rows))
else:
def combine_df(iter):
dfs = list(iter)
if len(dfs) > 0:
return [pd.concat(dfs)]
else:
return iter
rdd = self.rdd.coalesce(num_partitions)
repartitioned_shard = SparkXShards(rdd.mapPartitions(combine_df))
elif self._get_class_name() == 'list':
if num_partitions > self.rdd.getNumPartitions():
rdd = self.rdd \
.flatMap(lambda data: data) \
.repartition(num_partitions)
repartitioned_shard = SparkXShards(rdd.mapPartitions(
lambda iter: [list(iter)]))
else:
rdd = self.rdd.coalesce(num_partitions)
from functools import reduce
repartitioned_shard = SparkXShards(rdd.mapPartitions(
lambda iter: [reduce(lambda l1, l2: l1 + l2, iter)]))
elif self._get_class_name() == 'numpy.ndarray':
elem = self.rdd.first()
shape = elem.shape
dtype = elem.dtype
if len(shape) > 0:
if num_partitions > self.rdd.getNumPartitions():
rdd = self.rdd\
.flatMap(lambda data: list(data))\
.repartition(num_partitions)
repartitioned_shard = SparkXShards(rdd.mapPartitions(
lambda iter: np.stack([list(iter)], axis=0)
.astype(dtype)))
else:
rdd = self.rdd.coalesce(num_partitions)
from functools import reduce
repartitioned_shard = SparkXShards(rdd.mapPartitions(
lambda iter: [np.concatenate(list(iter), axis=0)]))
else:
repartitioned_shard = SparkXShards(self.rdd.repartition(num_partitions))
else:
repartitioned_shard = SparkXShards(self.rdd.repartition(num_partitions))
self._uncache()
return repartitioned_shard
[docs] def partition_by(self, cols, num_partitions=None):
"""
Return a new SparkXShards partitioned using the specified columns.
This is only applicable for SparkXShards of Pandas DataFrame.
:param cols: specified columns to partition by.
:param num_partitions: target number of partitions. If not specified,
the new SparkXShards would keep the current partition number.
:return: a new SparkXShards.
"""
if self._get_class_name() == 'pandas.core.frame.DataFrame':
import pandas as pd
schema = self._get_schema()
# if partition by a column
if isinstance(cols, str):
if cols not in schema['columns']:
raise Exception("The partition column is not in the DataFrame")
# change data to key value pairs
rdd = self.rdd.flatMap(
lambda df: df.apply(lambda row: (row[cols], row.values.tolist()), axis=1)
.values.tolist())
partition_num = self.rdd.getNumPartitions() if not num_partitions \
else num_partitions
# partition with key
partitioned_rdd = rdd.partitionBy(partition_num)
else:
raise Exception("Only support partition by a column name")
def merge(iterator):
data = [value[1] for value in list(iterator)]
if data:
df = pd.DataFrame(data=data, columns=schema['columns']).astype(schema['dtypes'])
return [df]
else:
# no data in this partition
return []
# merge records to df in each partition
partitioned_shard = SparkXShards(partitioned_rdd.mapPartitions(merge))
self._uncache()
return partitioned_shard
else:
raise Exception("Currently only support partition by for XShards"
" of Pandas DataFrame")
[docs] def unique(self):
"""
Return a unique list of elements of this SparkXShards.
This is only applicable for SparkXShards of Pandas Series.
:return: a unique list of elements of this SparkXShards.
"""
if self._get_class_name() == 'pandas.core.series.Series':
import pandas as pd
rdd = self.rdd.map(lambda s: s.unique())
import numpy as np
result = rdd.reduce(lambda list1, list2: pd.unique(np.concatenate((list1, list2),
axis=0)))
return result
else:
# we may support numpy or other types later
raise Exception("Currently only support unique() on XShards of Pandas Series")
[docs] def split(self):
"""
Split SparkXShards into multiple SparkXShards.
Each element in the SparkXShards needs be a list or tuple with same length.
:return: Splits of SparkXShards. If element in the input SparkDataShard is not
list or tuple, return list of input SparkDataShards.
"""
# get number of splits
list_split_length = self.rdd.map(lambda data: len(data) if isinstance(data, list) or
isinstance(data, tuple) else 1).collect()
# check if each element has same splits
if list_split_length.count(list_split_length[0]) != len(list_split_length):
raise Exception("Cannot split this XShards because its partitions "
"have different split length")
else:
if list_split_length[0] > 1:
def get_data(order):
def transform(data):
return data[order]
return transform
split_shard_list = [SparkXShards(self.rdd.map(get_data(i)))
for i in range(list_split_length[0])]
self._uncache()
return split_shard_list
else:
return [self]
[docs] def zip(self, other):
"""
Zips this SparkXShards with another one, returning key-value pairs with the first element
in each SparkXShards, second element in each SparkXShards, etc. Assumes that the two
SparkXShards have the *same number of partitions* and the *same number of elements
in each partition*(e.g. one was made through a transform_shard on the other
:param other: another SparkXShards
:return:
"""
assert isinstance(other, SparkXShards), "other should be a SparkXShards"
assert self.num_partitions() == other.num_partitions(), \
"The two SparkXShards should have the same number of partitions"
try:
rdd = self.rdd.zip(other.rdd)
zipped_shard = SparkXShards(rdd)
other._uncache()
self._uncache()
return zipped_shard
except Exception:
raise ValueError("The two SparkXShards should have the same number of elements "
"in each partition")
def __len__(self):
return self.rdd.map(lambda data: len(data) if hasattr(data, '__len__') else 1)\
.reduce(lambda l1, l2: l1 + l2)
[docs] def save_pickle(self, path, batchSize=10):
"""
Save this SparkXShards as a SequenceFile of serialized objects.
The serializer used is pyspark.serializers.PickleSerializer, default batch size is 10.
:param path: target path.
:param batchSize: batch size for each sequence file chunk.
"""
self.rdd.saveAsPickleFile(path, batchSize)
return self
def __del__(self):
self.uncache()
def __getitem__(self, key):
def get_data(data):
assert hasattr(data, '__getitem__'), \
"No selection operation available for this XShards"
try:
value = data[key]
except:
raise Exception("Invalid key for this XShards")
return value
return SparkXShards(self.rdd.map(get_data), transient=True)
# Tested on pyarrow 0.17.0; 0.16.0 would get errors.
[docs] def to_ray(self):
"""
Put data of this SparkXShards to Ray cluster object store.
:return: a new RayXShards which contains data of this SparkXShards.
"""
from zoo.ray import RayContext
ray_ctx = RayContext.get()
object_store_address = ray_ctx.address_info["object_store_address"]
def put_to_plasma(ids):
def f(index, iterator):
import pyarrow.plasma as plasma
from zoo.util.utils import get_node_ip
res = list(iterator)
client = plasma.connect(object_store_address)
target_id = ids[index]
# If the ObjectID exists in plasma, we assume a task trial
# succeeds and the data is already in the object store.
if not client.contains(target_id):
object_id = client.put(res, target_id)
assert object_id == target_id, \
"Errors occurred when putting data into plasma object store"
client.disconnect()
yield target_id, get_node_ip()
return f
# Create plasma ObjectIDs beforehand instead of creating a random one every time to avoid
# memory leak in case errors occur when putting data into plasma and Spark would retry.
# ObjectIDs in plasma is a byte string of length 20 containing characters and numbers.
# The random generation of ObjectIDs is often good enough to ensure unique IDs.
import pyarrow.plasma as plasma
object_ids = [plasma.ObjectID.from_random() for i in range(self.rdd.getNumPartitions())]
object_id_node_ips = self.rdd.mapPartitionsWithIndex(put_to_plasma(object_ids)).collect()
self.uncache()
# Sort the data according to the node_ips.
object_id_node_ips.sort(key=lambda x: x[1])
partitions = [RayPartition(object_id=id_ip[0], node_ip=id_ip[1],
object_store_address=object_store_address)
for id_ip in object_id_node_ips]
return RayXShards(partitions)
def _for_each(self, func, *args, **kwargs):
def utility_func(x, func, *args, **kwargs):
try:
result = func(x, *args, **kwargs)
except Exception as e:
return e
return result
result_rdd = self.rdd.map(lambda x: utility_func(x, func, *args, **kwargs))
return result_rdd
def _get_schema(self):
if 'schema' in self.type:
return self.type['schema']
else:
if self._get_class_name() == 'pandas.core.frame.DataFrame':
import pandas as pd
columns, dtypes = self.rdd.map(lambda x: (x.columns, x.dtypes)).first()
self.type['schema'] = {'columns': columns, 'dtypes': dtypes}
return self.type['schema']
return None
def _get_class_name(self):
if 'class_name' in self.type:
return self.type['class_name']
else:
self.type['class_name'] = self._for_each(get_class_name).first()
return self.type['class_name']
[docs]class SharedValue(object):
def __init__(self, data):
sc = init_nncontext()
self.broadcast_data = sc.broadcast(data)
self._value = None
@property
def value(self):
self._value = self.broadcast_data.value
return self._value
[docs] def unpersist(self):
self.broadcast_data.unpersist()