Image

Banana Leaf Disease Detection using Vision Transformer model

Banana cultivation is a significant agricultural activity in many tropical and subtropical regions, providing a vital source of income and nutrition. However, the health and productivity of banana plants are frequently threatened by various leaf diseases. These diseases can drastically reduce yield and quality, leading to substantial economic losses for farmers and impacting food security.


Traditional methods of disease detection in banana plants rely heavily on manual inspection by experts, which can be time-consuming, labor-intensive, and prone to human error. Additionally, by the time symptoms are visually noticeable, the disease may have already spread extensively, making effective management and control more challenging.


To address these challenges, the Banana Leaf Disease Detection project leverages advanced technologies such as machine learning, computer vision, and remote sensing. By utilizing these innovative approaches, the project aims to develop an automated and efficient system for early detection and diagnosis of banana leaf diseases.

Project Overview

Banana Leaf Disease Detection using Vision Transformer and CNN models aims to address a critical agricultural challenge. Banana plants are an essential source of food and income in tropical regions, but diseases can seriously harm their health and productivity. This project uses advanced technology to develop an efficient and accurate system for detecting diseases in banana leaves early.


The combination of Vision Transformer and CNN models offers a powerful approach to solving the problem. Traditional methods of disease detection are time-consuming and labor-intensive. By applying machine learning and computer vision, this project automates disease detection, reducing the reliance on manual inspection. The system helps farmers detect diseases at an early stage, leading to better disease management and increased crop yields.


Prerequisites

Before starting this project, it is important to have a basic understanding of machine learning and computer vision. Familiarity with deep learning frameworks such as TensorFlow and Keras is essential, as they will be used to build the models. A working knowledge of Python programming is required, as it will be used to implement the entire workflow. Experience with data preprocessing, model evaluation, and optimization is also beneficial.


You should have access to a dataset of banana leaf images, labeled with different types of diseases. The dataset will be used to train and validate the models. You will also need access to Google Colab or a similar platform for running the models, as well as the necessary packages such as TensorFlow, Keras, and Sklearn for building and training the models.


Approach

The approach to this project combines advanced machine learning techniques with practical agricultural applications. The project uses a hybrid model, combining the strengths of Vision Transformers and Convolutional Neural Networks (CNNs). Vision Transformers excel at capturing long-range dependencies in image data, while CNNs are powerful for extracting local features. By integrating these two models, the project creates a more robust system for banana leaf disease detection.


The project uses image data of banana leaves, which are preprocessed and fed into both models. Each model predicts the presence of a particular disease, and the final prediction is made using a voting system. This ensures that the predictions are more accurate, as it takes into account the strengths of both models.


Workflow and Methodology

The Overall Workflow of this Project Includes:

  • Data Collection and Preparation: The first step involves collecting a dataset of banana leaf images, categorized by the type of disease present. After that, the dataset is prepared for model training.
  • Model Building: Two models are built in parallel: a Vision Transformer model and a CNN model. Each model is trained separately to predict banana leaf diseases.
  • Model Training: Both models are trained using the prepared dataset. Class weights are adjusted to account for any imbalances in the dataset.
  • Model Evaluation: After training, both models are evaluated on a separate validation dataset to check their accuracy.
  • Voting Mechanism: The final step involves using an ensemble voting mechanism to combine the predictions of both models and make a more accurate decision.

The Methodology Involves:

  • Data Augmentation: Data augmentation techniques are applied to increase the diversity of the training data.
  • Model Optimization: Learning rate schedules and early stopping mechanisms are implemented to ensure that the models are optimized during training.
  • Performance Evaluation: The models' performance is evaluated using metrics such as precision, recall, accuracy, and AUC (Area Under the Curve).

Data Collection

The success of any machine learning project largely depends on the quality of the data used for training. For this project, a high-quality dataset of banana leaf images is essential. The dataset must include images of healthy banana leaves as well as leaves affected by various diseases like cordana, pestalotiopsis, and sigatoka.


Data Preparation

Once the dataset is collected, it must be carefully prepared before training the models. This involves splitting the dataset into training and validation sets, with 80% of the data used for training and 20% reserved for validation. It is important to ensure that the dataset is balanced so that all disease categories are represented equally.

Data augmentation techniques are used to increase the size of the dataset. These techniques include rotating, flipping, and zooming the images to create new variations. This helps to improve the robustness of the model by exposing it to a wider variety of inputs during training.


Data Preparation Workflow

  1. Rescaling: All images are rescaled to a common size of 224x224 pixels to ensure consistency.

  2. Normalization: The pixel values of the images are normalized to lie between 0 and 1. This helps the model converge faster during training.

  3. Data Augmentation: Techniques such as rotation, zoom, and flipping are applied to increase the size of the dataset.

  4. Splitting the Data: The dataset is split into training and validation sets in an 80:20 ratio.

  5. Handling Class Imbalance: Class weights are calculated and applied to handle any imbalances in the dataset.


Explanation All Code

Step 1:

Import and install the necessary packages.

!pip install tensorflow
!pip install keras

