diff --git a/src/Tree.js b/src/Tree.js index ee845fc..1c17c9b 100644 --- a/src/Tree.js +++ b/src/Tree.js @@ -6,6 +6,14 @@ import {FlattenedNode} from './shapes/nodeShapes'; import TreeState, {State} from './state/TreeState'; export default class Tree extends React.Component { + constructor(props) { + super(props); + this.state = { + topStickyHeader: null, + }; + this._listRef = React.createRef(); + } + _cache = new CellMeasurerCache({ fixedWidth: true, minHeight: 20, @@ -35,8 +43,66 @@ export default class Tree extends React.Component { : nodes[index]; }; + isGroupHeader = node => { + return node.children && node.children.length > 0 && node.deepness === 0; + }; + + componentDidMount() { + if (this._listRef.current) { + const list = this._listRef.current; + const grid = list && list.Grid; + if (grid) { + this.handleScroll({ + scrollTop: grid.state.scrollTop, + }); + } + } + } + + getAllHeaders = () => { + const rowCount = this.getRowCount(); + const headers = []; + let cumulativeHeight = 0; + + for (let i = 0; i < rowCount; i++) { + const node = this.getNode(i); + + if (this.isGroupHeader(node)) { + headers.push({ + node, + index: i, + top: cumulativeHeight, + }); + } + + cumulativeHeight += this._cache.rowHeight({index: i}); + } + + return headers; + }; + + handleScroll = ({scrollTop}) => { + if (!this._listRef.current) return; + + const allHeaders = this.getAllHeaders(); + + const topStickyHeader = allHeaders.filter(h => h.top <= scrollTop).pop() || null; + + const currentStickyId = + this.state.topStickyHeader && this.state.topStickyHeader.node && this.state.topStickyHeader.node.id; + const newStickyId = topStickyHeader && topStickyHeader.node && topStickyHeader.node.id; + + if (currentStickyId !== newStickyId) { + this.setState({ + topStickyHeader, + }); + } + }; + rowRenderer = ({node, key, measure, style, NodeRenderer, index}) => { const {nodeMarginLeft} = this.props; + const isHeader = this.isGroupHeader(node); + const className = isHeader ? 'tree-group-header' : ''; return ( + ); + }; + + renderStickyHeader = () => { + const {topStickyHeader} = this.state; + if (!topStickyHeader) return null; + + const {NodeRenderer, nodeMarginLeft} = this.props; + const index = topStickyHeader.index; + const currentNode = this.getNode(index); + + return ( + ); }; @@ -66,25 +163,58 @@ export default class Tree extends React.Component { ); }; + componentDidUpdate(prevProps) { + if (prevProps.nodes !== this.props.nodes) { + this._cache.clearAll(); + if (this._listRef.current) { + this._listRef.current.recomputeRowHeights(); + } + + this.forceUpdate(); + } + } + render() { const {nodes, width, scrollToIndex, scrollToAlignment} = this.props; + const {topStickyHeader} = this.state; + const stickyHeaderHeight = topStickyHeader ? this._cache.rowHeight({index: topStickyHeader.index}) : 0; return ( - - {({height, width: autoWidth}) => ( - (this._list = r)} - height={height} - rowCount={this.getRowCount()} - rowHeight={this._cache.rowHeight} - rowRenderer={this.measureRowRenderer(nodes)} - width={width || autoWidth} - scrollToIndex={scrollToIndex} - scrollToAlignment={scrollToAlignment} - /> + + {topStickyHeader && ( + + {this.renderStickyHeader()} + )} - + + + {({height, width: autoWidth}) => ( + + )} + + ); } }