Skip to content
Snippets Groups Projects
svhn.py 1.03 KiB
Newer Older
Carlos Vieira's avatar
Carlos Vieira committed
from keras.utils.data_utils import get_file
import numpy as np
from scipy.io import loadmat

def load_data():
    """Loads the SVHN dataset.

    # Arguments
        path: path where to cache the dataset locally
            (relative to ~/.keras/datasets).

    # Returns
        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
    """
    dirname = 'svhn'
    origin = 'http://ufldl.stanford.edu/housenumbers'

    train_mat = get_file("svhn_train_32x32.mat", origin=f"{origin}/train_32x32.mat")
    test_mat = get_file("svhn_test_32x32.mat", origin=f"{origin}/test_32x32.mat")
    Train = loadmat(train_mat)
    Test = loadmat(test_mat)
 
    x_train = Train['X']
    y_train = Train['y']
    x_test = Test['X']
    y_test = Test['y']
 
    x_train = x_train[np.newaxis,...]
    x_train = np.swapaxes(x_train,0,4).squeeze()
 
    x_test = x_test[np.newaxis,...]
    x_test = np.swapaxes(x_test,0,4).squeeze()
 
    np.place(y_train,y_train == 10,0)
    np.place(y_test,y_test == 10,0)

    return (x_train, y_train), (x_test, y_test)