Skip to content

Commit

Permalink
feat: added option to setup gcsfuse and download training code in set…
Browse files Browse the repository at this point in the history
…up_tpu.sh
  • Loading branch information
AshishKumar4 committed Aug 2, 2024
1 parent 1768cf1 commit ad27b94
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
44 changes: 43 additions & 1 deletion setup_tpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,46 @@ DefaultLimitNOFILE=infinity
EOF"

# Reload the systemd configuration
sudo systemctl daemon-reload
sudo systemctl daemon-reload

# Check for --mount-gcs argument
for arg in "$@"
do
case $arg in
--mount-gcs=*)
GCS_BUCKET="${arg#*=}"
shift
;;
--dev)
DEV_MODE=true
shift
;;
esac
done

if [ -n "$GCS_BUCKET" ]; then
# URL of the file to download
FILE_URL="https://raw.githubusercontent.com/AshishKumar4/FlaxDiff/main/datasets/gcsfuse.sh"
# Local path to save the downloaded file
LOCAL_FILE="gcsfuse.sh"

# Download the file
curl -o $LOCAL_FILE $FILE_URL

# Make the script executable
chmod +x $LOCAL_FILE

# Run the script with the specified arguments
./$LOCAL_FILE DATASET_GCS_BUCKET=$GCS_BUCKET MOUNT_PATH=/mnt/gcs_mount
fi

if [ "$DEV_MODE" = true ]; then
# Create 'research' directory in the home folder
mkdir -p $HOME/research

# Clone the repository into the 'research' directory
git clone git@github.com:AshishKumar4/FlaxDiff.git $HOME/research
else
# Download the training.py file into the home folder
wget -O $HOME/training.py https://github.com/AshishKumar4/FlaxDiff/raw/main/training.py
fi
2 changes: 1 addition & 1 deletion training_tpu.py → training.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def data_source():
return data_source


def data_source_cc12m(source="/home/mrwhite0racle/research/FlaxDiff/datasets/gcs_mount/arrayrecord/cc12m/"):
def data_source_cc12m(source="/mnt/gcs_mount/arrayrecord/cc12m/"):
def data_source():
cc12m_records_path = source
cc12m_records = [os.path.join(cc12m_records_path, i) for i in os.listdir(
Expand Down

0 comments on commit ad27b94

Please sign in to comment.