This project contains a Python script for training a simple object detection model using TensorFlow. The model is trained to recognize and locate MNIST digits placed on a larger canvas.
The main.py script performs the following steps:
- Loads the MNIST dataset using
tensorflow_datasets. - For each digit, it creates a larger image (75x75) and places the 28x28 digit at a random location.
- It defines a convolutional neural network (CNN) model with two outputs:
- A classification head to predict the digit (0-9).
- A regression head to predict the bounding box coordinates of the digit.
- The model is compiled with losses and metrics for both classification and bounding box regression.
- The script trains the model on the generated dataset.
- After training, it evaluates the model's performance, plots training metrics, and displays some predictions with bounding boxes.
The script requires the following Python libraries:
tensorflow
tensorflow-datasets
numpy
matplotlib
Pillow
You can install them using pip:
pip install -r requirements.txtTo run the script, execute the following command in your terminal:
python main.pyThe script will start the training process, and you will see the model's progress for each epoch. After training is complete, it will display evaluation results and show plots for the training metrics.
The model consists of three main parts:
- Feature Extractor: A simple CNN with three
Conv2Dlayers to extract features from the input image. - Classifier: A dense layer with a softmax activation function to predict the class of the digit.
- Bounding Box Regressor: A dense layer with a linear activation to predict the coordinates of the bounding box.