C++调用Libtorch常见函数

#创建变量 
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({ 1,3,224,224 }));
torch::jit::IValue inputs;

#定义模型变量
torch::jit::script::Module model = torch::jit::load("path");

at::Tensor output = model.forward(inputs).toTensor();
#获取尺寸

ouput.sizes()
int heigh = output.size(0);
int weight = output.size(1);

torch::Tensor out_tensor = output.detach(); # requires_grad为false,
out_tensor = out_tensor.squeeze().detach().permute({ 1, 2, 0 });
// squeeze 减少图像尺寸 permute 交换维度
out_tensor = out_tensor.mul(255).clamp(0, 255).to(torch::kU8); //*255,转uint8 
out_tensor = out_tensor.to(torch::kCPU); //迁移至CPU
cv::Mat resultImg(img_h, img_w, CV_8UC3, out_tensor.data_ptr()); // 将Tensor数据拷贝至Mat
// cv::cvtColor(resultImg, resultImg, CV_RGB2BGR); 


#
cv::Mat tensor2Mat(torch::Tensor &i_tensor)
{
	int height = i_tensor.size(0), width = i_tensor.size(1);
	//i_tensor = i_tensor.to(torch::kF32);
	i_tensor = i_tensor.to(torch::kCPU);
	cv::Mat o_Mat(cv::Size(width, height), CV_32F, i_tensor.data_ptr());
	return o_Mat;
}

版权声明:本文为weixin_43474255原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。