feat: multiple optimization profiles for disjoint input shape regimes#4325
feat: multiple optimization profiles for disjoint input shape regimes#4325cehongwang wants to merge 6 commits into
Conversation
f32fed3 to
427643d
Compare
427643d to
2cd4797
Compare
f907b64 to
9f9055a
Compare
Add support for defining N optimization profiles at compile time via the list-based ``Input.profiles`` API and selecting the active profile at runtime (manual pin by index, or opt-in shape-based auto-selection). - AOT (torch.export) compile path builds one TRT optimization profile per declared profile index; submodules inherit the profile count via propagation across graph breaks. - Python and C++ runtimes expose a matching primitive engine API (set_active_profile / num_optimization_profiles / _active_profile_index / _auto_select_profiles) so the two runtimes remain interchangeable. - Profile selection is exposed through the optimization_profile context manager; auto-selection uses lazy (first-fitting) profile selection. - Backward compatible: engines without declared profiles keep the historical single-profile (dynamic) / no-profile (static) behavior. Includes an example and runtime tests covering dynamic submodule inputs.
2cd4797 to
a0eeae7
Compare
| if (profile_index == active_profile_index) { | ||
| return; | ||
| } | ||
| auto stream = c10::cuda::getCurrentCUDAStream(device_info.id); |
There was a problem hiding this comment.
Does this work with the green context pr?
| } | ||
| const auto& dims = ranges_it->second; | ||
| auto sizes = inputs[i].sizes(); | ||
| for (size_t d = 0; d < sizes.size(); ++d) { |
There was a problem hiding this comment.
Can we cache only what the dynamic dimension is for each profile and its ranges? Then we dont need to search mostly static dims
| for (const auto& name : in_binding_names) { | ||
| is_shape_inference_io[name] = cuda_engine->isShapeInferenceIO(name.c_str()); | ||
| } | ||
| if (num_optimization_profiles <= 1) { |
There was a problem hiding this comment.
What not do this first to short cut everything?
| } | ||
| } | ||
|
|
||
| void TRTEngine::set_active_profile(int64_t profile_index) { |
There was a problem hiding this comment.
I kind of feel like this function should take stream as argument and put it on the caller to give the right stream instead of getting the current stream in the function body
| bool fits = true; | ||
| for (size_t i = 0; i < in_binding_names.size() && fits; ++i) { | ||
| const auto& name = in_binding_names[i]; | ||
| if (i >= inputs.size() || is_shape_inference_io[name]) { |
There was a problem hiding this comment.
I feel like this function doesnt make a ton of sense and could potentially lead to thrashing.
Fundamentally in the auto mode I think first we should see if the active profile is valid for the set of inputs.
The mechanism we use to map from the list to an input binding name should be identical to the code we use in execute_engine. Its fundamentally the same job. In fact we should only do it once and reuse this mapping result. There should be some map from index in the input list to a name.
If the inputs fit we should short cut and return.
I also think because of this the responsibility of changing the opt profile should be on these methods and not the caller. That way we can just no-op if it fits.
| """ | ||
| for p in range(self.num_optimization_profiles): | ||
| fits = True | ||
| for i, name in enumerate(self.in_binding_names): |
There was a problem hiding this comment.
Same sort of feedback here. Construct a index -> name map, then iterate through the indexes, if they fit all just return other wise go through your profile dim cache to find the next fit.
Add support for defining N optimization profiles at compile time via the list-based
Input.profilesAPI and selecting the active profile at runtime (manual pin by index, or opt-in shape-based auto-selection).Includes an example and runtime tests covering dynamic submodule inputs.
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: