-
Notifications
You must be signed in to change notification settings - Fork 1
/
tensor.h
152 lines (133 loc) · 3.19 KB
/
tensor.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#pragma once
// class Tensor {};
// class Tensor {
// public:
// void Resize(const DDimLite &ddim) { dims_ = ddim; }
// size_t data_size() const { return this->dims().production(); }
// size_t memory_size() const { return memory_size_; }
// size_t offset() const { return offset_; }
// void *raw_data() {
// return static_cast<char *>(
// (static_cast<char *>(buffer_->data()) + offset_));
// }
// private:
// DDimLite dims_;
// std::shared_ptr<Buffer> buffer_;
// LoD lod_;
// size_t memory_size_{};
// /// @brief Buffer may be shared with other tensors
// size_t offset_{0};
// };
enum class DataType {
FLOAT,
HALF,
INT8,
INT16,
INT32,
INT64,
DOUBLE,
BOOL,
INVALID,
};
enum class DeviceType {
CPU,
CUDA,
};
inline int64_t getElementSize(DataType t) noexcept {
switch (t) {
case DataType::INT32:
return 4;
case DataType::FLOAT:
return 4;
case DataType::INT16:
return 2;
case DataType::HALF:
return 2;
case DataType::INT64:
return 8;
case DataType::DOUBLE:
return 8;
case DataType::BOOL:
return 1;
case DataType::INT8:
return 1;
default:
return 0;
}
}
namespace nvinfer1
{
enum class DataType;
}
inline DataType GetDataType(const nvinfer1::DataType& dtype) {
if (dtype == nvinfer1::DataType::kFLOAT) {
return DataType::FLOAT;
} else if (dtype == nvinfer1::DataType::kHALF) {
return DataType::HALF;
} else if (dtype == nvinfer1::DataType::kINT8) {
return DataType::INT8;
} else if (dtype == nvinfer1::DataType::kINT32) {
return DataType::INT32;
}
// kBOOL
return DataType::BOOL;
}
class Tensor {
public:
void Reshape(const std::vector<int>& data_shape, DataType data_type) {
dims_.assign(data_shape.begin(), data_shape.end());
data_type_ = data_type;
size_t size = 0;
for (auto dim : data_shape) {
size *= dim;
}
data_ = reinterpret_cast<uint8_t*>(malloc(size));
}
template <typename T>
T* mutable_data() {
return reinterpret_cast<T*>(data());
}
template <typename T>
T* data() const {
return static_cast<const T*>(data());
}
template <typename T>
void CopyFromCpu(const T* data) {
size_t ele_size = numel() * sizeof(T);
}
template <typename T>
void CopyToCpu(T* data) const {}
std::vector<int> shape() const { return dims_; }
int64_t numel() const {
int64_t res = 0L;
for (auto dim : dims_) {
res *= dim;
}
return res;
}
const std::string& name() const;
DataType type() const { return data_type_; }
DeviceType device() const { return device_type_; }
const void* data() const { return static_cast<const void*>(data_); }
void* data() { return data_; }
~Tensor() {
if (data_ != nullptr) {
free(data_);
data_ = nullptr;
}
}
// private:
// Tensor(const Tensor& tensor) = delete;
// Tensor(const Tensor&& tensor) = delete;
// Tensor& operator=(const Tensor&) = delete;
// Tensor& operator=(const Tensor&&) = delete;
protected:
void* data_{nullptr};
std::vector<int> dims_;
DataType data_type_{DataType::FLOAT};
DeviceType device_type_{DeviceType::CPU};
};
struct NamedTensor {
std::string name;
Tensor tensor;
};