Open3D (C++ API)  0.17.0
TorchHelper.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// Copyright (c) 2018-2023 www.open3d.org
5// SPDX-License-Identifier: MIT
6// ----------------------------------------------------------------------------
7
8#pragma once
9#include <torch/script.h>
10
11#include <sstream>
12#include <type_traits>
13
15
16// Macros for checking tensor properties
17#define CHECK_CUDA(x) \
18 do { \
19 TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") \
20 } while (0)
21
22#define CHECK_CONTIGUOUS(x) \
23 do { \
24 TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") \
25 } while (0)
26
27#define CHECK_TYPE(x, type) \
28 do { \
29 TORCH_CHECK(x.dtype() == torch::type, #x " must have type " #type) \
30 } while (0)
31
32#define CHECK_SAME_DEVICE_TYPE(...) \
33 do { \
34 if (!SameDeviceType({__VA_ARGS__})) { \
35 TORCH_CHECK( \
36 false, \
37 #__VA_ARGS__ \
38 " must all have the same device type but got " + \
39 TensorInfoStr({__VA_ARGS__})) \
40 } \
41 } while (0)
42
43#define CHECK_SAME_DTYPE(...) \
44 do { \
45 if (!SameDtype({__VA_ARGS__})) { \
46 TORCH_CHECK(false, \
47 #__VA_ARGS__ \
48 " must all have the same dtype but got " + \
49 TensorInfoStr({__VA_ARGS__})) \
50 } \
51 } while (0)
52
53// Conversion from standard types to torch types
54typedef std::remove_const<decltype(torch::kInt32)>::type TorchDtype_t;
55template <class T>
57 TORCH_CHECK(false, "Unsupported type");
58}
59template <>
61 return torch::kUInt8;
62}
63template <>
65 return torch::kInt8;
66}
67template <>
69 return torch::kInt16;
70}
71template <>
73 return torch::kInt32;
74}
75template <>
77 return torch::kInt64;
78}
79template <>
81 return torch::kFloat32;
82}
83template <>
85 return torch::kFloat64;
86}
87
88// convenience function for comparing standard types with torch types
89template <class T, class TDtype>
90inline bool CompareTorchDtype(const TDtype& t) {
91 return ToTorchDtype<T>() == t;
92}
93
94// convenience function to check if all tensors have the same device type
95inline bool SameDeviceType(std::initializer_list<torch::Tensor> tensors) {
96 if (tensors.size()) {
97 auto device_type = tensors.begin()->device().type();
98 for (auto t : tensors) {
99 if (device_type != t.device().type()) {
100 return false;
101 }
102 }
103 }
104 return true;
105}
106
107// convenience function to check if all tensors have the same dtype
108inline bool SameDtype(std::initializer_list<torch::Tensor> tensors) {
109 if (tensors.size()) {
110 auto dtype = tensors.begin()->dtype();
111 for (auto t : tensors) {
112 if (dtype != t.dtype()) {
113 return false;
114 }
115 }
116 }
117 return true;
118}
119
120inline std::string TensorInfoStr(std::initializer_list<torch::Tensor> tensors) {
121 std::stringstream sstr;
122 size_t count = 0;
123 for (const auto t : tensors) {
124 sstr << t.sizes() << " " << t.toString() << " " << t.device();
125 ++count;
126 if (count < tensors.size()) sstr << ", ";
127 }
128 return sstr.str();
129}
130
131// convenience function for creating a tensor for temp memory
132inline torch::Tensor CreateTempTensor(const int64_t size,
133 const torch::Device& device,
134 void** ptr = nullptr) {
135 torch::Tensor tensor = torch::empty(
136 {size}, torch::dtype(ToTorchDtype<uint8_t>()).device(device));
137 if (ptr) {
138 *ptr = tensor.data_ptr<uint8_t>();
139 }
140 return tensor;
141}
142
143inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
144 torch::Tensor tensor) {
145 using namespace open3d::ml::op_util;
146
147 std::vector<DimValue> shape;
148 const int rank = tensor.dim();
149 for (int i = 0; i < rank; ++i) {
150 shape.push_back(tensor.size(i));
151 }
152 return shape;
153}
154
156 class TDimX,
157 class... TArgs>
158std::tuple<bool, std::string> CheckShape(torch::Tensor tensor,
159 TDimX&& dimex,
160 TArgs&&... args) {
161 return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(tensor),
162 std::forward<TDimX>(dimex),
163 std::forward<TArgs>(args)...);
164}
165
166//
167// Macros for checking the shape of Tensors.
168// Usage:
169// {
170// using namespace open3d::ml::op_util;
171// Dim w("w");
172// Dim h("h");
173// CHECK_SHAPE(tensor1, 10, w, h); // checks if the first dim is 10
174// // and assigns w and h based on
175// // the shape of tensor1
176//
177// CHECK_SHAPE(tensor2, 10, 20, h); // this checks if the the last dim
178// // of tensor2 matches the last dim
179// // of tensor1. The first two dims
180// // must match 10, 20.
181// }
182//
183//
184// See "../ShapeChecking.h" for more info and limitations.
185//
186#define CHECK_SHAPE(tensor, ...) \
187 do { \
188 bool cs_success_; \
189 std::string cs_errstr_; \
190 std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
191 TORCH_CHECK(cs_success_, \
192 "invalid shape for '" #tensor "', " + cs_errstr_) \
193 } while (0)
194
195#define CHECK_SHAPE_COMBINE_FIRST_DIMS(tensor, ...) \
196 do { \
197 bool cs_success_; \
198 std::string cs_errstr_; \
199 std::tie(cs_success_, cs_errstr_) = \
200 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
201 TORCH_CHECK(cs_success_, \
202 "invalid shape for '" #tensor "', " + cs_errstr_) \
203 } while (0)
204
205#define CHECK_SHAPE_IGNORE_FIRST_DIMS(tensor, ...) \
206 do { \
207 bool cs_success_; \
208 std::string cs_errstr_; \
209 std::tie(cs_success_, cs_errstr_) = \
210 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
211 TORCH_CHECK(cs_success_, \
212 "invalid shape for '" #tensor "', " + cs_errstr_) \
213 } while (0)
214
215#define CHECK_SHAPE_COMBINE_LAST_DIMS(tensor, ...) \
216 do { \
217 bool cs_success_; \
218 std::string cs_errstr_; \
219 std::tie(cs_success_, cs_errstr_) = \
220 CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
221 TORCH_CHECK(cs_success_, \
222 "invalid shape for '" #tensor "', " + cs_errstr_) \
223 } while (0)
224
225#define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor, ...) \
226 do { \
227 bool cs_success_; \
228 std::string cs_errstr_; \
229 std::tie(cs_success_, cs_errstr_) = \
230 CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
231 TORCH_CHECK(cs_success_, \
232 "invalid shape for '" #tensor "', " + cs_errstr_) \
233 } while (0)
TorchDtype_t ToTorchDtype< int64_t >()
Definition: TorchHelper.h:76
TorchDtype_t ToTorchDtype< uint8_t >()
Definition: TorchHelper.h:60
std::string TensorInfoStr(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:120
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(torch::Tensor tensor)
Definition: TorchHelper.h:143
TorchDtype_t ToTorchDtype< int16_t >()
Definition: TorchHelper.h:68
TorchDtype_t ToTorchDtype< int8_t >()
Definition: TorchHelper.h:64
TorchDtype_t ToTorchDtype< double >()
Definition: TorchHelper.h:84
bool SameDtype(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:108
bool SameDeviceType(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:95
std::remove_const< decltype(torch::kInt32)>::type TorchDtype_t
Definition: TorchHelper.h:54
TorchDtype_t ToTorchDtype()
Definition: TorchHelper.h:56
torch::Tensor CreateTempTensor(const int64_t size, const torch::Device &device, void **ptr=nullptr)
Definition: TorchHelper.h:132
std::tuple< bool, std::string > CheckShape(torch::Tensor tensor, TDimX &&dimex, TArgs &&... args)
Definition: TorchHelper.h:158
TorchDtype_t ToTorchDtype< int32_t >()
Definition: TorchHelper.h:72
bool CompareTorchDtype(const TDtype &t)
Definition: TorchHelper.h:90
TorchDtype_t ToTorchDtype< float >()
Definition: TorchHelper.h:80
int size
Definition: FilePCD.cpp:40
int count
Definition: FilePCD.cpp:42
char type
Definition: FilePCD.cpp:41
Definition: ShapeChecking.h:16
CSOpt
Check shape options.
Definition: ShapeChecking.h:405