diff --git a/setup_tpu.sh b/setup_tpu.sh index 36fe071..3c6896f 100755 --- a/setup_tpu.sh +++ b/setup_tpu.sh @@ -106,4 +106,46 @@ DefaultLimitNOFILE=infinity EOF" # Reload the systemd configuration -sudo systemctl daemon-reload \ No newline at end of file +sudo systemctl daemon-reload + +# Check for --mount-gcs argument +for arg in "$@" +do + case $arg in + --mount-gcs=*) + GCS_BUCKET="${arg#*=}" + shift + ;; + --dev) + DEV_MODE=true + shift + ;; + esac +done + +if [ -n "$GCS_BUCKET" ]; then + # URL of the file to download + FILE_URL="https://raw.githubusercontent.com/AshishKumar4/FlaxDiff/main/datasets/gcsfuse.sh" + # Local path to save the downloaded file + LOCAL_FILE="gcsfuse.sh" + + # Download the file + curl -o $LOCAL_FILE $FILE_URL + + # Make the script executable + chmod +x $LOCAL_FILE + + # Run the script with the specified arguments + ./$LOCAL_FILE DATASET_GCS_BUCKET=$GCS_BUCKET MOUNT_PATH=/mnt/gcs_mount +fi + +if [ "$DEV_MODE" = true ]; then + # Create 'research' directory in the home folder + mkdir -p $HOME/research + + # Clone the repository into the 'research' directory + git clone git@github.com:AshishKumar4/FlaxDiff.git $HOME/research +else + # Download the training.py file into the home folder + wget -O $HOME/training.py https://github.com/AshishKumar4/FlaxDiff/raw/main/training.py +fi \ No newline at end of file diff --git a/training_tpu.py b/training.py similarity index 99% rename from training_tpu.py rename to training.py index 3e3b2b0..fbdd0ff 100644 --- a/training_tpu.py +++ b/training.py @@ -138,7 +138,7 @@ def data_source(): return data_source -def data_source_cc12m(source="/home/mrwhite0racle/research/FlaxDiff/datasets/gcs_mount/arrayrecord/cc12m/"): +def data_source_cc12m(source="/mnt/gcs_mount/arrayrecord/cc12m/"): def data_source(): cc12m_records_path = source cc12m_records = [os.path.join(cc12m_records_path, i) for i in os.listdir(