diff --git a/transfer-learning.py b/transfer-learning.py index f185273f9c58c65857b284bdf120e728e653f81a..df54a490781af0117fc8cce48a49126e79e25da9 100644 --- a/transfer-learning.py +++ b/transfer-learning.py @@ -505,8 +505,11 @@ datagen = ImageDataGenerator( vertical_flip=False ) +val_samples = int(x_train.shape[0] * 0.1) +# round up to nearest multiple of batch_size +val_samples += batch_size - (val_samples % batch_size) x_train, x_val, y_train, y_val = train_test_split( - x_train, y_train, test_size=0.1) + x_train, y_train, test_size=val_samples) datagen.fit(x_train) convs = 1