diff --git a/README.md b/README.md index 5afd5a4..4355b24 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,9 @@ OpenVINO is a toolkit from Intel that optimizes deep learning model inference fo 4. Perform inference on the provided image using the OpenVINO model. 5. Benchmark results, including average inference time, are logged for the OpenVINO model. +## Benchmarking and Visualization +The results of the benchmarks for all modes are saved and visualized in a bar chart, showcasing the average inference times across different backends. The visualization aids in comparing the performance gains achieved with different optimizations. + #### Requirements Ensure you have installed the OpenVINO Toolkit and the necessary dependencies to use OpenVINO's model optimizer and inference engine. diff --git a/benchmark/benchmark_models.py b/benchmark/benchmark_models.py index e772360..6a73677 100644 --- a/benchmark/benchmark_models.py +++ b/benchmark/benchmark_models.py @@ -1,16 +1,215 @@ -import src.benchmark_class -from benchmark.benchmark_utils import run_benchmark -from src.benchmark_class import PyTorchBenchmark, ONNXBenchmark, OVBenchmark -import openvino as ov +import time +from typing import Tuple + +from abc import ABC, abstractmethod +import numpy as np import torch +import torch.backends.cudnn as cudnn +import logging import onnxruntime as ort +import openvino as ov + +# Configure logging +logging.basicConfig(filename="model.log", level=logging.INFO) + + +class Benchmark(ABC): + """ + Abstract class representing a benchmark. + """ + + def __init__(self, nruns: int = 100, nwarmup: int = 50): + self.nruns = nruns + self.nwarmup = nwarmup + + @abstractmethod + def run(self): + """ + Abstract method to run the benchmark. + """ + pass + + +class PyTorchBenchmark: + def __init__( + self, + model: torch.nn.Module, + device: str = "cuda", + input_shape: Tuple[int, int, int, int] = (32, 3, 224, 224), + dtype: torch.dtype = torch.float32, + nwarmup: int = 50, + nruns: int = 100, + ) -> None: + """ + Initialize the Benchmark object. + + :param model: The model to be benchmarked. + :param device: The device to run the benchmark on ("cpu" or "cuda"). + :param input_shape: The shape of the input data. + :param dtype: The data type to be used in the benchmark (typically torch.float32 or torch.float16). + :param nwarmup: The number of warmup runs before timing. + :param nruns: The number of runs for timing. + """ + self.model = model + self.device = device + self.input_shape = input_shape + self.dtype = dtype + self.nwarmup = nwarmup + self.nruns = nruns + + cudnn.benchmark = True # Enable cuDNN benchmarking optimization + + def run(self): + """ + Run the benchmark with the given model, input shape, and other parameters. + Log the average batch time and print the input shape and output feature size. + """ + # Prepare input data + input_data = torch.randn(self.input_shape).to(self.device).to(self.dtype) + + # Warm up + print("Warm up ...") + with torch.no_grad(): + for _ in range(self.nwarmup): + features = self.model(input_data) + torch.cuda.synchronize() + + # Start timing + print("Start timing ...") + timings = [] + with torch.no_grad(): + for i in range(1, self.nruns + 1): + start_time = time.time() + features = self.model(input_data) + torch.cuda.synchronize() + end_time = time.time() + timings.append(end_time - start_time) + + if i % 10 == 0: + print( + f"Iteration {i}/{self.nruns}, ave batch time {np.mean(timings) * 1000:.2f} ms" + ) + + logging.info(f"Average batch time: {np.mean(timings) * 1000:.2f} ms") + return np.mean(timings) * 1000 + + +class ONNXBenchmark(Benchmark): + """ + A class used to benchmark the performance of an ONNX model. + """ + + def __init__( + self, + ort_session: ort.InferenceSession, + input_shape: tuple, + nruns: int = 100, + nwarmup: int = 50, + ): + super().__init__(nruns) + self.ort_session = ort_session + self.input_shape = input_shape + self.nwarmup = nwarmup + self.nruns = nruns + + def run(self): + print("Warming up ...") + # Adjusting the batch size in the input shape to match the expected input size of the model. + input_shape = (1,) + self.input_shape[1:] + input_data = np.random.randn(*input_shape).astype(np.float32) + + for _ in range(self.nwarmup): # Warm-up runs + _ = self.ort_session.run(None, {"input": input_data}) + + print("Starting benchmark ...") + timings = [] + + for i in range(1, self.nruns + 1): + start_time = time.time() + _ = self.ort_session.run(None, {"input": input_data}) + end_time = time.time() + timings.append(end_time - start_time) + + if i % 10 == 0: + print( + f"Iteration {i}/{self.nruns}, ave batch time {np.mean(timings) * 1000:.2f} ms" + ) + + avg_time = np.mean(timings) * 1000 + logging.info(f"Average ONNX inference time: {avg_time:.2f} ms") + return avg_time + + +class OVBenchmark(Benchmark): + def __init__( + self, model: ov.frontend.FrontEnd, input_shape: Tuple[int, int, int, int] + ): + """ + Initialize the OVBenchmark with the OpenVINO model and the input shape. + + :param model: ov.frontend.FrontEnd + The OpenVINO model. + :param input_shape: Tuple[int, int, int, int] + The shape of the model input. + """ + self.ov_model = model + self.core = ov.Core() + self.compiled_model = None + self.input_shape = input_shape + self.nwarmup = 50 + self.nruns = 100 + self.dummy_input = np.random.randn(*input_shape).astype(np.float32) + + def warmup(self): + """ + Compile the OpenVINO model for optimal execution on available hardware. + """ + self.compiled_model = self.core.compile_model(self.ov_model, "AUTO") + + def inference(self, input_data) -> dict: + """ + Perform inference on the input data using the compiled OpenVINO model. + + :param input_data: np.ndarray + The input data for the model. + :return: dict + The model's output as a dictionary. + """ + outputs = self.compiled_model(inputs={"input": input_data}) + return outputs + + def run(self): + """ + Run the benchmark on the OpenVINO model. It first warms up by compiling the model and then measures + the average inference time over a set number of runs. + """ + # Warm-up runs + logging.info("Warming up ...") + for _ in range(self.nwarmup): + self.warmup() + + # Benchmarking + total_time = 0 + for i in range(1, self.nruns + 1): + start_time = time.time() + _ = self.inference(self.dummy_input) + total_time += time.time() - start_time + + if i % 10 == 0: + print( + f"Iteration {i}/{self.nruns}, ave batch time {total_time / i * 1000:.2f} ms" + ) + + avg_time = total_time / self.nruns + logging.info(f"Average inference time: {avg_time * 1000:.2f} ms") + return avg_time * 1000 def benchmark_onnx_model(ort_session: ort.InferenceSession): run_benchmark(None, None, None, ort_session, onnx=True) -def benchmark_ov_model(ov_model: ov.CompiledModel) -> src.benchmark_class.OVBenchmark: +def benchmark_ov_model(ov_model: ov.CompiledModel) -> OVBenchmark: ov_benchmark = OVBenchmark(ov_model, input_shape=(1, 3, 224, 224)) ov_benchmark.run() return ov_benchmark @@ -18,3 +217,28 @@ def benchmark_ov_model(ov_model: ov.CompiledModel) -> src.benchmark_class.OVBenc def benchmark_cuda_model(cuda_model: torch.nn.Module, device: str, dtype: torch.dtype): run_benchmark(cuda_model, device, dtype) + + +def run_benchmark( + model: torch.nn.Module, + device: str, + dtype: torch.dtype, + ort_session: ort.InferenceSession = None, + onnx: bool = False, +) -> None: + """ + Run and log the benchmark for the given model, device, and dtype. + + :param onnx: + :param ort_session: + :param model: The model to be benchmarked. + :param device: The device to run the benchmark on ("cpu" or "cuda"). + :param dtype: The data type to be used in the benchmark (typically torch.float32 or torch.float16). + """ + if onnx: + logging.info(f"Running Benchmark for ONNX") + benchmark = ONNXBenchmark(ort_session, input_shape=(32, 3, 224, 224)) + else: + logging.info(f"Running Benchmark for {device.upper()} and precision {dtype}") + benchmark = PyTorchBenchmark(model, device=device, dtype=dtype) + benchmark.run() \ No newline at end of file diff --git a/benchmark/benchmark_utils.py b/benchmark/benchmark_utils.py index 38973be..532a4b1 100644 --- a/benchmark/benchmark_utils.py +++ b/benchmark/benchmark_utils.py @@ -8,32 +8,7 @@ import torch import onnxruntime as ort -from src.benchmark_class import PyTorchBenchmark, ONNXBenchmark, OVBenchmark - - -def run_benchmark( - model: torch.nn.Module, - device: str, - dtype: torch.dtype, - ort_session: ort.InferenceSession = None, - onnx: bool = False, -) -> None: - """ - Run and log the benchmark for the given model, device, and dtype. - - :param onnx: - :param ort_session: - :param model: The model to be benchmarked. - :param device: The device to run the benchmark on ("cpu" or "cuda"). - :param dtype: The data type to be used in the benchmark (typically torch.float32 or torch.float16). - """ - if onnx: - logging.info(f"Running Benchmark for ONNX") - benchmark = ONNXBenchmark(ort_session, input_shape=(32, 3, 224, 224)) - else: - logging.info(f"Running Benchmark for {device.upper()} and precision {dtype}") - benchmark = PyTorchBenchmark(model, device=device, dtype=dtype) - benchmark.run() +from benchmark.benchmark_models import PyTorchBenchmark, ONNXBenchmark, OVBenchmark def run_all_benchmarks( @@ -110,7 +85,13 @@ def plot_benchmark_results(results: Dict[str, float]): # Plot plt.figure(figsize=(10, 6)) - ax = sns.barplot(x=data["Time"], y=data["Model"], hue=data["Model"], palette="rocket", legend=False) + ax = sns.barplot( + x=data["Time"], + y=data["Model"], + hue=data["Model"], + palette="rocket", + legend=False, + ) # Adding the actual values on the bars for index, value in enumerate(data["Time"]): diff --git a/main.py b/main.py index 7e34521..058d9f9 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,5 @@ import logging import os.path - import torch_tensorrt from benchmark.benchmark_models import benchmark_onnx_model, benchmark_ov_model @@ -9,11 +8,16 @@ parse_arguments, init_onnx_model, init_ov_model, - init_cuda_model, export_onnx_model, + init_cuda_model, + export_onnx_model, ) from src.image_processor import ImageProcessor from prediction.prediction_models import * from src.model import ModelLoader +import warnings + +# Filter out the specific warning from torchvision +warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.io.image") # Configure logging logging.basicConfig(filename="model.log", level=logging.INFO) @@ -38,18 +42,27 @@ def main(): ort_session = init_onnx_model(args.onnx_path, model_loader, device) if args.mode != "all": benchmark_onnx_model(ort_session) - predict_onnx_model(ort_session, img_batch, args.topk, model_loader.categories) + predict_onnx_model( + ort_session, img_batch, args.topk, model_loader.categories + ) # OpenVINO if args.mode in ["ov", "all"]: # Check if ONNX model wasn't exported previously if not os.path.isfile(args.onnx_path): - export_onnx_model(onnx_path=args.onnx_path, model_loader=model_loader, device=device) + export_onnx_model( + onnx_path=args.onnx_path, model_loader=model_loader, device=device + ) ov_model = init_ov_model(args.onnx_path) if args.mode != "all": ov_benchmark = benchmark_ov_model(ov_model) - predict_ov_model(ov_benchmark.compiled_model, img_batch, args.topk, model_loader.categories) + predict_ov_model( + ov_benchmark.compiled_model, + img_batch, + args.topk, + model_loader.categories, + ) # CUDA if args.mode in ["cuda", "all"]: @@ -75,11 +88,13 @@ def main(): img_batch = img_batch.to(device) else: print("Compiling TensorRT model") + batch_size = 1 if args.mode == "cuda" else 32 model = torch_tensorrt.compile( model, - inputs=[torch_tensorrt.Input((32, 3, 224, 224), dtype=precision)], + inputs=[torch_tensorrt.Input((batch_size, 3, 224, 224), dtype=precision)], enabled_precisions={precision}, truncate_long_and_double=True, + require_full_compilation=True, ) # If it is for TensorRT, determine the mode (FP32 or FP16) and store under a TensorRT key mode = "fp32" if precision == torch.float32 else "fp16" diff --git a/prediction/prediction_utils.py b/prediction/prediction_utils.py index a9ea429..25a5e7b 100644 --- a/prediction/prediction_utils.py +++ b/prediction/prediction_utils.py @@ -4,6 +4,8 @@ import torch import onnxruntime as ort import numpy as np +import torch_tensorrt + def make_prediction( @@ -16,7 +18,7 @@ def make_prediction( """ Make and print predictions for the given model, img_batch, topk, and categories. - :param model: The model (or ONNX Runtime InferenceSession) to make predictions with. + :param model: The model to make predictions with. :param img_batch: The batch of images to make predictions on. :param topk: The number of top predictions to show. :param categories: The list of categories to label the predictions. @@ -58,7 +60,14 @@ def make_prediction( prob = np.exp(prob[0]) / np.sum(np.exp(prob[0])) else: # PyTorch Model - logging.info(f"Running prediction for PyTorch model") + params = list(model.parameters()) + if params: + logging.info(f"Running prediction for PyTorch_{params[0].device}") + elif isinstance(model, torch.nn.Module): + logging.info(f"Running prediction for TensorRT_{precision} model") + else: + raise ValueError("Running prediction for an unknown model type") + if isinstance(img_batch, np.ndarray): img_batch = torch.tensor(img_batch) else: