Skip to content

Commit

Permalink
Replace mjQUICKSORT with faster, native sorting function. fixes #1638
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688519096
Change-Id: I05ab1576703c2458968ef81915ac682a8beac391
  • Loading branch information
kbayes authored and copybara-github committed Oct 22, 2024
1 parent 9e1aa37 commit 2b0629d
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 82 deletions.
43 changes: 25 additions & 18 deletions src/engine/engine_collision_driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,8 @@ int mj_isElemActive(const mjModel* m, int f, int e) {
//----------------------------- collision detection entry point ------------------------------------

// compare contact pairs by their geom/elem/vert IDs
quicksortfunc(contactcompare, context, el1, el2) {
static inline int contactcompare(const mjContact* c1, const mjContact* c2, void* context) {
const mjModel* m = (const mjModel*) context;
mjContact* c1 = (mjContact*)el1;
mjContact* c2 = (mjContact*)el2;

// get colliding object ids
int con1_obj1 = c1->geom[0] >= 0 ? c1->geom[0] : (c1->elem[0] >= 0 ? c1->elem[0] : c1->vert[0]);
Expand Down Expand Up @@ -258,6 +256,9 @@ quicksortfunc(contactcompare, context, el1, el2) {
return 0;
}

// define contactSort function for sorting contacts
mjSORT(contactSort, mjContact, contactcompare)



// main collision function
Expand Down Expand Up @@ -364,8 +365,14 @@ void mj_collision(const mjModel* m, mjData* d) {
int ncon_after = d->ncon;

// sort contacts
mjQUICKSORT(d->contact + ncon_before, ncon_after - ncon_before,
sizeof(mjContact), contactcompare, (void*) m);
int n = ncon_after - ncon_before;
if (n > 1) {
mj_markStack(d);
mjContact* buf = (mjContact*)mj_stackAllocByte(d, n * sizeof(mjContact),
_Alignof(mjContact));
contactSort(d->contact + ncon_before, buf, n, (void*)m);
mj_freeStack(d);
}
}

// process bodyflex pair: all-to-all
Expand Down Expand Up @@ -1006,10 +1013,7 @@ typedef struct _mjtSAP mjtSAP;


// comparison function for SAP
quicksortfunc(SAPcompare, context, el1, el2) {
mjtSAP* obj1 = (mjtSAP*)el1;
mjtSAP* obj2 = (mjtSAP*)el2;

static inline int SAPcmp(mjtSAP* obj1, mjtSAP* obj2, void* context) {
if (obj1->value < obj2->value) {
return -1;
} else if (obj1->value == obj2->value) {
Expand All @@ -1019,6 +1023,8 @@ quicksortfunc(SAPcompare, context, el1, el2) {
}
}

// define SAPsort function for sorting SAP sorting
mjSORT(SAPsort, mjtSAP, SAPcmp)


// given list of axis-aligned bounding boxes in AAMM (xmin[3], xmax[3]) format,
Expand All @@ -1043,7 +1049,8 @@ static int mj_SAP(mjData* d, const mjtNum* aamm, int n, int axis, int* pair, int
}

// sort along specified axis
mjQUICKSORT(sortbuf, 2*n, sizeof(mjtSAP), SAPcompare, 0);
mjtSAP* buf = (mjtSAP*) mj_stackAllocByte(d, 2*n*sizeof(mjtSAP), _Alignof(mjtSAP));
SAPsort(sortbuf, buf, 2*n, NULL);

// define the other two axes
int axisA, axisB;
Expand Down Expand Up @@ -1133,19 +1140,18 @@ static void updateCov(mjtNum cov[9], const mjtNum vec[3], const mjtNum cen[3]) {


// comparison function for unsigned ints
quicksortfunc(uintcompare, context, el1, el2) {
unsigned int n1 = *(unsigned int*)el1;
unsigned int n2 = *(unsigned int*)el2;

if (n1 < n2) {
static inline int uintcmp(int* i, int* j, void* context) {
if ((unsigned) *i < (unsigned) *j) {
return -1;
} else if (n1 == n2) {
} else if (*i == *j) {
return 0;
} else {
return 1;
}
}

// define bfsort function for sorting bodyflex pairs
mjSORT(bfsort, int, uintcmp)


// broadphase collision detector
Expand Down Expand Up @@ -1281,8 +1287,9 @@ int mj_broadphase(const mjModel* m, mjData* d, int* bfpair, int maxpair) {
endbroad:

// sort bodyflex pairs by signature
if (npair) {
mjQUICKSORT(bfpair, npair, sizeof(int), uintcompare, 0);
if (npair > 1) {
int* buf = mj_stackAllocInt(d, npair);
bfsort(bfpair, buf, npair, NULL);
}

mj_freeStack(d);
Expand Down
17 changes: 11 additions & 6 deletions src/engine/engine_collision_sdf.c
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,10 @@ static void undoTransformation(const mjModel* m, const mjData* d, int g,
//---------------------------- narrow phase -----------------------------------------------

// comparison function for contact sorting
quicksortfunc(distcompare, dist, i1, i2) {
mjtNum d1 = ((mjtNum*)dist)[*(int*)i1];
mjtNum d2 = ((mjtNum*)dist)[*(int*)i2];

if (d1 < d2) {
static inline int distcmp(int* i, int* j, void* context) {
mjtNum d1 = ((mjtNum*)context)[*i];
mjtNum d2 = ((mjtNum*)context)[*j];
if (d1 < d2) {
return -1;
} else if (d1 == d2) {
return 0;
Expand All @@ -319,6 +318,9 @@ quicksortfunc(distcompare, dist, i1, i2) {
}
}

// define distSort function for contact sorting
mjSORT(distSort, int, distcmp)

// check if the collision point already exists
static int isknown(const mjtNum* points, const mjtNum x[3], int cnt) {
for (int i = 0; i < cnt; i++) {
Expand Down Expand Up @@ -642,7 +644,10 @@ int mjc_MeshSDF(const mjModel* m, const mjData* d, mjContact* con, int g1, int g
}

// sort contacts using depth
mjQUICKSORT(index, ncandidate, sizeof(int), distcompare, dist);
if (ncandidate > 1) {
int buf[MAXMESHPNT];
distSort(index, buf, ncandidate, dist);
}

// add only the first mjMAXCONPAIR pairs
for (int i=0; i < mju_min(ncandidate, mjMAXCONPAIR); i++) {
Expand Down
88 changes: 51 additions & 37 deletions src/engine/engine_sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,46 +15,60 @@
#ifndef MUJOCO_SRC_ENGINE_ENGINE_SORT_H_
#define MUJOCO_SRC_ENGINE_ENGINE_SORT_H_

#if !defined(__cplusplus)
#include <stddef.h>
#include <stdlib.h>
// threshold size of a run to do insertion sort on
#define _mjRUNSIZE 32

// sorting functions using q_sort_s/r
#ifdef _WIN32
#define mjQUICKSORT(buf, elnum, elsz, func, context) \
qsort_s(buf, elnum, elsz, func, context)
#define quicksortfunc(name, context, el1, el2) \
static int name(void* context, const void* el1, const void* el2)
#else // assumes POSIX
#ifdef __APPLE__
#define mjQUICKSORT(buf, elnum, elsz, func, context) \
qsort_r(buf, elnum, elsz, context, func)
#define quicksortfunc(name, context, el1, el2) \
static int name(void* context, const void* el1, const void* el2)
#else // non-Apple
#define mjQUICKSORT(buf, elnum, elsz, func, context) \
qsort_r(buf, elnum, elsz, func, context)
#define quicksortfunc(name, context, el1, el2) \
static int name(const void* el1, const void* el2, void* context)
#endif
#endif
#else
#include <algorithm>
#include <cstddef>
#include <cstdlib>
// insertion sort sub-macro that runs on a sub-array [start, ..., end)
#define _mjINSERTION_SORT(type, arr, start, end, cmp, context) \
{ \
for (int j = start + 1; j < end; j++) { \
type tmp = arr[j]; \
int k = j - 1; \
for (; k >= start && cmp(arr + k, &tmp, context) > 0; k--) { \
arr[k + 1] = arr[k]; \
} \
arr[k + 1] = tmp; \
} \
}

// sorting function using std::sort
template <typename T>
void mjQUICKSORT(T* buf, size_t elnum, size_t elsz,
int (*compare)(const void* a, const void* b, void* c),
void* context) {
std::sort(buf, buf + elnum, [compare, context](const T& a, const T& b) {
return compare(&a, &b, context) < 0;
});
// sub-macro that merges two sub-sorted arrays [start, ..., mid), [mid, ..., end) together
#define _mjMERGE(type, arr, buf, start, mid, end, cmp, context) \
{ \
int len1 = mid - start, len2 = end - mid; \
type* left = buf, *right = buf + len1; \
for (int i = 0; i < len1; i++) left[i] = arr[start + i]; \
for (int i = 0; i < len2; i++) right[i] = arr[mid + i]; \
int i = 0, j = 0, k = start; \
while (i < len1 && j < len2) { \
if (cmp(left + i, right + j, context) <= 0) { \
arr[k++] = left[i++]; \
} else { \
arr[k++] = right[j++]; \
} \
} \
while (i < len1) arr[k++] = left[i++]; \
while (j < len2) arr[k++] = right[j++]; \
}

#define quicksortfunc(name, context, el1, el2) \
static int name(const void* el1, const void* el2, void* context)
#endif
// defines an inline stable sorting function via tiled merge sorting (timsort)
// function is of form:
// void name(type* arr, type* buf, int n, void* context)
// where arr is the array of size n to be sorted inplace and buf is a buffer of size n.
#define mjSORT(name, type, cmp) \
static inline void name(type* arr, type* buf, int n, void* context) { \
for (int start = 0; start < n; start += _mjRUNSIZE) { \
int end = (start + _mjRUNSIZE < n) ? start + _mjRUNSIZE : n; \
_mjINSERTION_SORT(type, arr, start, end, cmp, context); \
} \
for (int len = _mjRUNSIZE; len < n; len *= 2) { \
for (int start = 0; start < n; start += 2*len) { \
int mid = start + len; \
int end = (start + 2*len < n) ? start + 2*len : n; \
if (mid < end) { \
_mjMERGE(type, arr, buf, start, mid, end, cmp, context); \
} \
} \
} \
}

#endif // MUJOCO_SRC_ENGINE_ENGINE_SORT_H_
15 changes: 11 additions & 4 deletions src/render/render_gl3.c
Original file line number Diff line number Diff line change
Expand Up @@ -766,10 +766,10 @@ static void setView(int view, mjrRect viewport, const mjvScene* scn, const mjrCo


// comparison function for geom sorting
quicksortfunc(geomcompare, context, el1, el2) {
static inline int geomcmp(int* i, int* j, void* context) {
mjvGeom* geom = (mjvGeom*) context;
float d1 = geom[*(int*)el1].camdist;
float d2 = geom[*(int*)el2].camdist;
float d1 = geom[*i].camdist;
float d2 = geom[*j].camdist;

if (d1 < d2) {
return -1;
Expand All @@ -780,6 +780,9 @@ quicksortfunc(geomcompare, context, el1, el2) {
}
}

// define geomSort function for sorting geoms
mjSORT(geomSort, int, geomcmp)



// adjust light n position and direction
Expand Down Expand Up @@ -907,7 +910,11 @@ void mjr_render(mjrRect viewport, mjvScene* scn, const mjrContext* con) {
}

// sort transparent geoms according to distance to camera
mjQUICKSORT(scn->geomorder, nt, sizeof(int), geomcompare, scn->geoms);
if (nt > 1) {
int *buf = (int*) mju_malloc(nt * sizeof(int));
geomSort(scn->geomorder, buf, nt, scn->geoms);
mju_free(buf);
}

// allow only one reflective geom
int j = 0;
Expand Down
Loading

0 comments on commit 2b0629d

Please sign in to comment.