Source code for zoo.pipeline.api.keras.datasets.reuters

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from bigdl.dataset import base
import numpy as np

from six.moves import cPickle


[docs]def download_reuters(dest_dir): """Download pre-processed reuters newswire data :argument dest_dir: destination directory to store the data :return The absolute path of the stored data """ file_name = 'reuters.pkl' file_abs_path = base.maybe_download(file_name, dest_dir, 'https://s3.amazonaws.com/text-datasets/reuters.pkl') return file_abs_path
[docs]def load_data(dest_dir='/tmp/.zoo/dataset', nb_words=None, oov_char=2, test_split=0.2): """Load reuters dataset. :argument dest_dir: where to cache the data (relative to `~/.zoo/dataset`). nb_words: number of words to keep, the words are already indexed by frequency so that the less frequent words would be abandoned oov_char: index to pad the abandoned words, if None, one abandoned word would be taken place with its next word and total length -= 1 test_split: the ratio to split part of dataset to test data, the remained data would be train data :return the train, test separated reuters dataset. """ path = download_reuters(dest_dir) f = open(path, 'rb') x, y = cPickle.load(f) # return data and label, need to separate to train and test f.close() shuffle_by_seed([x, y]) if not nb_words: nb_words = max([max(s) for s in x]) if oov_char is not None: new_x = [] for s in x: new_s = [] for word in s: if word >= nb_words: new_s.append(oov_char) else: new_s.append(word) new_x.append(new_s) else: new_x = [] for s in x: new_s = [] for word in s: if word < nb_words: new_s.append(word) new_x.append(new_s) x = new_x split_index = int(len(x) * (1 - test_split)) x_train, y_train = x[:split_index], y[:split_index] x_test, y_test = x[split_index:], y[split_index:] return (x_train, y_train), (x_test, y_test)
[docs]def shuffle_by_seed(arr_list, seed=0): for arr in arr_list: np.random.seed(seed) np.random.shuffle(arr)
[docs]def get_word_index(dest_dir='/tmp/.zoo/dataset', filename='reuters_word_index.pkl'): """Retrieves the dictionary mapping word indices back to words. # Arguments dest_dir: where to cache the data (relative to `~/.zoo/dataset`). filename: dataset file name # Returns The word index dictionary. """ path = base.maybe_download(filename, dest_dir, 'https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl') f = open(path, 'rb') data = cPickle.load(f, encoding='latin1') f.close() return data
if __name__ == "__main__": print('Processing text dataset') (x_train, y_train), (x_test, y_test) = load_data() print('finished processing text')