28#define DEFINE_TRT_ENTRYPOINTS 1
29#define DEFINE_TRT_LEGACY_PARSER_ENTRYPOINT 0
31#include "argsParser.h"
35#include "parserOnnxConfig.h"
38#include <cuda_runtime_api.h>
45using namespace nvinfer1;
46using samplesCommon::SampleUniquePtr;
77 std::shared_ptr<nvinfer1::ICudaEngine>
mEngine;
83 SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
84 SampleUniquePtr<nvonnxparser::IParser>& parser);
89 bool processInput(
float* filedata,
const samplesCommon::BufferManager& buffers);
103 auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
109 auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(0));
115 auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
121 auto parser = SampleUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger()));
134 auto profileStream = samplesCommon::makeCudaStream();
139 config->setProfileStream(*profileStream);
141 SampleUniquePtr<IHostMemory> plan{builder->buildSerializedNetwork(*network, *config)};
147 mRuntime = std::shared_ptr<nvinfer1::IRuntime>(createInferRuntime(sample::gLogger.getTRTLogger()));
153 mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(
154 mRuntime->deserializeCudaEngine(plan->data(), plan->size()), samplesCommon::InferDeleter());
160 ASSERT(network->getNbInputs() == 1);
161 mInputDims = network->getInput(0)->getDimensions();
164 ASSERT(network->getNbOutputs() == 1);
165 mOutputDims = network->getOutput(0)->getDimensions();
180 SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
181 SampleUniquePtr<nvonnxparser::IParser>& parser)
184 static_cast<int>(sample::gLogger.getReportableSeverity()));
192 config->setFlag(BuilderFlag::kFP16);
196 config->setFlag(BuilderFlag::kBF16);
200 config->setFlag(BuilderFlag::kINT8);
201 samplesCommon::setAllDynamicRanges(network.get(), 127.0F, 127.0F);
204 samplesCommon::enableDLA(builder.get(), config.get(),
mParams.dlaCore);
218 samplesCommon::BufferManager buffers(
mEngine);
220 auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(
mEngine->createExecutionContext());
226 for (int32_t i = 0, e =
mEngine->getNbIOTensors(); i < e; i++)
228 auto const name =
mEngine->getIOTensorName(i);
229 context->setTensorAddress(name, buffers.getDeviceBuffer(name));
233 ASSERT(
mParams.inputTensorNames.size() == 1);
240 buffers.copyInputToDevice();
242 bool status = context->executeV2(buffers.getDeviceBindings().data());
249 buffers.copyOutputToHost();
252 if (!verifyOutput(buffers))
270 float* hostDataBuffer =
static_cast<float*
>(buffers.getHostBuffer(
mParams.inputTensorNames[0]));
271 for (
int i = 0; i < inputC * inputH * inputW; i++)
273 hostDataBuffer[i] = fileData[i];
284 samplesCommon::OnnxSampleParams params;
285 if (args.dataDirs.empty())
287 params.dataDirs.push_back(
"data/mnist/");
288 params.dataDirs.push_back(
"data/samples/mnist/");
292 params.dataDirs = args.dataDirs;
294 params.onnxFileName =
"mnist.onnx";
295 params.inputTensorNames.push_back(
"Input3");
296 params.outputTensorNames.push_back(
"Plus214_Output_0");
297 params.dlaCore = args.useDLACore;
298 params.int8 = args.runInInt8;
299 params.fp16 = args.runInFp16;
300 params.bf16 = args.runInBf16;
308 auto sampleTest = sample::gLogger.defineTest(
gSampleName, argc, argv);
310 sample::gLogger.reportTestStart(sampleTest);
314 sample::gLogInfo <<
"Building and running a GPU inference engine for Onnx MNIST" << std::endl;
318 return sample::gLogger.reportFail(sampleTest);
322 return sample::gLogger.reportFail(sampleTest);
325 return sample::gLogger.reportPass(sampleTest);
The TensorRTEngine class implements a generic TensorRT model.
std::shared_ptr< nvinfer1::IRuntime > mRuntime
The TensorRT runtime used to deserialize the engine.
bool infer()
Runs the TensorRT inference engine for this sample.
bool build()
Function builds the network engine.
nvinfer1::Dims mOutputDims
The dimensions of the output to the network.
nvinfer1::Dims mInputDims
The dimensions of the input to the network.
bool processInput(float *filedata, const samplesCommon::BufferManager &buffers)
Reads the input and stores the result in a managed buffer.
samplesCommon::OnnxSampleParams mParams
The parameters for the sample.
TensorRTEngine(const samplesCommon::OnnxSampleParams ¶ms)
std::shared_ptr< nvinfer1::ICudaEngine > mEngine
The TensorRT engine used to run the network.
int mNumber
The number to classify.
bool constructNetwork(SampleUniquePtr< nvinfer1::IBuilder > &builder, SampleUniquePtr< nvinfer1::INetworkDefinition > &network, SampleUniquePtr< nvinfer1::IBuilderConfig > &config, SampleUniquePtr< nvonnxparser::IParser > &parser)
Parses an ONNX model for MNIST and creates a TensorRT network.
const std::string gSampleName
samplesCommon::OnnxSampleParams initializeSampleParams(const samplesCommon::Args &args)
Initializes members of the params struct using the command line args.
std::string locateFile(const std::string &filepathSuffix, const std::vector< std::string > &directories, bool reportError=true)