This README provides a comprehensive guide to developing an AI application for classifying different species of flowers using deep learning. The project involves training a classifier to recognize flower species, which could be utilized in applications such as a mobile app that identifies flowers from camera images.
In this project, you will:
- Load and preprocess the image dataset
- Train an image classifier
- Use the trained classifier to predict image content
By the end of the project, you will have a command-line application capable of training on any set of labeled images. The skills and methods developed here can be applied to various other image classification tasks.
- Python (3.x)
- PyTorch
- torchvision
- PIL (Python Imaging Library)
- Matplotlib
- JSON
Start by importing the necessary packages. Keeping all imports at the beginning of your code is a good practice.
# TODO: Import necessary packages
Dataset Overview:
- The dataset consists of images split into three sets: training, validation, and testing.
- The training set will be augmented with transformations such as random scaling, cropping, and flipping.
- The validation and testing sets will be resized to 224x224 pixels without transformations.
Tasks:
- Define transformations for each dataset.
- Load the datasets using
ImageFolder
. - Create data loaders.
# TODO: Define data transformations
# TODO: Load datasets with ImageFolder
# TODO: Define data loaders
Load a JSON file containing the mapping from category labels to category names. This mapping will help in interpreting the classifier's output.
import json
# Load label mapping
with open('cat_to_name.json', 'r') as f:
cat_to_name = json.load(f)
Tasks:
- Use a pre-trained network (e.g., VGG) as the feature extractor.
- Define and train a new feed-forward network as the classifier.
- Tune hyperparameters and track performance on the validation set.
Note: Ensure only the classifier layers are trained while the pre-trained network weights remain frozen.
# TODO: Build and train the classifier
Evaluate the trained network on the test dataset to estimate its performance on unseen data. Aim for an accuracy around 70%.
# TODO: Evaluate the network on the test set
Save the trained model, including the classifier weights and additional information such as class-to-index mappings. This allows for easy future use and inference.
# TODO: Save the model checkpoint
Write a function to load the model checkpoint and reconstruct the model for future use.
# TODO: Implement checkpoint loading function
Tasks:
- Write a function to preprocess images for the model: resizing, cropping, normalizing.
- Convert images from PIL format to PyTorch tensors.
def process_image(image):
''' Preprocess a PIL image for a PyTorch model. '''
# TODO: Implement preprocessing steps
Implement a function to use the trained model for making predictions. This function should return the top K probable classes and their associated probabilities.
def predict(image_path, model, topk=5):
''' Predict the class of an image using a trained model. '''
# TODO: Implement prediction function
Verify the model's predictions by displaying the input image along with the top 5 predicted classes and their probabilities. Use Matplotlib to visualize these results.
# TODO: Display an image along with the top 5 predicted classes
- Ensure that the workspace remains active during long-running tasks to prevent disconnection.
- If the model checkpoint exceeds 1 GB, consider reducing the size of the hidden layers to avoid saving issues.