Deploying PyTorch Segmentation Model on NVIDIA Jetson Orin with TensorRT C++
In this blog post, I discuss a simplified approach to deploy PyTorch model for real-time image segmentation using on Nvidia Jetson computer with ROS (Robot Operating System) and NVIDIA TensorRT C++.

Building a complete pipeline from training a PyTorch model to deploying it with TensorRT in a C++ inference application involves several steps. I am using pre-trained segmentation model build in PyTorch library. Once the model is trained, it should be saved in PyTorch's format (.pt
or .pth
). I have tweaked PIDNet to work with just two classes: background and object. Instead of the usual method where PIDNet shrinks the image eight times and then enlarges it again during processing, I am simply using the original image directly as output. Then converted into ONNX and finally using tensorrt's utility convert the ONNX model to target platforms .engine format (IMPORTANT! An `.engine` file generated by TensorRT should be used on the same platform on which it was converted from the ONNX format. This is because the engine file is optimized for the specific hardware configuration of that platform, including the GPU architecture and available CUDA cores). I have used a python script to publish video data to a ROS topic in real scenerio you would have a camera streaming ROS node. Both the video streaming script and C++ inference code files are attached to this post.
1. Performing Segmentation Inference with TensorRT
The C++ program for segmentation inference in a ROS environment can be divided into several key parts: initialization, preprocessing, inference, postprocessing, and publishing the segmentation results to ROS topic.
1. Initialization
In the initialization phase, the program sets up the ROS node and the TensorRT inference engine. This involves loading the pre-trained model from a serialized engine file and creating an execution context for performing inference. The program also initializes CUDA streams for asynchronous execution and allocates memory buffers for the input and output data.
Once you have .pt
file from pytorch convert it to ONNX and then you can use following command to convert it to `.engine` format
trtexec --onnx=model.onnx --saveEngine=model.engine
Please run trtexec --help
to see more options
using namespace nvinfer1;
class Logger : public nvinfer1::ILogger{
public:
void log(Severity severity, const char* msg) noexcept {
// remove this 'if' if you need more logged info
if ((severity == Severity::kERROR) || (severity == Severity::kINTERNAL_ERROR)) {
std::cout << msg << "\n";
}
}
} gLogger;
class SegmentationInference {
public:
SegmentationInference() {
cudaSetDevice(DEVICE);
std::string engine_name = "yourEngineFilePath.engine";
deserializeEngine(engine_name);
}
~SegmentationInference() {
context_->destroy();
engine_->destroy();
runtime_->destroy();
}
void doInference(float* input, float* output) {
const ICudaEngine& engine = context_->getEngine();
assert(engine.getNbBindings() == 2);
void* buffers[2];
const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME);
const int outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME);
cudaMalloc(&buffers[inputIndex], BATCH_SIZE * 3 * INPUT_H * INPUT_W * sizeof(float));
cudaMalloc(&buffers[outputIndex], BATCH_SIZE * OUTPUT_SIZE * sizeof(float));
cudaMemcpyAsync(buffers[inputIndex], input, BATCH_SIZE * 3 * INPUT_H * INPUT_W * sizeof(float), cudaMemcpyHostToDevice, stream_);
context_->enqueue(BATCH_SIZE, buffers, stream_, nullptr);
cudaMemcpyAsync(output, buffers[outputIndex], BATCH_SIZE * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream_);
cudaStreamSynchronize(stream_);
cudaFree(buffers[inputIndex]);
cudaFree(buffers[outputIndex]);
}
private:
IRuntime* runtime_;
ICudaEngine* engine_;
IExecutionContext* context_;
cudaStream_t stream_;
void deserializeEngine(const std::string& engine_name) {
std::ifstream file(engine_name, std::ios::binary);
if (file.good()) {
file.seekg(0, file.end);
size_t size = file.tellg();
file.seekg(0, file.beg);
char* trt_model_stream = new char[size];
file.read(trt_model_stream, size);
file.close();
runtime_ = createInferRuntime(gLogger);
engine_ = runtime_->deserializeCudaEngine(trt_model_stream, size);
context_ = engine_->createExecutionContext();
delete[] trt_model_stream;
cudaStreamCreate(&stream_);
} else {
throw std::runtime_error("Failed to open engine file.");
}
}
};
2. Preprocessing
Before feeding the input image into the model, the program performs a series of preprocessing steps:
- Resizing: The input image is resized to match the dimensions expected by the model (e.g., 512x384 pixels).
- Normalization: The pixel values of the image are normalized to a certain range, typically [0, 1], and the mean and standard deviation are subtracted and divided, respectively, to match the normalization used during model training.
- Channel Reordering: The order of color channels is changed from HWC (as read by OpenCV) to CHW or another format expected by the model, because my original model was in build in PyTorch which expects images by default to be in CHW format.
- Flattening: The preprocessed image is converted into a flat array of floating-point numbers (buffers) that can be fed into the model.
// Preprocessing
cv::resize(img, img, cv::Size(INPUT_W, INPUT_H));
img.convertTo(img, CV_32FC3, 1.0 / 255.0);
cv::subtract(img, img_mean, img);
cv::divide(img, img_std, img);
// Rearranging from HWC to CHW
float data[BATCH_SIZE * 3 * INPUT_H * INPUT_W];
for (int i = 0; i < INPUT_H * INPUT_W; i++) {
data[i] = img.at(i)[2];
data[i + INPUT_H * INPUT_W] = img.at(i)[1];
data[i + 2 * INPUT_H * INPUT_W] = img.at(i)[0];
}
3. Inference
The inference phase is where the actual segmentation is performed:
- Input Copying: The preprocessed input data is copied to the input buffer in the GPU memory.
- Model Execution: The inference engine executes the model using the input data to produce the output segmentation map.
- Output Copying: The output data, which contains the raw segmentation probabilities for each pixel, is copied back from the GPU memory to a buffer in the host memory.
4. Postprocessing
After inference, the raw output needs to be postprocessed to generate a usable segmentation mask:
- Thresholding: Each pixel's probabilities are thresholded to determine the most likely class. For example, a probability above a certain threshold (e.g., 0.5) might be considered as belonging to a particular class.
- Color Mapping: The class indices are mapped to colors for visualization. For instance, different objects or features might be assigned different colors in the final segmentation mask.
- Resizing: The segmentation mask might be resized back to the original dimensions of the input image if needed.
// Postprocessing
for (int i = 0; i < OUTPUT_H * OUTPUT_W; i++) {
float fmax = prob[i];
int index = 0;
for (int j = 1; j < cls; j++) {
if (prob[i + j * OUTPUT_H * OUTPUT_W] > fmax) {
index = j;
fmax = prob[i + j * OUTPUT_H * OUTPUT_W];
}
}
if (index == 1) {
result.at(i) = cv::Vec3b(255, 255, 255);
}
}
5. Publishing the Results
The segmentation mask is published as a ROS Image message on the "/segmentation_output" topic. This involves converting the mask back into an image format that can be handled by ROS and other nodes in the system.
2. Streaming Video with ROS and OpenCV
Finally, I set up a video stream that can be processed in real-time. For this, I used a ROS node written in Python, which utilizes the OpenCV library to read frames from a video file and publish them as ROS Image messages on the "/input_video_stream" topic. This node is designed to continuously stream video frames, resizing them to the desired dimensions for further processing. The use of the CvBridge package allows for conversion between ROS Image messages and OpenCV images, making it easy to integrate with other ROS nodes that subscribe to the video stream.
If you have any questions or need any help in this please let me know. You can leave a comment or send me an email: t109318417@ntut.org.tw
Dependencies Required:
- opencv 4.5.5 (build from source with CUDA Enabled)
- Tensorrt 8.5.1.7 or 8.6.1.6
- Cuda 11.8
Credits:
https://github.com/wang-xinyu/tensorrtx/blob/master/unet/
https://github.com/XuJiacong/PIDNet
https://github.com/Darth-Kronos/PIDNet_TensorRT
Files
What's Your Reaction?






