-
Couldn't load subscription status.
- Fork 65
FXC-3351-Updates to speed up some mode data operations #2927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 1 comment
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
Summary
tidy3d/components/data/monitor_data.pyLines 773-781 773 try:
774 for field_self, field_other in zip(fields_self.values(), fields_other.values()):
775 for key in self._tangential_dims:
776 if not np.all(field_self.coords[key] == field_other.coords[key]):
! 777 raise ValueError("Coordinates do not match.")
778
779 # Tangential coordinates match, so we try to use .values for speed.
780 # This will work if other coordinates match dimensions or are broadcastable.
781 # This is OK as we do not enforce frequencies or mode indexes to be the same.Lines 2267-2275 2267 arr = np.moveaxis(data.data, move_order, range(data.ndim))
2268 lead_shape = arr.shape[:-2]
2269 nf, nm = arr.shape[-2], arr.shape[-1]
2270 if nf != num_freqs or nm != num_modes:
! 2271 raise DataError(
2272 "sort_inds_2d shape does not match array shape in _apply_mode_reorder."
2273 )
2274
2275 # Vectorized gather: reshape to (nf, Nlead, nm), gather along last axis |
c9566b8 to
ef492c1
Compare
|
@greptile give this another go buddy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the code and the changes, I can see you're trying to optimize mode data operations. Let me provide a concrete suggestion for the section that needs fixing.
The issue is in the _apply_mode_reorder method where you're trying to use np.take_along_axis with broadcasting, but the shapes don't align correctly. Here's the fix:
sort_inds_2d = np.asarray(sort_inds_2d, dtype=int)
num_freqs, num_modes = sort_inds_2d.shape
# Fast no-op
identity = np.arange(num_modes)
if np.all(sort_inds_2d == identity[None, :]):
return self
modify_data = {}
new_mode_index_coord = identity
for key, data in self.data_arrs.items():
if "mode_index" not in data.dims or "f" not in data.dims:
continue
dims_orig = tuple(data.dims)
# Preserve coords (as numpy)
coords_out = {
k: (v.values if hasattr(v, "values") else np.asarray(v))
for k, v in data.coords.items()
}
f_axis = data.get_axis_num("f")
m_axis = data.get_axis_num("mode_index")
# Move axes so array is (..., f, mode)
move_order = [ax for ax in range(data.ndim) if ax not in (f_axis, m_axis)] + [
f_axis,
m_axis,
]
arr = np.moveaxis(data.values, move_order, range(data.ndim))
lead_shape = arr.shape[:-2]
nf, nm = arr.shape[-2], arr.shape[-1]
if nf != num_freqs or nm != num_modes:
raise DataError(
"sort_inds_2d shape does not match array shape in _apply_mode_reorder."
)
# Vectorized gather: reshape to (Nlead, nf, nm), gather along last axis
arr_reshaped = arr.reshape((-1, nf, nm))
# Expand sort_inds_2d to (Nlead, nf, nm)
inds = np.broadcast_to(sort_inds_2d[None, :, :], arr_reshaped.shape)
arr_sorted_flat = np.take_along_axis(arr_reshaped, inds, axis=2)
arr_sorted = arr_sorted_flat.reshape(*lead_shape, nf, nm)
# Move axes back to original order
arr_sorted = np.moveaxis(arr_sorted, range(data.ndim), move_order)
# Update coords: keep f, reset mode_index to 0..num_modes-1
coords_out["mode_index"] = new_mode_index_coord
coords_out["f"] = data.coords["f"].values
modify_data[key] = DataArray(arr_sorted, coords=coords_out, dims=dims_orig)
return self.updated_copy(**modify_data)
The key changes:
- Changed
data.datatodata.valuesfor consistency - Fixed the reshaping to
(Nlead, nf, nm)instead of transposing - Used
np.broadcast_toto expandsort_inds_2dto match the full array shape - Applied
np.take_along_axiswith proper broadcasting
ef492c1 to
b972ffa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just a couple minor (?) comments
| raise ValueError("Coordinates do not match.") | ||
| return fields | ||
| except: # noqa: E722 | ||
| # Catching a broad exception here in case anything went wrong in the check. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we even need try-catch here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, but I do like the simplicity of try-except for control flow. Since I have a nested loop here, it's kind of verbose to control with a boolean and breaks.
I did change the except to explicitly only catch ValueError though. How does that sound?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wary of about using try/except for control flow. What about a helper method or 2 to simplify the steps?
| f_axis = data.get_axis_num("f") | ||
| m_axis = data.get_axis_num("mode_index") | ||
|
|
||
| # Move axes so array is (..., f, mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we move axes to (f, ..., mode) to begin with?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is to not assume some specific axes order for robustness. Is your suggestion just so we can avoid the reorder after the sorting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like there could be up to 4 reorders: (orig) -> (..., f, mode) -> (f, ..., mode) -> (..., f, mode) -> (orig). So, I was wondering if it can be reduced to 2: (orig) -> (f, ..., mode) -> (orig).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah got it will give it another look tomorrow!
b972ffa to
884f8b4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, but I do think trying to avoid those try/except usages would be better.
| try: | ||
| # If coords already match, just return the tangential fields directly. | ||
| # Using try: except for flow control. | ||
| for field in fields.values(): | ||
| for idim, dim in enumerate(self._tangential_dims): | ||
| if field.coords[dim].values.size != coords[idim].size or not np.all( | ||
| field.coords[dim].values == coords[idim] | ||
| ): | ||
| raise ValueError("Coordinates do not match.") | ||
| return fields | ||
| except ValueError: | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here I think bool logic is preferable. From my understand raising an exception can also be slower. In this case, the interpolation is the expensive part so not a big problem, but good to avoid I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A method that can determine whether two monitor data have matching tangential fields would be useful when I get back to the dot_yee/flux_yee stuff as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could switch this to bool and split this out in a separate method. What do you think about the other one? There I'm more worried about something erroring that I haven't anticipated (like e.g. due to dimensions order) so I kind of want it to fallback to the xarray handling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well I guess that would indicate a bug we need to fix. I am trying to think how it could fail, from your check the coords must be identical right? Same shape and order I would think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, only the tangential dims are checked
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah because this is to speed up e.g. overlap computations of mode data at one frequency with the same mode data at a different frequency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ultimately, I will let you decide!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I think switching this one to if-based, but leaving the other one as try-catch seems like a good compromise at the moment.
But here's some idea for the one falling back to xarray: if we know fairly broad, but maybe not exhaustive, conditions under which we expect with certainty the optimized calculation to complete successfully maybe it's worth to switching to still if-based branching. That is,
if (broad but not exhaustive ok):
# direct calc
else:
# xarray calc
Yes, we won't get the most performance in all possible cases, but hopefully in overwhelming majority. Again, not insisting on this though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's basically how I had it previously, I guess I just decided to be somewhat defensive against something randomly erroring in a way that I haven't anticipated and was not captured by tests. But I do agree with you guys that hiding errors is not great either.
As discussed here I made some changes to speed up the overlap sort.
monitor_data._updatedas we now supportupdated_copy(deep=False), which is both better documented and seems to work faster.Greptile Overview
Updated On: 2025-10-24 13:20:46 UTC
Greptile Summary
Optimized mode data operations by replacing the internal
_updated()method with direct calls toupdated_copy(deep=False, validate=False)to avoid unnecessary deep copying and validation overhead.Key changes:
_updated()method fromMonitorDataclass which useddict()+parse_obj()approachupdated_copy()usingdeep=Falseandvalidate=Falseflags in 5 locations acrossmonitor_data.pyand 1 location inmode_solver.py_interpolated_tangential_fields()that checks if coordinates already match usingnp.allclose()before interpolationoverlap_sort()efficiency by usingsymmetry_expandeddata once and removing unnecessary coordinate reassignment via_assign_coords()The changes maintain the same shallow-copy semantics while using the more standard
updated_copy()API with appropriate flags, resulting in cleaner code and improved performance.Confidence Score: 4/5
_updated()method with the standardupdated_copy()API. The refactoring maintains the same shallow-copy semantics and improves code consistency. However, the new early-exit optimization in_interpolated_tangential_fields()usesnp.allclose()for coordinate comparison, which is appropriate for floating-point comparisons but should be monitored to ensure it doesn't skip interpolation when coordinates are very close but not identical enough for the use case.monitor_data.py:800-808where the new coordinate matching logic was addedImportant Files Changed
File Analysis
_updated()method withupdated_copy(deep=False, validate=False)for performance; added early-exit optimization in_interpolated_tangential_fields(); optimizedoverlap_sort()to usesymmetry_expandedand removed unnecessary coordinate reassignment_updated()method call withupdated_copy(deep=False, validate=False)for consistency with monitor_data.py changesSequence Diagram
sequenceDiagram participant Client participant ModeData participant AbstractFieldData participant ModeSolver Note over ModeData: overlap_sort() optimization Client->>ModeData: overlap_sort(track_freq, overlap_thresh) ModeData->>AbstractFieldData: symmetry_expanded AbstractFieldData-->>ModeData: data_expanded (self or new copy) alt No symmetry (returns self) Note over AbstractFieldData: Returns self directly else Has symmetry AbstractFieldData->>AbstractFieldData: updated_copy(deep=False, validate=False) Note over AbstractFieldData: Creates shallow copy with expanded fields end ModeData->>ModeData: dot(data_expanded, conjugate) Note over ModeData: Compute self-overlap loop For each frequency direction ModeData->>ModeData: _isel(f=[freq_id]) Note over ModeData: Extract frequency slice ModeData->>ModeData: _find_ordering_one_freq() Note over ModeData: Uses dot() instead of outer_dot() end ModeData->>ModeData: updated_copy(deep=False, validate=False) ModeData-->>Client: Sorted mode data Note over ModeSolver: _colocate_data() optimization Client->>ModeSolver: _colocate_data(mode_solver_data) ModeSolver->>ModeSolver: updated_copy(deep=False, validate=False) Note over ModeSolver: Replaces _updated() method ModeSolver-->>Client: Colocated data Note over AbstractFieldData: _interpolated_tangential_fields() optimization Client->>AbstractFieldData: _interpolated_tangential_fields(coords) AbstractFieldData->>AbstractFieldData: Check if coords match using np.allclose() alt Coords already match AbstractFieldData-->>Client: Return fields directly (early exit) else Coords don't match AbstractFieldData->>AbstractFieldData: Interpolate fields AbstractFieldData-->>Client: Return interpolated fields end