diff --git a/Dockerfile b/Dockerfile index 07a1fea..8367f80 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,11 +6,11 @@ RUN apt-get update && apt-get install -y \ python3-pip \ git -# Install Python packages -RUN pip3 install torch torchvision torch-tensorrt pandas Pillow numpy packaging onnx -ΓΈ # Set the working directory WORKDIR /workspace # Copy local project files to /workspace in the image COPY . /workspace + +# Install Python packages +RUN pip3 install --no-cache-dir -r /workspace/requirements.txt \ No newline at end of file diff --git a/README.md b/README.md index 7d11329..8a1dbe0 100644 --- a/README.md +++ b/README.md @@ -7,9 +7,9 @@ 5. [Inference Benchmark Results](#inference-benchmark-results) - [Example of Results](#example-of-results) - [Explanation of Results](#explanation-of-results) -6. [Author](#author) -7. [References](#references) -8. [Notes](#notes) +6. [ONNX Exporter](#onnx-exporter) +7. [Author](#author) +8. [References](#references) ## Overview This project demonstrates how to perform inference with a PyTorch model and optimize it using NVIDIA TensorRT. The script loads a pre-trained ResNet-50 model from torchvision, performs inference on a user-provided image, and prints the top-K predicted classes. Additionally, the script benchmarks the model's performance in the following configurations: CPU, CUDA, TensorRT-FP32, and TensorRT-FP16, providing insights into the speedup gained through optimization. @@ -30,19 +30,20 @@ docker build -t awesome-tesnorrt . docker run --gpus all --rm -it awesome-tesnorrt # 3. Run the Script inside the Container -python src/main.py --image_path /path-to-image/image.jpg --topk 2 +python src/main.py ``` ### Arguments -- `--image_path`: Specifies the path to the image you want to predict. +- `--image_path`: (Optional) Specifies the path to the image you want to predict. - `--topk`: (Optional) Specifies the number of top predictions to show. Defaults to 5 if not provided. +- `--onnx`: (Optional) Specifies if we want export ResNet50 model to ONNX and run benchmark only for this model ## Example Command ```sh -python src/main.py --image_path ./inference/cat3.jpg --topk 3 --show_image +python src/main.py --image_path ./inference/cat3.jpg --topk 3 --onnx ``` -This command will run predictions on the image at the specified path, show the top 3 predictions, and display the image. If you do not want to display the image, omit the `--show_image` flag. For the default 5 top predictions, omit the `--topk` argument or set it to 5. +This command will run predictions on the image at the specified path and show the top 3 predictions using both PyTorch and ONNX Runtime models. For the default 5 top predictions, omit the --topk argument or set it to 5. ## Inference Benchmark Results @@ -58,6 +59,7 @@ My prediction: %33 tabby My prediction: %26 Egyptian cat Running Benchmark for CPU Average batch time: 942.47 ms +Average ONNX inference time: 15.59 ms Running Benchmark for CUDA Average batch time: 41.02 ms Compiling and Running Inference Benchmark for TensorRT with precision: torch.float32 @@ -70,16 +72,16 @@ Average batch time: 7.25 ms - First k lines show the topk predictions. For example, `My prediction: %33 tabby` displays the highest confidence prediction made by the model for the input image, confidence level (`%33`), and the predicted class (`tabby`). - The following lines provide information about the average batch time for running the model in different configurations: - `Running Benchmark for CPU` and `Average batch time: 942.47 ms` indicate the average batch time when running the model on the CPU. + - `Average ONNX inference time: 15.59 ms` indicate the average batch time when running the ONNX model on the CPU. - `Running Benchmark for CUDA` and `Average batch time: 41.02 ms` indicate the average batch time when running the model on CUDA. - `Compiling and Running Inference Benchmark for TensorRT with precision: torch.float32` and `Average batch time: 19.20 ms` show the average batch time when running the model with TensorRT using `float32` precision. - `Compiling and Running Inference Benchmark for TensorRT with precision: torch.float16` and `Average batch time: 7.25 ms` indicate the average batch time when running the model with TensorRT using `float16` precision. +## ONNX Exporter +The ONNX Exporter utility is integrated into this project to allow the conversion of the PyTorch model to ONNX format, enabling inference and benchmarking using ONNX Runtime. The ONNX model can provide hardware-agnostic optimizations and is widely supported across various platforms and devices. + ## Author [DimaBir](https://github.com/DimaBir) ## References - [ResNetTensorRT Project](https://github.com/DimaBir/ResNetTensorRT/tree/main) - -## Notes -- The project uses a Docker container built on top of the NVIDIA TensorRT image to ensure that all dependencies, including CUDA and TensorRT, are correctly installed and configured. -- Please ensure you have the NVIDIA Container Toolkit installed to run the container with GPU support. diff --git a/inference/cat3.jpg b/inference/cat3.jpg new file mode 100644 index 0000000..0a81671 Binary files /dev/null and b/inference/cat3.jpg differ diff --git a/inference/fan.jpg b/inference/fan.jpg new file mode 100644 index 0000000..7090ab2 Binary files /dev/null and b/inference/fan.jpg differ diff --git a/inference/image-2.jpg b/inference/image-2.jpg deleted file mode 100644 index c1f993c..0000000 Binary files a/inference/image-2.jpg and /dev/null differ diff --git a/inference/vase.jpg b/inference/vase.jpg new file mode 100644 index 0000000..d6e0917 Binary files /dev/null and b/inference/vase.jpg differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..91b1f90 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +torch +torchvision +torch-tensorrt +pandas +Pillow +numpy +packaging +onnx +onnxruntime \ No newline at end of file diff --git a/src/benchmark.py b/src/benchmark.py index ab2ecc6..4708c28 100644 --- a/src/benchmark.py +++ b/src/benchmark.py @@ -1,16 +1,35 @@ import time from typing import Tuple +from abc import ABC, abstractmethod import numpy as np import torch import torch.backends.cudnn as cudnn import logging +import onnxruntime as ort # Configure logging logging.basicConfig(filename="model.log", level=logging.INFO) -class Benchmark: +class Benchmark(ABC): + """ + Abstract class representing a benchmark. + """ + + def __init__(self, nruns: int = 100, nwarmup: int = 50): + self.nruns = nruns + self.nwarmup = nwarmup + + @abstractmethod + def run(self) -> None: + """ + Abstract method to run the benchmark. + """ + pass + + +class PyTorchBenchmark: def __init__( self, model: torch.nn.Module, @@ -74,3 +93,43 @@ def run(self) -> None: print(f"Input shape: {input_data.size()}") print(f"Output features size: {features.size()}") logging.info(f"Average batch time: {np.mean(timings) * 1000:.2f} ms") + + +class ONNXBenchmark(Benchmark): + """ + A class used to benchmark the performance of an ONNX model. + """ + + def __init__( + self, + ort_session: ort.InferenceSession, + input_shape: tuple, + nruns: int = 100, + nwarmup: int = 50, + ): + super().__init__(nruns) + self.ort_session = ort_session + self.input_shape = input_shape + self.nwarmup = nwarmup + self.nruns = nruns + + def run(self) -> None: + print("Warming up ...") + # Adjusting the batch size in the input shape to match the expected input size of the model. + input_shape = (1,) + self.input_shape[1:] + input_data = np.random.randn(*input_shape).astype(np.float32) + + for _ in range(self.nwarmup): # Warm-up runs + _ = self.ort_session.run(None, {"input": input_data}) + + print("Starting benchmark ...") + timings = [] + + for _ in range(self.nruns): + start_time = time.time() + _ = self.ort_session.run(None, {"input": input_data}) + end_time = time.time() + timings.append(end_time - start_time) + + avg_time = np.mean(timings) * 1000 + logging.info(f"Average ONNX inference time: {avg_time:.2f} ms") diff --git a/src/main.py b/src/main.py index ab50a40..e278216 100644 --- a/src/main.py +++ b/src/main.py @@ -3,58 +3,100 @@ import onnx import torch import torch_tensorrt -from typing import List, Tuple +from typing import List, Tuple, Union +import onnxruntime as ort +import numpy as np from model import ModelLoader from image_processor import ImageProcessor -from benchmark import Benchmark -from src.onnx_exporter import ONNXExporter +from benchmark import PyTorchBenchmark, ONNXBenchmark +from onnx_exporter import ONNXExporter # Configure logging -logging.basicConfig(filename='model.log', level=logging.INFO) +logging.basicConfig(filename="model.log", level=logging.INFO) -def run_benchmark(model: torch.nn.Module, device: str, dtype: torch.dtype) -> None: +def run_benchmark( + model: torch.nn.Module, + device: str, + dtype: torch.dtype, + ort_session: ort.InferenceSession = None, + onnx: bool = False, +) -> None: """ Run and log the benchmark for the given model, device, and dtype. + :param onnx: + :param ort_session: :param model: The model to be benchmarked. :param device: The device to run the benchmark on ("cpu" or "cuda"). :param dtype: The data type to be used in the benchmark (typically torch.float32 or torch.float16). """ - logging.info(f"Running Benchmark for {device.upper()}") - benchmark = Benchmark(model, device=device, dtype=dtype) + if onnx: + logging.info(f"Running Benchmark for ONNX") + benchmark = ONNXBenchmark(ort_session, input_shape=(32, 3, 224, 224)) + else: + logging.info(f"Running Benchmark for {device.upper()}") + benchmark = PyTorchBenchmark(model, device=device, dtype=dtype) benchmark.run() def make_prediction( - model: torch.nn.Module, - img_batch: torch.Tensor, + model: Union[torch.nn.Module, ort.InferenceSession], + img_batch: Union[torch.Tensor, np.ndarray], topk: int, categories: List[str], - precision: torch.dtype, + precision: torch.dtype = None, ) -> None: """ Make and print predictions for the given model, img_batch, topk, and categories. - :param model: The model to make predictions with. + :param model: The model (or ONNX Runtime InferenceSession) to make predictions with. :param img_batch: The batch of images to make predictions on. :param topk: The number of top predictions to show. :param categories: The list of categories to label the predictions. - :param precision: The data type to be used for the predictions (typically torch.float32 or torch.float16). + :param precision: The data type to be used for the predictions (typically torch.float32 or torch.float16) for PyTorch models. """ - # Clone img_batch to avoid in-place modifications - img_batch = img_batch.clone().to(precision) + is_onnx_model = isinstance(model, ort.InferenceSession) + + if is_onnx_model: + # Get the input name for the ONNX model. + input_name = model.get_inputs()[0].name + + # Run the model with the properly named input. + ort_inputs = {input_name: img_batch} + ort_outs = model.run(None, ort_inputs) + + # Assuming the model returns a list with one array of class probabilities. + if len(ort_outs) > 0: + prob = ort_outs[0] + + # Checking if prob has more than one dimension and selecting the right one. + if prob.ndim > 1: + prob = prob[0] + + # Apply Softmax to get probabilities + prob = np.exp(prob) / np.sum(np.exp(prob)) + + else: # PyTorch Model + img_batch = img_batch.clone().to(precision) + + model.eval() + with torch.no_grad(): + outputs = model(img_batch.to(precision)) + prob = torch.nn.functional.softmax(outputs[0], dim=0) + prob = prob.cpu().numpy() - model.eval() - with torch.no_grad(): - outputs = model(img_batch.to(precision)) - prob = torch.nn.functional.softmax(outputs[0], dim=0) + top_indices = prob.argsort()[-topk:][::-1] + top_probs = prob[top_indices] - probs, classes = torch.topk(prob, topk) for i in range(topk): - probability = probs[i].item() - class_label = categories[0][int(classes[i])] + probability = top_probs[i] + if is_onnx_model: + # Accessing the DataFrame by row number using .iloc[] + class_label = categories.iloc[top_indices[i]].item() + else: + class_label = categories[0][int(top_indices[i])] logging.info(f"#{i + 1}: {int(probability * 100)}% {class_label}") @@ -68,7 +110,6 @@ def main() -> None: parser.add_argument( "--image_path", type=str, - required=True, default="./inference/cat3.jpg", help="Path to the image to predict", ) @@ -76,9 +117,7 @@ def main() -> None: "--topk", type=int, default=5, help="Number of top predictions to show" ) parser.add_argument( - "--onnx", - action="store_true", - help="If we want export model to ONNX format" + "--onnx", action="store_true", help="If we want export model to ONNX format" ) parser.add_argument( "--onnx_path", @@ -101,50 +140,62 @@ def main() -> None: onnx_path = args.onnx_path # Export the model to ONNX format using ONNXExporter - onnx_exporter = ONNXExporter(model_loader.model, onnx_path) + onnx_exporter = ONNXExporter(model_loader.model, device, onnx_path) onnx_exporter.export_model() - # check if model was loaded successfully - model = onnx.load(onnx_path) - - # Check the model well-formed - onnx.checker.check_model(model) - - print(onnx.helper.printable_graph(model.graph)) - exit(0) - - # Make and log predictions for CPU - print("Making prediction with CPU model") - make_prediction( - model_loader.model.to("cpu"), img_batch.to("cpu"), args.topk, model_loader.categories, torch.float32 - ) - - # Run benchmarks for CPU and CUDA - run_benchmark(model_loader.model.to("cpu"), "cpu", torch.float32) - run_benchmark(model_loader.model.to("cuda"), "cuda", torch.float32) + # Create ONNX Runtime session + ort_session = ort.InferenceSession( + onnx_path, providers=["CPUExecutionProvider"] + ) - # Trace CUDA model - print("Tracing CUDA model") - traced_model = torch.jit.trace( - model_loader.model, [torch.randn((1, 3, 224, 224)).to("cuda")] - ) + # Run benchmark + run_benchmark(None, None, None, ort_session, onnx=True) - # Compile, run benchmarks and make predictions with TensorRT models - for precision in [torch.float32, torch.float16]: - logging.info( - f"Running Inference Benchmark for TensorRT with precision: {precision}" - ) - trt_model = torch_tensorrt.compile( - traced_model, - inputs=[torch_tensorrt.Input((32, 3, 224, 224), dtype=precision)], - enabled_precisions={precision}, - truncate_long_and_double=True, - ) - run_benchmark(trt_model, "cuda", precision) - print("Making prediction with TensorRT model") + # Make prediction + print(f"Making prediction with {ort.get_device()} for ONNX model") make_prediction( - trt_model, img_batch, args.topk, model_loader.categories, precision + ort_session, + img_batch.cpu().numpy(), + topk=args.topk, + categories=model_loader.categories, ) + else: + # Define configurations for which to run benchmarks and make predictions + configs = [ + ("cpu", torch.float32), + ("cuda", torch.float32), + ("cuda", torch.float16), + ] + + for device, precision in configs: + model = model_loader.model.to(device) + + if device == "cuda": + print(f"Tracing {device} model") + model = torch.jit.trace( + model, [torch.randn((1, 3, 224, 224)).to(device)] + ) + + if device == "cuda" and precision == torch.float16: + print("Compiling TensorRT model") + model = torch_tensorrt.compile( + model, + inputs=[torch_tensorrt.Input((32, 3, 224, 224), dtype=precision)], + enabled_precisions={precision}, + truncate_long_and_double=True, + ) + + print(f"Making prediction with {device} model in {precision} precision") + make_prediction( + model, + img_batch.to(device), + args.topk, + model_loader.categories, + precision, + ) + + print(f"Running Benchmark for {device} model in {precision} precision") + run_benchmark(model, device, precision) if __name__ == "__main__": diff --git a/src/onnx_exporter.py b/src/onnx_exporter.py index 8405bd7..c11bec2 100644 --- a/src/onnx_exporter.py +++ b/src/onnx_exporter.py @@ -1,18 +1,28 @@ import torch -from torch.onnx import export +from torch.onnx import export, TrainingMode from torchvision import models + class ONNXExporter: - def __init__(self, model: torchvision.models.ResNet, onnx_path: str): + def __init__(self, model, device, onnx_path: str): self.model = model self.onnx_path = onnx_path + self.device = device def export_model(self): - self.mode.eval() + self.model.eval() # Define dummy input tensor - x = torch.randn(1, 3, 224, 224).to(self.model.device) + x = torch.randn(1, 3, 224, 224).to(self.device) # Export model as ONNX - export(self.model, x, self.onnx_path, verbose=True, input_names=['input'], output_names=['output']) - print(f"Model exported to {self.onnx_path}") \ No newline at end of file + export( + self.model, + x, + self.onnx_path, + training=TrainingMode.EVAL, + verbose=True, + input_names=["input"], + output_names=["output"], + ) + print(f"Model exported to {self.onnx_path}")