Training the CNN Model

Learn how to train a CNN model for plant disease detection with proper techniques and best practices

Data Preparation Steps
Essential steps to prepare your dataset for CNN training

Split the Dataset

Divide your plant leaf images into training (70%), validation (15%), and test (15%) sets to ensure proper model evaluation.

Create Data Generators

Use data generators to load and preprocess images in batches, saving memory during training.

Apply Data Augmentation

Implement rotation, flipping, zooming, and brightness adjustments to increase dataset diversity.

Normalize Pixel Values

Scale pixel values to the range [0,1] by dividing by 255 to improve training stability.

One-Hot Encode Labels

Convert categorical disease labels to one-hot encoded vectors for multi-class classification.

Common Data Challenges
Issues to watch for when preparing plant disease datasets

Class Imbalance

Plant disease datasets often have more images of certain diseases. Use class weights or oversampling to address this.

Background Variation

Images may have different backgrounds that can confuse the model. Consider segmentation or background removal.

Lighting Conditions

Varying lighting can affect model performance. Use brightness and contrast augmentation to make the model robust.

Disease Similarity

Some plant diseases look similar, making classification difficult. Ensure your dataset has clear examples of each.

Data Leakage

Ensure images from the same plant don't appear in both training and validation sets to avoid overfitting.

Data Preparation Code
Python code for preparing your plant disease dataset
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

# Define paths and parameters
data_dir = 'plant_disease_dataset'
img_height, img_width = 224, 224
batch_size = 32
seed = 42

# Create train/validation/test splits
classes = os.listdir(data_dir)
train_data = []
val_data = []
test_data = []

for class_name in classes:
    class_dir = os.path.join(data_dir, class_name)
    images = os.listdir(class_dir)
    
    # Split images for this class
    train_imgs, temp_imgs = train_test_split(images, test_size=0.3, random_state=seed)
    val_imgs, test_imgs = train_test_split(temp_imgs, test_size=0.5, random_state=seed)
    
    # Add class and image path to respective lists
    train_data.extend([(os.path.join(class_dir, img), class_name) for img in train_imgs])
    val_data.extend([(os.path.join(class_dir, img), class_name) for img in val_imgs])
    test_data.extend([(os.path.join(class_dir, img), class_name) for img in test_imgs])

# Create data generators with augmentation for training
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# Only rescale validation and test data (no augmentation)
val_test_datagen = ImageDataGenerator(rescale=1./255)

# Create flow from directory generators
train_generator = train_datagen.flow_from_directory(
    data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='training',
    shuffle=True,
    seed=seed
)

validation_generator = val_test_datagen.flow_from_directory(
    data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation',
    shuffle=False,
    seed=seed
)

# Print dataset information
print(f"Number of training samples: {len(train_data)}")
print(f"Number of validation samples: {len(val_data)}")
print(f"Number of test samples: {len(test_data)}")
print(f"Number of classes: {len(classes)}")
print(f"Class names: {classes}")

# Check class distribution
for split_name, split_data in [("Training", train_data), ("Validation", val_data), ("Test", test_data)]:
    print(f"\n{split_name} class distribution:")
    class_counts = {}
    for _, class_name in split_data:
        if class_name not in class_counts:
            class_counts[class_name] = 0
        class_counts[class_name] += 1
    
    for class_name, count in class_counts.items():
        print(f"  {class_name}: {count} images")