Skip to content

Commit

Permalink
Define THPStorage struct only once (rather than N times) (pytorch#14802)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#14802

The definetion of THPStorage does not depend on any Real, its macro
defintion is unnecessary, refactor the code so that THPStorage is not macro
defined.

Reviewed By: ezyang

Differential Revision: D13340445

fbshipit-source-id: 343393d0a36c868b9a06eea2ad9b80f5e395e947
  • Loading branch information
lhuang04 authored and facebook-github-bot committed Dec 5, 2018
1 parent ca6311d commit 524574a
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 43 deletions.
1 change: 0 additions & 1 deletion torch/csrc/Storage.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#ifndef THP_STORAGE_INC
#define THP_STORAGE_INC

#define THPStorage TH_CONCAT_3(THP,Real,Storage)
#define THPStorageStr TH_CONCAT_STRING_3(torch.,Real,Storage)
#define THPStorageClass TH_CONCAT_3(THP,Real,StorageClass)
#define THPStorage_(NAME) TH_CONCAT_4(THP,Real,Storage_,NAME)
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/StorageDefs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once
struct THPStorage {
PyObject_HEAD
THWStorage *cdata;
};
1 change: 0 additions & 1 deletion torch/csrc/cuda/Storage.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#ifndef THCP_STORAGE_INC
#define THCP_STORAGE_INC

#define THCPStorage TH_CONCAT_3(THCP,Real,Storage)
#define THCPStorageStr TH_CONCAT_STRING_3(torch.cuda.,Real,Storage)
#define THCPStorageClass TH_CONCAT_3(THCP,Real,StorageClass)
#define THCPStorage_(NAME) TH_CONCAT_4(THCP,Real,Storage_,NAME)
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/cuda/override_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#define THWTensor_(NAME) THCTensor_(NAME)

#define THPStorage_(NAME) TH_CONCAT_4(THCP,Real,Storage_,NAME)
#define THPStorage THCPStorage
#define THPStorageBaseStr THCPStorageBaseStr
#define THPStorageStr THCPStorageStr
#define THPStorageClass THCPStorageClass
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/cuda/restore_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#define THPTensorClass TH_CONCAT_3(THP,Real,TensorClass)
#define THPTensor_(NAME) TH_CONCAT_4(THP,Real,Tensor_,NAME)

#define THPStorage TH_CONCAT_3(THP,Real,Storage)
#define THPStorageStr TH_CONCAT_STRING_3(torch.,Real,Storage)
#define THPStorageClass TH_CONCAT_3(THP,Real,StorageClass)
#define THPStorage_(NAME) TH_CONCAT_4(THP,Real,Storage_,NAME)
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/cuda/undef_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#undef THPTensorType

#undef THPStorage_
#undef THPStorage
#undef THPStorageBaseStr
#undef THPStorageStr
#undef THPStorageClass
Expand Down
50 changes: 25 additions & 25 deletions torch/csrc/generic/Storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,37 +291,37 @@ void THPStorage_(initCopyMethods)()
#ifndef THD_GENERIC_FILE
auto& h = THWStorage_(copy_functions);
// copy from CPU types
THPInsertStorageCopyFunction<THPStorage, THPByteStorage>(&THPByteStorageType, h, &THWStorage_(copyByte));
THPInsertStorageCopyFunction<THPStorage, THPCharStorage>(&THPCharStorageType, h, &THWStorage_(copyChar));
THPInsertStorageCopyFunction<THPStorage, THPShortStorage>(&THPShortStorageType, h, &THWStorage_(copyShort));
THPInsertStorageCopyFunction<THPStorage, THPIntStorage>(&THPIntStorageType, h, &THWStorage_(copyInt));
THPInsertStorageCopyFunction<THPStorage, THPLongStorage>(&THPLongStorageType, h, &THWStorage_(copyLong));
THPInsertStorageCopyFunction<THPStorage, THPHalfStorage>(&THPHalfStorageType, h, &THWStorage_(copyHalf));
THPInsertStorageCopyFunction<THPStorage, THPFloatStorage>(&THPFloatStorageType, h, &THWStorage_(copyFloat));
THPInsertStorageCopyFunction<THPStorage, THPDoubleStorage>(&THPDoubleStorageType, h, &THWStorage_(copyDouble));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPByteStorageType, h, &THWStorage_(copyByte));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPCharStorageType, h, &THWStorage_(copyChar));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPShortStorageType, h, &THWStorage_(copyShort));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPIntStorageType, h, &THWStorage_(copyInt));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPLongStorageType, h, &THWStorage_(copyLong));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPHalfStorageType, h, &THWStorage_(copyHalf));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPFloatStorageType, h, &THWStorage_(copyFloat));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPDoubleStorageType, h, &THWStorage_(copyDouble));
#ifdef THC_GENERIC_FILE
// copy from GPU types
THPInsertStorageCopyFunction<THPStorage, THCPByteStorage>(&THCPByteStorageType, h, &THWStorage_(copyCudaByte));
THPInsertStorageCopyFunction<THPStorage, THCPCharStorage>(&THCPCharStorageType, h, &THWStorage_(copyCudaChar));
THPInsertStorageCopyFunction<THPStorage, THCPShortStorage>(&THCPShortStorageType, h, &THWStorage_(copyCudaShort));
THPInsertStorageCopyFunction<THPStorage, THCPIntStorage>(&THCPIntStorageType, h, &THWStorage_(copyCudaInt));
THPInsertStorageCopyFunction<THPStorage, THCPLongStorage>(&THCPLongStorageType, h, &THWStorage_(copyCudaLong));
THPInsertStorageCopyFunction<THPStorage, THCPFloatStorage>(&THCPFloatStorageType, h, &THWStorage_(copyCudaFloat));
THPInsertStorageCopyFunction<THPStorage, THCPDoubleStorage>(&THCPDoubleStorageType, h, &THWStorage_(copyCudaDouble));
THPInsertStorageCopyFunction<THPStorage, THCPHalfStorage>(&THCPHalfStorageType, h, &THWStorage_(copyCudaHalf));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPByteStorageType, h, &THWStorage_(copyCudaByte));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPCharStorageType, h, &THWStorage_(copyCudaChar));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPShortStorageType, h, &THWStorage_(copyCudaShort));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPIntStorageType, h, &THWStorage_(copyCudaInt));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPLongStorageType, h, &THWStorage_(copyCudaLong));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPFloatStorageType, h, &THWStorage_(copyCudaFloat));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPDoubleStorageType, h, &THWStorage_(copyCudaDouble));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPHalfStorageType, h, &THWStorage_(copyCudaHalf));
// add CPU <- GPU copies to base type
#define THPCpuStorage TH_CONCAT_3(THP, Real, Storage)
/// #define THPCpuStorage TH_CONCAT_3(THP, Real, Storage)
#define THCpuStorage_(name) TH_CONCAT_4(TH, Real, Storage_, name)
extern THPCopyList THCpuStorage_(copy_functions);
auto& b = THCpuStorage_(copy_functions);
THPInsertStorageCopyFunction<THPCpuStorage, THCPByteStorage>(&THCPByteStorageType, b, &THCpuStorage_(copyCudaByte));
THPInsertStorageCopyFunction<THPCpuStorage, THCPCharStorage>(&THCPCharStorageType, b, &THCpuStorage_(copyCudaChar));
THPInsertStorageCopyFunction<THPCpuStorage, THCPShortStorage>(&THCPShortStorageType, b, &THCpuStorage_(copyCudaShort));
THPInsertStorageCopyFunction<THPCpuStorage, THCPIntStorage>(&THCPIntStorageType, b, &THCpuStorage_(copyCudaInt));
THPInsertStorageCopyFunction<THPCpuStorage, THCPLongStorage>(&THCPLongStorageType, b, &THCpuStorage_(copyCudaLong));
THPInsertStorageCopyFunction<THPCpuStorage, THCPFloatStorage>(&THCPFloatStorageType, b, &THCpuStorage_(copyCudaFloat));
THPInsertStorageCopyFunction<THPCpuStorage, THCPDoubleStorage>(&THCPDoubleStorageType, b, &THCpuStorage_(copyCudaDouble));
THPInsertStorageCopyFunction<THPCpuStorage, THCPHalfStorage>(&THCPHalfStorageType, b, &THCpuStorage_(copyCudaHalf));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPByteStorageType, b, &THCpuStorage_(copyCudaByte));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPCharStorageType, b, &THCpuStorage_(copyCudaChar));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPShortStorageType, b, &THCpuStorage_(copyCudaShort));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPIntStorageType, b, &THCpuStorage_(copyCudaInt));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPLongStorageType, b, &THCpuStorage_(copyCudaLong));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPFloatStorageType, b, &THCpuStorage_(copyCudaFloat));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPDoubleStorageType, b, &THCpuStorage_(copyCudaDouble));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPHalfStorageType, b, &THCpuStorage_(copyCudaHalf));
#undef THCpuStorage
#undef THCpuStorage_
#endif
Expand Down
5 changes: 1 addition & 4 deletions torch/csrc/generic/Storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
#define TH_GENERIC_FILE "generic/Storage.h"
#else

struct THPStorage {
PyObject_HEAD
THWStorage *cdata;
};
#include "torch/csrc/StorageDefs.h"

THP_API PyObject * THPStorage_(New)(THWStorage *ptr);
extern PyObject *THPStorageClass;
Expand Down
9 changes: 0 additions & 9 deletions torch/csrc/generic/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,6 @@
#else
#define GENERATE_SPARSE 1
#endif

template<>
void THPPointer<THPStorage>::free() {
if (ptr)
Py_DECREF(ptr);
}

template class THPPointer<THPStorage>;

#undef GENERATE_SPARSE

#endif
8 changes: 8 additions & 0 deletions torch/csrc/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,11 @@ void THPPointer<THTensor>::free() {
THTensor_free(LIBRARY_STATE ptr);
}
}

template<>
void THPPointer<THPStorage>::free() {
if (ptr)
Py_DECREF(ptr);
}

template class THPPointer<THPStorage>;

0 comments on commit 524574a

Please sign in to comment.