Spaces:
Runtime error
Runtime error
| // Copyright (c) Facebook, Inc. and its affiliates. | |
| namespace detectron2 { | |
| at::Tensor ROIAlignRotated_forward_cpu( | |
| const at::Tensor& input, | |
| const at::Tensor& rois, | |
| const float spatial_scale, | |
| const int pooled_height, | |
| const int pooled_width, | |
| const int sampling_ratio); | |
| at::Tensor ROIAlignRotated_backward_cpu( | |
| const at::Tensor& grad, | |
| const at::Tensor& rois, | |
| const float spatial_scale, | |
| const int pooled_height, | |
| const int pooled_width, | |
| const int batch_size, | |
| const int channels, | |
| const int height, | |
| const int width, | |
| const int sampling_ratio); | |
| at::Tensor ROIAlignRotated_forward_cuda( | |
| const at::Tensor& input, | |
| const at::Tensor& rois, | |
| const float spatial_scale, | |
| const int pooled_height, | |
| const int pooled_width, | |
| const int sampling_ratio); | |
| at::Tensor ROIAlignRotated_backward_cuda( | |
| const at::Tensor& grad, | |
| const at::Tensor& rois, | |
| const float spatial_scale, | |
| const int pooled_height, | |
| const int pooled_width, | |
| const int batch_size, | |
| const int channels, | |
| const int height, | |
| const int width, | |
| const int sampling_ratio); | |
| // Interface for Python | |
| inline at::Tensor ROIAlignRotated_forward( | |
| const at::Tensor& input, | |
| const at::Tensor& rois, | |
| const double spatial_scale, | |
| const int64_t pooled_height, | |
| const int64_t pooled_width, | |
| const int64_t sampling_ratio) { | |
| if (input.is_cuda()) { | |
| return ROIAlignRotated_forward_cuda( | |
| input, | |
| rois, | |
| spatial_scale, | |
| pooled_height, | |
| pooled_width, | |
| sampling_ratio); | |
| AT_ERROR("Detectron2 is not compiled with GPU support!"); | |
| } | |
| return ROIAlignRotated_forward_cpu( | |
| input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); | |
| } | |
| inline at::Tensor ROIAlignRotated_backward( | |
| const at::Tensor& grad, | |
| const at::Tensor& rois, | |
| const double spatial_scale, | |
| const int64_t pooled_height, | |
| const int64_t pooled_width, | |
| const int64_t batch_size, | |
| const int64_t channels, | |
| const int64_t height, | |
| const int64_t width, | |
| const int64_t sampling_ratio) { | |
| if (grad.is_cuda()) { | |
| return ROIAlignRotated_backward_cuda( | |
| grad, | |
| rois, | |
| spatial_scale, | |
| pooled_height, | |
| pooled_width, | |
| batch_size, | |
| channels, | |
| height, | |
| width, | |
| sampling_ratio); | |
| AT_ERROR("Detectron2 is not compiled with GPU support!"); | |
| } | |
| return ROIAlignRotated_backward_cpu( | |
| grad, | |
| rois, | |
| spatial_scale, | |
| pooled_height, | |
| pooled_width, | |
| batch_size, | |
| channels, | |
| height, | |
| width, | |
| sampling_ratio); | |
| } | |
| } // namespace detectron2 | |