Conversation
Add missing newline at end of file for graph.cc
Removed TODO comments for memory allocation and deallocation.
Updated inferShape method to correctly compute the concatenated shape.
Added a missing newline at the end of the file and updated comments.
Update transpose function to reverse dimensions
Removed TODO comments and added clarification about shape.
There was a problem hiding this comment.
Pull request overview
This PR implements homework assignments for a deep learning framework, completing several operator shape inference functions, broadcast logic, graph optimization passes, and memory allocation. The PR title "呢哇" (Chinese characters) doesn't describe the changes.
Changes:
- Implements shape inference for Clip, Cast, Transpose, MatMul, and Concat operators
- Implements bidirectional broadcasting shape inference (infer_broadcast)
- Adds graph optimization passes to eliminate redundant transposes and fuse transposes into matmul operations
- Implements memory allocation with first-fit strategy and block coalescing
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| src/utils/operator_utils.cc | Implements broadcast shape inference and applies formatting changes to namespace braces |
| src/operators/unary.cc | Implements shape inference for Clip and Cast operators |
| src/operators/transpose.cc | Fixes default permutation to reverse dimensions per ONNX spec and implements shape inference with validation |
| src/operators/matmul.cc | Implements shape inference with batch broadcasting and transpose support |
| src/operators/concat.cc | Implements shape inference with dimension validation |
| src/core/graph.cc | Implements graph optimization passes and memory allocation with lifetime analysis |
| src/core/allocator.cc | Implements memory allocator with first-fit allocation and block coalescing |
| include/core/allocator.h | Adds data structures and helper methods for memory management |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| { | ||
| if (r == dim) | ||
| continue; | ||
| IT_ASSERT(cur[r] == dims[r], "Concat dims mismatch on non-concat axis"); |
There was a problem hiding this comment.
Similar to the other assertion, this message could be more informative by including which permutation element caused the issue. Consider: "Concat dims mismatch on non-concat axis " + std::to_string(r) + ": expected " + std::to_string(dims[r]) + ", got " + std::to_string(cur[r]).
| { | ||
| int p = transposePermute[i]; | ||
| IT_ASSERT(p >= 0 && p < rank); | ||
| IT_ASSERT(seen[p] == 0); |
There was a problem hiding this comment.
The error message for duplicate permutation indices lacks helpful details. Consider: "Duplicate permutation index " + std::to_string(p) + " at position " + std::to_string(i) + " in transpose permutation".
| std::map<size_t, size_t>::iterator Allocator::findFreeBlock(size_t size) | ||
| { | ||
| // first-fit: 找到第一个 size 足够的空闲块 | ||
| for (auto it = freeBlocks.begin(); it != freeBlocks.end(); ++it) | ||
| { | ||
| if (it->second >= size) | ||
| return it; | ||
| } | ||
| return freeBlocks.end(); |
There was a problem hiding this comment.
The findFreeBlock function uses a linear search (first-fit) through all free blocks. For workloads with many allocations and deallocations, this could become a performance bottleneck as it has O(n) complexity where n is the number of free blocks. Consider using a more efficient data structure, such as maintaining free blocks sorted by size (in addition to address) or using a best-fit strategy with size-indexed structures for better performance. However, if the number of tensors is small (typical in most ML graphs), this may not be a practical concern.
| Shape dims = inputs[0]->getDims(); | ||
| // =================================== 作业 =================================== | ||
| // TODO:修改 dims,返回正确的 concat 后的 shape | ||
| // REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13 | ||
| // =================================== 作业 =================================== | ||
|
|
||
| // =================================== 作业 =================================== | ||
| // TODO:修改 dims,返回正确的 concat 后的 shape | ||
| // REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13 | ||
| // =================================== 作业 =================================== | ||
| IT_ASSERT(!inputs.empty()); |
There was a problem hiding this comment.
The assertion that inputs is not empty (line 22) comes after dims is already initialized from inputs[0] (line 16). If inputs were empty, line 16 would access an invalid element before the assertion could catch it. Move the IT_ASSERT(!inputs.empty()) check to before line 16 to ensure the precondition is validated first.
| auto tr = as<TransposeObj>(src); | ||
| if (tr->getOutput() != in) | ||
| continue; | ||
| if (!isSwapLast2Permute(tr->getPermute())) | ||
| continue; | ||
|
|
||
| auto trIn = tr->getInputs(0); | ||
| mm->replaceInput(in, trIn); | ||
| if (inputIdx == 0) | ||
| mm->setTransA(!mm->getTransA()); | ||
| else | ||
| mm->setTransB(!mm->getTransB()); | ||
| toRemove.insert(src.get()); | ||
| changed = true; |
There was a problem hiding this comment.
Potential issue: When a transpose operator is marked for removal and added to the toRemove set, the code doesn't verify that the transpose output is only consumed by this matmul operator. If the transpose output is used by multiple consumers, removing it would break other parts of the graph. Add a check to ensure tr->getOutput()->getTargets().size() == 1 before adding to toRemove, similar to the check in rule 1 at line 157.
| m = aM; | ||
| n = bN; | ||
| k = aK; |
There was a problem hiding this comment.
The matmul m, n, k member variables are being modified during inferShape, but this method is marked as non-const and is called during shape inference. These member variables should be considered as operator attributes that are derived from the inputs during inference. However, the issue is that m, n, k are being set as side effects of inferShape, which could lead to issues if inferShape is called multiple times or in different contexts. Consider whether these should be computed on-demand via getter methods instead of being cached as mutable state, or document that these are only valid after inferShape has been called.
| else if (dimB == 1) | ||
| out[i] = dimA; | ||
| else | ||
| IT_ASSERT(false, "Broadcast shape mismatch"); |
There was a problem hiding this comment.
The assertion message "Broadcast shape mismatch" on line 31 could be more informative. Consider including the actual dimension values and indices to help with debugging. For example: "Broadcast shape mismatch at dimension " + std::to_string(i) + ": dimA=" + std::to_string(dimA) + ", dimB=" + std::to_string(dimB).
| IT_ASSERT(false, "Broadcast shape mismatch"); | |
| { | |
| std::string msg = "Broadcast shape mismatch at dimension " + | |
| std::to_string(i) + ": dimA=" + | |
| std::to_string(dimA) + ", dimB=" + | |
| std::to_string(dimB); | |
| IT_ASSERT(false, msg); | |
| } |
No description provided.