|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | +import numba |
| 4 | +from scipy.sparse import coo_array, csc_array |
| 5 | + |
| 6 | + |
| 7 | +def get_coo_storage(ns_total, storage, use_storage): |
| 8 | + if not use_storage: |
| 9 | + # coo_uix = np.empty(ns_total, dtype=int) |
| 10 | + coo_six = np.empty(ns_total, dtype=int) |
| 11 | + coo_data = np.empty(ns_total, dtype=np.float32) |
| 12 | + return coo_six, coo_data |
| 13 | + |
| 14 | + if hasattr(storage, "coo_data"): |
| 15 | + if storage.coo_data.size < ns_total: |
| 16 | + # del storage.coo_uix |
| 17 | + del storage.coo_six |
| 18 | + del storage.coo_data |
| 19 | + # storage.coo_uix = np.empty(ns_total, dtype=int) |
| 20 | + storage.coo_six = np.empty(ns_total, dtype=int) |
| 21 | + storage.coo_data = np.empty(ns_total, dtype=np.float32) |
| 22 | + else: |
| 23 | + # storage.coo_uix = np.empty(ns_total, dtype=int) |
| 24 | + storage.coo_six = np.empty(ns_total, dtype=int) |
| 25 | + storage.coo_data = np.empty(ns_total, dtype=np.float32) |
| 26 | + |
| 27 | + # return storage.coo_uix, storage.coo_six, storage.coo_data |
| 28 | + return storage.coo_six, storage.coo_data |
| 29 | + |
| 30 | + |
| 31 | +def coo_to_torch(coo_array, dtype, transpose=False, is_coalesced=True, copy_data=False): |
| 32 | + coo = ( |
| 33 | + torch.from_numpy(coo_array.coords[int(transpose)]), |
| 34 | + torch.from_numpy(coo_array.coords[1 - int(transpose)]), |
| 35 | + ) |
| 36 | + s0, s1 = coo_array.shape |
| 37 | + if transpose: |
| 38 | + s0, s1 = s1, s0 |
| 39 | + res = torch.sparse_coo_tensor( |
| 40 | + torch.row_stack(coo), |
| 41 | + torch.asarray(coo_array.data, dtype=torch.float, copy=copy_data), |
| 42 | + size=(s0, s1), |
| 43 | + is_coalesced=is_coalesced, |
| 44 | + ) |
| 45 | + return res |
| 46 | + |
| 47 | + |
| 48 | +def coo_to_scipy(coo_tensor): |
| 49 | + data = coo_tensor.values().numpy(force=True) |
| 50 | + coords = coo_tensor.indices().numpy(force=True) |
| 51 | + return coo_array((data, coords), shape=coo_tensor.shape) |
| 52 | + |
| 53 | + |
| 54 | +def get_csc_storage(ns_total, storage, use_storage): |
| 55 | + if not use_storage: |
| 56 | + csc_row_indices = np.empty(ns_total, dtype=int) |
| 57 | + csc_data = np.empty(ns_total, dtype=np.float32) |
| 58 | + return csc_row_indices, csc_data |
| 59 | + |
| 60 | + if hasattr(storage, "csc_data"): |
| 61 | + if storage.csc_data.size < ns_total: |
| 62 | + del storage.csc_row_indices |
| 63 | + del storage.csc_data |
| 64 | + storage.csc_row_indices = np.empty(ns_total, dtype=int) |
| 65 | + storage.csc_data = np.empty(ns_total, dtype=np.float32) |
| 66 | + else: |
| 67 | + storage.csc_row_indices = np.empty(ns_total, dtype=int) |
| 68 | + storage.csc_data = np.empty(ns_total, dtype=np.float32) |
| 69 | + |
| 70 | + return storage.csc_row_indices, storage.csc_data |
| 71 | + |
| 72 | + |
| 73 | +sig = "void(i8, i8[::1], i8[::1], i8[::1], f4[::1], f4[::1])" |
| 74 | + |
| 75 | + |
| 76 | +@numba.njit(sig, error_model="numpy", nogil=True, parallel=True) |
| 77 | +def csc_insert(row, write_offsets, inds, csc_indices, csc_data, liks): |
| 78 | + """Insert elements into a CSC sparse array |
| 79 | +
|
| 80 | + To use this, you need to know the indptr in advance. Then this function |
| 81 | + can help you to insert a row into the array. You have to insert all nz |
| 82 | + elements for that row at once in a single call to this function, and |
| 83 | + rows must be written in order. |
| 84 | +
|
| 85 | + (However, the columns within each row can be unordered.) |
| 86 | +
|
| 87 | + write_offsets should be initialized at the indptr -- so, they are |
| 88 | + each column's "write head", indicating how many rows have been written |
| 89 | + for that column so far. |
| 90 | +
|
| 91 | + Then, this updates the row indices (csc_indices) array with the correct |
| 92 | + row for each column (inds), and adds the data in the right place. The |
| 93 | + "write heads" are incremented, so that when this fn is called for the next |
| 94 | + row things are in the right place. |
| 95 | +
|
| 96 | + This would be equivalent to: |
| 97 | + data_ixs = write_offsets[inds] |
| 98 | + csc_indices[data_ixs] = row |
| 99 | + csc_data[data_ixs] = liks |
| 100 | + write_offsets[inds] += 1 |
| 101 | + """ |
| 102 | + for j in numba.prange(inds.shape[0]): |
| 103 | + ind = inds[j] |
| 104 | + data_ix = write_offsets[ind] |
| 105 | + csc_indices[data_ix] = row |
| 106 | + csc_data[data_ix] = liks[j] |
| 107 | + write_offsets[ind] += 1 |
| 108 | + |
| 109 | + |
| 110 | +def coo_sparse_mask_rows(coo, keep_mask): |
| 111 | + """Row indexing with a boolean mask.""" |
| 112 | + if keep_mask.all(): |
| 113 | + return coo |
| 114 | + |
| 115 | + kept_label_indices = np.flatnonzero(keep_mask) |
| 116 | + ii, jj = coo.coords |
| 117 | + ixs = np.searchsorted(kept_label_indices, ii) |
| 118 | + ixs.clip(0, kept_label_indices.size - 1, out=ixs) |
| 119 | + valid = np.flatnonzero(kept_label_indices[ixs] == ii) |
| 120 | + coo = coo_array( |
| 121 | + (coo.data[valid], (ixs[valid], jj[valid])), |
| 122 | + shape=(kept_label_indices.size, coo.shape[1]), |
| 123 | + ) |
| 124 | + return coo |
| 125 | + |
| 126 | + |
| 127 | +def csc_sparse_mask_rows(csc, keep_mask, in_place=False): |
| 128 | + if keep_mask.all(): |
| 129 | + return csc |
| 130 | + |
| 131 | + if not in_place: |
| 132 | + csc = csc.copy() |
| 133 | + |
| 134 | + rowix_dtype = csc.indices.dtype |
| 135 | + kept_row_inds = np.flatnonzero(keep_mask).astype(rowix_dtype) |
| 136 | + oldrow_to_newrow = np.zeros(len(keep_mask), dtype=rowix_dtype) |
| 137 | + oldrow_to_newrow[kept_row_inds] = np.arange(len(kept_row_inds)) |
| 138 | + nnz = _csc_sparse_mask_rows( |
| 139 | + csc.indices, csc.indptr, csc.data, oldrow_to_newrow, keep_mask |
| 140 | + ) |
| 141 | + |
| 142 | + return csc_array( |
| 143 | + (csc.data[:nnz], csc.indices[:nnz], csc.indptr), |
| 144 | + shape=(len(kept_row_inds), csc.shape[1]), |
| 145 | + ) |
| 146 | + |
| 147 | + |
| 148 | +sigs = [ |
| 149 | + "i8(i8[::1], i8[::1], f4[::1], i8[::1], bool_[::1])", |
| 150 | + "i8(i4[::1], i4[::1], f4[::1], i4[::1], bool_[::1])", |
| 151 | +] |
| 152 | + |
| 153 | + |
| 154 | +@numba.njit(sigs, error_model="numpy", nogil=True) |
| 155 | +def _csc_sparse_mask_rows(indices, indptr, data, oldrow_to_newrow, keep_mask): |
| 156 | + write_ix = 0 |
| 157 | + |
| 158 | + column = 0 |
| 159 | + column_kept_count = 0 |
| 160 | + column_end = indptr[1] |
| 161 | + |
| 162 | + for read_ix in range(len(indices)): |
| 163 | + row = indices[read_ix] |
| 164 | + if not keep_mask[row]: |
| 165 | + continue |
| 166 | + |
| 167 | + # write data for this sample |
| 168 | + indices[write_ix] = oldrow_to_newrow[row] |
| 169 | + data[write_ix] = data[read_ix] |
| 170 | + write_ix += 1 |
| 171 | + |
| 172 | + while read_ix >= column_end: |
| 173 | + indptr[column + 1] = write_ix - 1 |
| 174 | + column += 1 |
| 175 | + column_end = indptr[column + 1] |
| 176 | + |
| 177 | + while column < len(indptr) - 1: |
| 178 | + indptr[column + 1] = write_ix |
| 179 | + column += 1 |
| 180 | + column_end = indptr[column + 1] |
| 181 | + |
| 182 | + return write_ix |
| 183 | + |
| 184 | + |
| 185 | +# @numba.njit(sigs, error_model="numpy", nogil=True) |
| 186 | +# def _csc_sparse_mask_rows(indices, indptr, data, oldrow_to_newrow, keep_mask): |
| 187 | +# write_ix = 0 |
| 188 | + |
| 189 | +# column_start = indptr[0] |
| 190 | +# for column in range(len(indptr) - 1): |
| 191 | +# column_kept_count = 0 |
| 192 | +# column_end = indptr[column + 1] |
| 193 | + |
| 194 | +# for read_ix in range(column_start, column_end): |
| 195 | +# row = indices[read_ix] |
| 196 | +# if not keep_mask[row]: |
| 197 | +# continue |
| 198 | + |
| 199 | +# indices[write_ix] = oldrow_to_newrow[row] |
| 200 | +# data[write_ix] = data[read_ix] |
| 201 | +# column_kept_count += 1 |
| 202 | +# write_ix += 1 |
| 203 | + |
| 204 | +# # indptr[column] is not column_start. |
| 205 | +# indptr[column + 1] = indptr[column] + column_kept_count |
| 206 | +# column_start = column_end |
| 207 | + |
| 208 | +# return write_ix |
0 commit comments