From 333d0bd4800acc56fd969e77f2c9ec3a88a79cd3 Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Mon, 21 Oct 2024 14:39:25 -0400 Subject: [PATCH] Update to how gradientCheckpointing works. --- WORKSPACE | 4 ++-- deps.bzl | 4 ++-- nnc/Model.swift | 17 ++++++++++++++--- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index a4bfb36160c..29681ab91aa 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -5,9 +5,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") git_repository( name = "ccv", - commit = "e7c76392ec4ad529797484c6759fd2f8da974d99", + commit = "20d998dc3c7008060df6fdfa17e932c9dae93d31", remote = "https://github.com/liuliu/ccv.git", - shallow_since = "1728857981 -0400", + shallow_since = "1729535387 -0400", ) load("@ccv//config:ccv.bzl", "ccv_deps", "ccv_setting") diff --git a/deps.bzl b/deps.bzl index 99e700b2d62..2369ce96da2 100644 --- a/deps.bzl +++ b/deps.bzl @@ -17,8 +17,8 @@ def s4nnc_deps(): git_repository, name = "ccv", remote = "https://github.com/liuliu/ccv.git", - commit = "e7c76392ec4ad529797484c6759fd2f8da974d99", - shallow_since = "1728857981 -0400", + commit = "20d998dc3c7008060df6fdfa17e932c9dae93d31", + shallow_since = "1729535387 -0400", ) _maybe( diff --git a/nnc/Model.swift b/nnc/Model.swift index e2e3eff8a4d..339a792e804 100644 --- a/nnc/Model.swift +++ b/nnc/Model.swift @@ -205,12 +205,23 @@ public class Model { * Whether to enable gradient checkpointing for this model. Once it is enabled, we will re-run * the model forward pass again during backward pass. This is effective at reducing memory usage. */ - public var gradientCheckpointing: Bool { + public var gradientCheckpointing: Bool? { get { - ccv_cnnp_model_gradient_checkpointing(cModel) != 0 + let value = ccv_cnnp_model_gradient_checkpointing(cModel) + if value == 1 { + return true + } else if value == -1 { + return false + } else { + return nil + } } set { - ccv_cnnp_model_set_gradient_checkpointing(cModel, newValue ? 1 : 0) + if let newValue = newValue { + ccv_cnnp_model_set_gradient_checkpointing(cModel, newValue ? 1 : -1) + } else { + ccv_cnnp_model_set_gradient_checkpointing(cModel, 0) + } } }