Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Nov 2, 2024
1 parent d9447b4 commit b7ca169
Showing 1 changed file with 55 additions and 6 deletions.
61 changes: 55 additions & 6 deletions csrc/python_frontend/segmentation.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace nvfuser::python_frontend {

class FusionDefinition;

//! ===========================================================================
//
//! setupSegmentation runs the segmentation algorithm on CPP Fusion to create
//! SegmentedFusion. It returns the number of segments in SegmentedFusion.
//!
Expand All @@ -27,11 +29,56 @@ class FusionDefinition;
//! from cloned Vals to original fusion state index.
//! 4) Get extents for cloned fusion
//! 5) Create SchedulerRuntimeInfo
//! 6) Run segmentation algorithm usin cloned fusion, input arguments, and
//! 6) Run segmentation algorithm using cloned fusion, input arguments, and
//! scheduler runtime infomation.
//! 7) Get sequential order of fusion segments using prepareGroupOrder.
//! 8) Return the number of segments created by segmentation algorithm.
//! ===========================================================================
//!
//! buildSegment creates the CPP Fusion for a given segment id, translate it to
//! the python FusionDefinition, then returns a mapping from segment fusion
//! state indices to the original fusion state indices.
//!
//! Why do we need a map from the segment's fusion index space to the original
//! fusion index space?
//!
//! * The original FusionDefinition is decomposed into a sequence of segment
//! FusionDefinitions.
//! * Each FusionDefinition has an independent index space.
//! * At runtime, the original FusionDefinition acts an argument manager,
//! gathering input arguments and storing output results.
//! * To perform this function, it requires a map from the segment index space
//! to the original index space.
//!
//! NOTE: Steps 4a through 4d are run for every fusion segment. However,
//! sometimes the python definition needs the extents of the original fusion's
//! input tensors as extra arguments.
//!
//! Details:
//! 1) Use segment id to get SegmentedGroup from group_run_order_.
//! 2) Create CPP Fusion for SegmentedGroup.
//! * IrCloner acts as a map from fusion segment to the original fusion.
//! 3) Translate CPP Fusion to Python FusionDefinition
//! 4) Create map from segment fusion indices to original fusion indices.
//! a) Get original Vals for SegmentedGroup's inputs and outputs.
//! b) Map original Vals to their original fusion indices.
//! c) Map original Vals to their segment Vals
//! d) Map segment Vals to their fusion indices.
//! e) Return map if the number of input arguments for python definition
//! matches the number of input arguments for CPP fusion.
//! f) Create a map from segment to original extents.
//! g) Create a map from segment fusion indices to original extents.
//! h) Find segment inputs that are missing from segment to original
//! indices map.
//! i) Get segment CPP Vals for the missing segment fusion indices.
//! j) Map segment CPP Vals to original CPP Vals.
//! k) Map original CPP Vals to their corresponding fusion indices.
//! l) Add missing mappings to segment to original indices map.
//! 5) Return the mapping from the segmented FusionDefinition index space to
//! original FusionDefinition index space.
//!
//! ===========================================================================
//
//! prepareGroupOrder is similar to prepareRuntimeOrder. It generates the
//! sequential order of SegmentedGroups in SegmentedFusion.
//!
Expand All @@ -48,22 +95,24 @@ class FusionDefinition;
//! 9) End For
//! 10) Fail if none of the SegmentedGroups are available to run.
//! 11) End While
//! ===========================================================================
class SegmentationState {
public:
//! Run segmentation algorithm on FusionDefinition.
// Run segmentation algorithm on FusionDefinition.
int64_t setupSegmentation(
Fusion* fusion,
const std::unordered_map<const Val*, int64_t>& map_value_to_original_fid,
const at::ArrayRef<c10::IValue>& inputs);

//! Given SegmentedFusion and vector of FusionDefinition objects for the
//! fusion segments, create the fusion segments and clone their state to the
//! FusionDefinitions.
// Given an empty FusionDefinition and a segment id, buildSegment creates the
// CPP Fusion, translates it to the python FusionDefinition, then return a
// mapping from segment fusion state indices to the original fusion state
// indices.
NVF_API std::unordered_map<int64_t, int64_t> buildSegment(
FusionDefinition& other,
int64_t segment_id);

//! Perform a topological sort on SegmentedFusion to segment order.
// Perform a topological sort on SegmentedFusion to segment order.
void prepareGroupOrder();

private:
Expand Down

0 comments on commit b7ca169

Please sign in to comment.