diff --git a/transfer-learning.py b/transfer-learning.py
index f185273f9c58c65857b284bdf120e728e653f81a..df54a490781af0117fc8cce48a49126e79e25da9 100644
--- a/transfer-learning.py
+++ b/transfer-learning.py
@@ -505,8 +505,11 @@ datagen = ImageDataGenerator(
     vertical_flip=False
 )
 
+val_samples = int(x_train.shape[0] * 0.1)
+# round up to nearest multiple of batch_size
+val_samples += batch_size - (val_samples % batch_size)
 x_train, x_val, y_train, y_val = train_test_split(
-    x_train, y_train, test_size=0.1) 
+    x_train, y_train, test_size=val_samples)
 datagen.fit(x_train)
 
 convs = 1