Training the CNN Model
Learn how to train a CNN model for plant disease detection with proper techniques and best practices
Split the Dataset
Divide your plant leaf images into training (70%), validation (15%), and test (15%) sets to ensure proper model evaluation.
Create Data Generators
Use data generators to load and preprocess images in batches, saving memory during training.
Apply Data Augmentation
Implement rotation, flipping, zooming, and brightness adjustments to increase dataset diversity.
Normalize Pixel Values
Scale pixel values to the range [0,1] by dividing by 255 to improve training stability.
One-Hot Encode Labels
Convert categorical disease labels to one-hot encoded vectors for multi-class classification.
Class Imbalance
Plant disease datasets often have more images of certain diseases. Use class weights or oversampling to address this.
Background Variation
Images may have different backgrounds that can confuse the model. Consider segmentation or background removal.
Lighting Conditions
Varying lighting can affect model performance. Use brightness and contrast augmentation to make the model robust.
Disease Similarity
Some plant diseases look similar, making classification difficult. Ensure your dataset has clear examples of each.
Data Leakage
Ensure images from the same plant don't appear in both training and validation sets to avoid overfitting.
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
# Define paths and parameters
data_dir = 'plant_disease_dataset'
img_height, img_width = 224, 224
batch_size = 32
seed = 42
# Create train/validation/test splits
classes = os.listdir(data_dir)
train_data = []
val_data = []
test_data = []
for class_name in classes:
class_dir = os.path.join(data_dir, class_name)
images = os.listdir(class_dir)
# Split images for this class
train_imgs, temp_imgs = train_test_split(images, test_size=0.3, random_state=seed)
val_imgs, test_imgs = train_test_split(temp_imgs, test_size=0.5, random_state=seed)
# Add class and image path to respective lists
train_data.extend([(os.path.join(class_dir, img), class_name) for img in train_imgs])
val_data.extend([(os.path.join(class_dir, img), class_name) for img in val_imgs])
test_data.extend([(os.path.join(class_dir, img), class_name) for img in test_imgs])
# Create data generators with augmentation for training
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
# Only rescale validation and test data (no augmentation)
val_test_datagen = ImageDataGenerator(rescale=1./255)
# Create flow from directory generators
train_generator = train_datagen.flow_from_directory(
data_dir,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='categorical',
subset='training',
shuffle=True,
seed=seed
)
validation_generator = val_test_datagen.flow_from_directory(
data_dir,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='categorical',
subset='validation',
shuffle=False,
seed=seed
)
# Print dataset information
print(f"Number of training samples: {len(train_data)}")
print(f"Number of validation samples: {len(val_data)}")
print(f"Number of test samples: {len(test_data)}")
print(f"Number of classes: {len(classes)}")
print(f"Class names: {classes}")
# Check class distribution
for split_name, split_data in [("Training", train_data), ("Validation", val_data), ("Test", test_data)]:
print(f"\n{split_name} class distribution:")
class_counts = {}
for _, class_name in split_data:
if class_name not in class_counts:
class_counts[class_name] = 0
class_counts[class_name] += 1
for class_name, count in class_counts.items():
print(f" {class_name}: {count} images")