Importing Libraries

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, MaxPooling2D, Dropout, GlobalAveragePooling2D, Dense, Input, Activation
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization, GlobalAveragePooling2D
from tensorflow.keras.applications import ResNet50, InceptionV3
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Dense, LayerNormalization, Dropout, Rescaling
from tensorflow.keras.metrics import Precision, Recall, AUC
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.utils import class_weight
from keras.optimizers import Adam
from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, f1_score
import numpy as np
from PIL import Image
from keras.preprocessing.image import img_to_array, load_img
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.ndimage import sobel
from skimage.filters import scharr
import warnings
import plotly.express as px
import plotly.subplots as sp
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.figure_factory as ff
warnings.filterwarnings('ignore')

You can mount your Google Drive in a Google Colab notebook with this piece of code. This makes it easy to view files stored in Google Drive for data manipulation, analysis, and training models in the Colab environment.

#Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Step 2:

Data collection and preparation

We split the dataset for training and testing. We used 80% data for traning and remaining 20% as test data.This block of code loads and preprocesses an image dataset from Google Drive for Keras model training. It sets up an ImageDataGenerator for data augmentation and rescaling pixel values to [0, 1]. The dataset is split into 80% for training and 20% for validation. Using flow_from_directory, it generates batches of 32 images resized to 224x224 pixels with categorical labels for multi-class classification.

# Read dataset from google drive
data_dir = '/content/drive/MyDrive/banana leaf/AugmentedSet'
datagen = ImageDataGenerator(validation_split=0.2, rescale=1./255)
# Create training and validation generators
train_generator = datagen.flow_from_directory(
    data_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='training'
)
val_generator = datagen.flow_from_directory(
    data_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='validation'
)

This block of code calculates and displays class weights to address class imbalance in a dataset and prints class labels for a sample batch. It uses class_weight.compute_class_weight from sklearn to calculate balanced class weights based on the training data, storing them in a dictionary with class indices as keys. The mapping of class names to numerical labels is printed, and a new dictionary, class_weights_with_labels, maps class names to their corresponding weights. Finally, it prints class names for a sample batch from the training generator by decoding the class indices, providing a quick check of class labels and class distribution.

# Compute class weights
class_weights = class_weight.compute_class_weight(
    'balanced',
    classes=np.unique(train_generator.classes),
    y=train_generator.classes
)
class_weights = {i: class_weights[i] for i in range(len(class_weights))}
# Print the class indices to see the mapping
print("Class indices:", train_generator.class_indices)
# Map class indices to class names and display weights
class_labels = {v: k for k, v in train_generator.class_indices.items()}
class_weights_with_labels = {class_labels[i]: class_weights[i] for i in class_weights}
print("Class weights with labels:")
for class_name, weight in class_weights_with_labels.items():
    print(f"Class: {class_name}, Weight: {weight:.4f}")
# you can also print the class labels for a sample batch
x, y = next(train_generator)
class_labels = {v: k for k, v in train_generator.class_indices.items()}
print("Sample batch class names:", [class_labels[np.argmax(label)] for label in y])

Step 3:

Display the images of every class

This block of code visualizes a batch of images from the training dataset with their class labels using Matplotlib and Seaborn. It sets the plot style to "dark" with sns.set(style="dark") and fetches a batch of images and labels from train_generator. A function plot_images is defined to display these images in a 4x8 grid, removing axes and setting each subplot's title to the corresponding class name (decoded from the one-hot encoded label). Finally, it calls plot_images to show the images and their labels, providing a visual check of the batch.

sns.set(style="dark")
# Fetch a batch of images and labels
images, labels = next(train_generator)
# Function to plot images with class names
def plot_images(images_arr, labels_arr, class_labels):
    fig, axes = plt.subplots(4, 8, figsize=(20, 10))
    axes = axes.flatten()
    for img, lbl, ax in zip(images_arr, labels_arr, axes):
        ax.imshow(img)
        ax.axis('off')
        class_name = class_labels[np.argmax(lbl)]
        ax.set_title(class_name)
    plt.tight_layout()
    plt.show()
# Plot images with class names
plot_images(images, labels, class_labels)

Plot the grayscale images

This block of code defines and uses a function to visualize a batch of Sobel-filtered images with their class labels. The plot_sobel_images function takes original images, Sobel-filtered images, labels, and a class index-to-name mapping. It creates a 4x8 grid of subplots using Matplotlib, displaying each Sobel-filtered image in grayscale with its corresponding class name as the title, and turns off the axes for a cleaner look. The function is then called to show the Sobel-filtered images along with their class labels, offering a visual representation of the filtered images and their classes.

def apply_sobel(images_arr):
    sobel_images = []
    for img in images_arr:
        gray_img = np.dot(img[...,:3], [0.2989, 0.5870, 0.1140])
        sobel_x = sobel(gray_img, axis=0, mode='constant')
        sobel_y = sobel(gray_img, axis=1, mode='constant')
        sobel_img = np.hypot(sobel_x, sobel_y)
        sobel_images.append(sobel_img)
    return np.array(sobel_images)
sobel_images = apply_sobel(images)
def plot_sobel_images(original_images, sobel_images, labels_arr, class_labels):
    fig, axes = plt.subplots(4, 8, figsize=(20, 10))
    axes = axes.flatten()
    for orig_img, sob_img, lbl, ax in zip(original_images, sobel_images, labels_arr, axes):
        ax.imshow(sob_img, cmap='gray')
        ax.axis('off')
        class_name = class_labels[np.argmax(lbl)]
        ax.set_title(class_name, color='black')
    plt.tight_layout()
    plt.show()
plot_sobel_images(images, sobel_images, labels, class_labels)
type(train_generator)