The research paper associated with this project can be found here: Transferability Prediction for Model Recommendation: A Graph Learning based Approach
Transfer learning has emerged as a popular approach to improv- ing the performance of target tasks using knowledge from related source models. However, it remains challenging to select the most suitable pre-trained models, especially with multiple heterogeneous model architectures. As a solution to this scenario, in this paper, we formulate the transferability prediction problem as a bipartite graph learning problem, and propose a source data-free transferability prediction method. Leveraging task metadata as node information, a customized Graph Attention Network (GAT) is employed to predict the transferability among tasks, i.e., the graph edges. By learning low-dimensional task embeddings and predicting edges in the latent space, our method effectively captures intrinsic task relationships and infers transferability for unseen task pairs. Experimental re- sults on general image classification tasks and segmentation tasks in autonomous driving scenarios demonstrate the effectiveness of our method, showcasing improvements of 15% and 12% compared to state-of-the-art methods, respectively.
- Model Metadata (
source_model_meta_data.csv
): Contains metadata for each model, detailing aspects such as architecture, input size, model capacity, complexity, pre-trained datasets, and performance metrics. - Dataset Metadata (
target_dataset_meta_data.csv
): Includes metadata for various datasets, describing features like the number of classes, categories, sample size, and domain/modality specifics. - Training Records (
record_domainnet_10_270.csv
): Historical records of training sessions, stored in CSV format, providing insights into the training dynamics and outcomes. - Transferability Scores (
LogME.csv
): A matrix of transferability scores across multiple models and tasks, facilitating comparative studies and benchmarking. - Partially visualization (
./images/transferability_matrix.jpg
): A graphical representation of the transferability relationships between different models and tasks.
Each model in our repository is characterized by the following attributes:
- Architecture: Design of the model crucial for feature capture and generalization.
- Input Size: Dimensionality of model inputs, impacting information processing capabilities.
- Model Capacity: Number of parameters, indicative of learning and representational power.
- Model Complexity: Memory consumption used as a proxy for complexity.
- Pre-trained Dataset: The dataset on which the model was initially trained, affecting its biases and pre-existing knowledge.
- Model Performance: Empirical measure of model accuracy on benchmark datasets.
Below is a table describing some of the models included in this repository:
Architecture Family | Model Name | # of Parameters | acc@1 (on ImageNet-1K) | GFLOPS |
---|---|---|---|---|
convnext | convnext_small | 50,223,688 | 83.616 | 8.68 |
convnext | convnext_tiny | 28,589,128 | 82.52 | 4.46 |
densenet | densenet121 | 7,978,856 | 74.434 | 2.83 |
... | ... | ... | ... | ... |
resnet | resnet101 | 68,883,240 | 78.468 | 11.4 |
resnet | resnet50 | 25,557,032 | 80.858 | 4.09 |
resnet | resnet18 | 11,689,512 | 69.758 | 1.81 |
resnet | wide_resnet50_2 | 44,549,160 | 81.886 | 7.8 |
Datasets are described by:
- Number of Predefined Classes: Reflects the categorical complexity.
- Categories: High-level description crucial for domain-specific tasks.
- Sample Size: Total number of samples, important for training efficacy.
- Domain: Specifies the particular field or environment that the dataset represents, which can influence model performance and applicability.
- Labels: Detailed descriptions of the labels within the dataset, providing insights into the types of outputs the model needs to handle and their possible relationships.
Below is a table describing some of the datasets included in this repository:
Dataset Name | Num Predefined Classes | Domain | Categories | Labels | Train Sample Size |
---|---|---|---|---|---|
1_painting_2 | 10 | painting | 2 | saxophone, flying_saucer | 20 |
2_painting_2 | 10 | painting | 2 | leaf, jail | 20 |
3_painting_2 | 10 | painting | 2 | flying_saucer, jail | 20 |
... | ... | ... | ... | ... | ... |
269_clipart_32 | 50 | clipart | 32 | tornado, harp, ..., birthday_cake | 1471 |
270_clipart_32 | 50 | clipart | 32 | spoon, flying_saucer, ..., paint_can | 1494 |
In the context of pre-trained model selection, the primary criterion focuses on evaluating the transferability of a source model to a target task. Consider the source task (M_S = {(m_i, )}{i=1}^m), target task (D_T = {(x_T^i, y_T^i)}{i=1}^n), and the model (M = (\theta, h_s)) trained on the source task (D_S), where (\theta) represents the feature extraction layer for input (x_s), and (h_s) denotes the decoder layer that maps the extracted features to the label (y_s). This approach employs the widely used transfer learning strategy known as "Retrain head", which retains the parameters of the source task's feature extraction layer (\theta) while finetuning to optimize the parameters of the target task's decoder readout function (h_t). Empirical transferability is defined by the average accuracy of the source model when applied to the target task.
This equation represents the expected accuracy of the model (M), when its pretrained feature extractor (\theta) is combined with a newly trained decoder (h_t) on the target dataset (D_T).
Below is a graphical representation of the transferability relationships between different models and tasks:
The embedding.ipynb
notebook is designed to preprocess the metadata of models and datasets, performing tasks such as one-hot encoding and normalization to prepare the data for further analysis.
The source embedding process uses the target dataset metadata to generate an embedding file, Model_embedding.npy
. This file contains the processed embeddings of model metadata which are used in assessing the transferability of models to various target tasks.
Command to generate source embedding:
python embedding.ipynb --input target_metadata.csv --output Model_embedding.npy
Use target data through embedding.ipynb to generate the Model_embedding.npy file.
Use source data through embedding.ipynb to generate the Dataset_embedding.npy file.
Incorporates model scoring metrics like LogME, LEEP, and H-Score as indicators of model transferability.
- Graph Attention Network
- run python scgat.py can get the paper results.