From 38d6e052f3489bc85350ff3f713494d9de3cde4c Mon Sep 17 00:00:00 2001 From: Manas Vardhan Date: Mon, 16 Feb 2026 09:04:56 -0800 Subject: [PATCH] Fix outdated torch_logs tutorial by removing CUDA device check The tutorial was wrapped in a CUDA device capability check that caused the entire example to be skipped when built on CI without a compatible GPU. This resulted in the docs showing 'Skipping because torch.compile is not supported on this device' instead of actual log output. torch.compile works on CPU, so this fix removes the device check and uses a dynamic device selection (CUDA if available, otherwise CPU). This ensures the tutorial produces meaningful log output regardless of the build environment. Fixes pytorch/pytorch#137285 --- recipes_source/torch_logs.py | 48 +++++++++++++++++------------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/recipes_source/torch_logs.py b/recipes_source/torch_logs.py index b5c3f0bd8ac..527bcd71f3f 100644 --- a/recipes_source/torch_logs.py +++ b/recipes_source/torch_logs.py @@ -32,51 +32,49 @@ import torch -# exit cleanly if we are on a device that doesn't support torch.compile -if torch.cuda.get_device_capability() < (7, 0): - print("Skipping because torch.compile is not supported on this device.") -else: - @torch.compile() - def fn(x, y): - z = x + y - return z + 2 +@torch.compile() +def fn(x, y): + z = x + y + return z + 2 - inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda")) + +device = "cuda" if torch.cuda.is_available() else "cpu" +inputs = (torch.ones(2, 2, device=device), torch.zeros(2, 2, device=device)) # print separator and reset dynamo # between each example - def separator(name): - print(f"==================={name}=========================") - torch._dynamo.reset() +def separator(name): + print(f"==================={name}=========================") + torch._dynamo.reset() - separator("Dynamo Tracing") +separator("Dynamo Tracing") # View dynamo tracing # TORCH_LOGS="+dynamo" - torch._logging.set_logs(dynamo=logging.DEBUG) - fn(*inputs) +torch._logging.set_logs(dynamo=logging.DEBUG) +fn(*inputs) - separator("Traced Graph") +separator("Traced Graph") # View traced graph # TORCH_LOGS="graph" - torch._logging.set_logs(graph=True) - fn(*inputs) +torch._logging.set_logs(graph=True) +fn(*inputs) - separator("Fusion Decisions") +separator("Fusion Decisions") # View fusion decisions # TORCH_LOGS="fusion" - torch._logging.set_logs(fusion=True) - fn(*inputs) +torch._logging.set_logs(fusion=True) +fn(*inputs) - separator("Output Code") +separator("Output Code") # View output code generated by inductor # TORCH_LOGS="output_code" - torch._logging.set_logs(output_code=True) - fn(*inputs) +torch._logging.set_logs(output_code=True) +fn(*inputs) - separator("") +separator("") ###################################################################### # Conclusion