Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ASA core modules #641

Open
wants to merge 443 commits into
base: develop
Choose a base branch
from
Open

ASA core modules #641

wants to merge 443 commits into from

Conversation

balos1
Copy link
Member

@balos1 balos1 commented Jan 14, 2025

This PR adds the core modules that will support adjoint sensitivity analysis in a package-agnostic way.

Copy link
Member

@gardner48 gardner48 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partway through src

CHANGELOG.md Outdated Show resolved Hide resolved
doc/shared/figs/sunadjoint_ckpt_fixed.png Outdated Show resolved Hide resolved
doc/shared/sunadjoint/SUNAdjointCheckpointScheme.rst Outdated Show resolved Hide resolved
Comment on lines +214 to +217
}

SUNCheckCall(SUNDataNode_HasChildren(step_data_node, &has_children));
if (!has_children)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't this be

Suggested change
}
SUNCheckCall(SUNDataNode_HasChildren(step_data_node, &has_children));
if (!has_children)
}
else

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure which line this is referring to.

Copy link
Member

@gardner48 gardner48 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Continuing through src

test/unit_tests/sundials/CMakeLists.txt Outdated Show resolved Hide resolved
test/unit_tests/sundials/CMakeLists.txt Outdated Show resolved Hide resolved

if (!(step_num % IMPL_MEMBER(self, interval)))
{
if (stage_num == 0) { *yes_or_no = SUNTRUE; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be z[0] rather than y_n for methods with an implicit first stage.

return SUN_ERR_CHECKPOINT_NOT_FOUND;
}

SUNCheckCall(SUNDataNode_GetDataNvector(solution_node, *yout, t));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this functions takes an N_Vector why is the input to this function N_Vector*?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand this comment.

{
SUNFunctionBegin(self->sunctx);

void* queue = NULL;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and below, I think a NULL queue will cause an error with SYCL

Comment on lines +309 to +310
SUNMemoryType buffer_mem_type = N_VGetDeviceArrayPointer(v) ? SUNMEMTYPE_DEVICE
: SUNMEMTYPE_HOST;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and below, user supplied vectors my not supply N_VGetDeviceArrayPointer but still use device data. I think we'll need a new NVector function to query what memory space the data is stored in.


void* queue = NULL;

SUNMemoryType leaf_mem_type = SUNMEMTYPE_HOST;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be leaf_data->type?

if (leaf_mem_type == buffer_mem_type)
{
sunrealtype* data_ptr = leaf_data->ptr;
*t = data_ptr[0];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and below, can we assume data_ptr is host data or should this use the memory helper copy?

if (leaf_mem_type == buffer_mem_type)
{
sunrealtype* data_ptr = leaf_data->ptr;
data_ptr[0] = t;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and below, can we assume data_prtr is host data or does this need to use the memory helper copy?

Copy link
Member

@gardner48 gardner48 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finished pass over src, starting on test

src/sundials/sundials_adjointstepper.c Outdated Show resolved Hide resolved
return SUN_SUCCESS;
}

SUNErrCode SUNAdjointStepper_ReInit(SUNAdjointStepper self, N_Vector y0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[discussion] The current behavior is like ReInit, but we do not have a SUNStepperReInit function. So the user would need to reinitialize the forward and adjoint steppers. We could add SUNStepperReInit with the same signature of SUNStepperReset. This would require some updates to the ARKODE ReInit functions that would be wrapped by a generic ReInit to allow for NULL inputs in order to retain the current RHS functions (and some other changes handle NULL inputs).


while ((direction == -one && t > tout) || (direction == one && t < tout))
{
SUNCheckCall(SUNAdjointStepper_OneStep(self, tout, sens, tret));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To get interpolated output at the desired time this should use SUNStepper_Evolve (or we'll need to add a SUNStepper function to get interpolated output)


SUNErrCode retcode = SUN_SUCCESS;
sunrealtype t = self->tf;
SUNCheckCall(SUNStepper_OneStep(adj_sunstepper, tout, sens, &t));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is managing integrating to and loading checkpoints handled inside the SUNStepper wrapper? With the addition of SUNStepper_ResetCheckpointIndex I would have expected to see that in the SUNAdjointStepper functions for Evolve and OneStep.

self->last_flag = adj_sunstepper->last_flag;

self->step_idx--;
self->nst++;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to store nst or can we get this from the adj_sunstepper?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check SUNStlVector_* return codes?

int retval;
int64_t idx;
int64_t retval;
sunbooleantype collision;

if (map == NULL || key == NULL || value == NULL) { return (-1); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (map == NULL || key == NULL || value == NULL) { return (-1); }
if (map == NULL || key == NULL || value == NULL) { return SUNHASHMAP_ERROR; }

/* Extract the elapsed times from the hash map */
SUNHashMap_Values(p->map, (void***)&values, sizeof(sunTimerStruct));
sunTimerStruct* reduced =
(sunTimerStruct*)malloc(p->map->size * sizeof(sunTimerStruct));
for (i = 0; i < p->map->size; ++i) { reduced[i] = *values[i]; }
(sunTimerStruct*)malloc(map_size * sizeof(sunTimerStruct));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check for malloc fail

@@ -507,7 +527,7 @@ SUNErrCode sunCollectTimers(SUNProfiler p)
MPI_Op_free(&MPI_sunTimerStruct_MAXANDSUM);

/* Update the values that are in this rank's hash map. */
for (i = 0; i < p->map->size; ++i)
for (int64_t i = 0; i < map_size; ++i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (int64_t i = 0; i < map_size; ++i)
for (int i = 0; i < map_size; ++i)

@@ -320,23 +449,37 @@ SUNErrCode SUNHashMap_Sort(SUNHashMap map, SUNHashMapKeyValue** sorted,
**Returns:**
* A SUNErrCode indicating success or a failure
*/
#if SUNDIALS_MPI_ENABLED
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was needed to avoid an unused function warning because this is only called when MPI is enabled (or at least it was before)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants