From c03c42cb66159b718696f6959f15d78370709f4b Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:34:49 +0800 Subject: [PATCH 01/17] =?UTF-8?q?=E4=B9=A0=E9=A2=98=E4=B8=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/core/allocator.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/include/core/allocator.h b/include/core/allocator.h index 002601d..ce7ef3e 100644 --- a/include/core/allocator.h +++ b/include/core/allocator.h @@ -8,7 +8,8 @@ #include #include -namespace infini { +namespace infini +{ class Allocator { private: @@ -27,6 +28,9 @@ namespace infini { // TODO:可能需要设计一个数据结构来存储free block,以便于管理和合并 // HINT: 可以使用一个 map 来存储 free block,key 为 block 的起始/结尾地址,value 为 block 的大小 // =================================== 作业 =================================== + std::map freeBlocks; + void addFreeBlock(size_t addr, size_t size); + std::map::iterator findFreeBlock(size_t size); public: Allocator(Runtime runtime); From 3ecf2475a922f72c3cff8f368b75806ee1258d30 Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:39:01 +0800 Subject: [PATCH 02/17] =?UTF-8?q?=E4=B9=A0=E9=A2=98=E4=B8=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/allocator.cc | 87 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 1 deletion(-) diff --git a/src/core/allocator.cc b/src/core/allocator.cc index ff593ae..0a695a2 100644 --- a/src/core/allocator.cc +++ b/src/core/allocator.cc @@ -33,7 +33,28 @@ namespace infini // TODO: 设计一个算法来分配内存,返回起始地址偏移量 // =================================== 作业 =================================== - return 0; + auto it = findFreeBlock(size); + if (it != freeBlocks.end()) + { + const size_t addr = it->first; + const size_t blockSize = it->second; + IT_ASSERT(blockSize >= size); + freeBlocks.erase(it); + + const size_t remain = blockSize - size; + if (remain > 0) + { + freeBlocks.emplace(addr + size, remain); + } + + used += size; + return addr; + } + + const size_t addr = peak; + peak += size; + used += size; + return addr; } void Allocator::free(size_t addr, size_t size) @@ -44,6 +65,70 @@ namespace infini // =================================== 作业 =================================== // TODO: 设计一个算法来回收内存 // =================================== 作业 =================================== + IT_ASSERT(used >= size); + used -= size; + addFreeBlock(addr, size); + } + std::map::iterator Allocator::findFreeBlock(size_t size) + { + // first-fit: 找到第一个 size 足够的空闲块 + for (auto it = freeBlocks.begin(); it != freeBlocks.end(); ++it) + { + if (it->second >= size) + return it; + } + return freeBlocks.end(); + } + + void Allocator::addFreeBlock(size_t addr, size_t size) + { + // 插入一个 free block,并与左右相邻块合并(coalescing) + auto it = freeBlocks.lower_bound(addr); + + // 尝试与左侧块合并 + if (it != freeBlocks.begin()) + { + auto left = std::prev(it); + const size_t leftAddr = left->first; + const size_t leftSize = left->second; + if (leftAddr + leftSize == addr) + { + addr = leftAddr; + size += leftSize; + freeBlocks.erase(left); + } + } + + // 尝试与右侧块合并(重新定位迭代器) + it = freeBlocks.lower_bound(addr); + if (it != freeBlocks.end()) + { + const size_t rightAddr = it->first; + const size_t rightSize = it->second; + if (addr + size == rightAddr) + { + size += rightSize; + freeBlocks.erase(it); + } + } + + freeBlocks.emplace(addr, size); + + // 若空闲块位于堆顶(addr+size==peak),则可以把 peak 往回收缩。 + // 进一步:如果收缩后的新 peak 仍然与另一个空闲块相邻,也可以继续收缩。 + while (!freeBlocks.empty()) + { + auto it = freeBlocks.upper_bound(peak); + if (it == freeBlocks.begin()) + break; + --it; + const size_t blockAddr = it->first; + const size_t blockSize = it->second; + if (blockAddr + blockSize != peak) + break; + peak = blockAddr; + freeBlocks.erase(it); + } } void *Allocator::getPtr() From 5ddc1fe2ded8ce28295c930cbd7fa5fb17b85a06 Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:39:36 +0800 Subject: [PATCH 03/17] =?UTF-8?q?=E4=B9=A0=E9=A2=98=E4=B8=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added new includes and optimized graph processing rules to eliminate redundant transpose operations and merge them into matrix multiplication attributes. --- src/core/graph.cc | 228 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 223 insertions(+), 5 deletions(-) diff --git a/src/core/graph.cc b/src/core/graph.cc index 3a90637..1a1ae0d 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -1,4 +1,7 @@ #include "core/graph.h" +#include "core/blob.h" +#include "operators/matmul.h" +#include "operators/transpose.h" #include #include #include @@ -106,6 +109,166 @@ namespace infini // 1. 去除冗余的算子(例如,两个相邻的算子都是 transpose 算子,且做的是相反的操作,可以将其全部删除) // 2. 合并算子(例如,矩阵乘算子中含有属性transA、transB,如果其输入存在transpose,且对最后两个维度做交换,就可以将transpose融入到矩阵乘算子的属性中去) // =================================== 作业 =================================== + + IT_ASSERT(topo_sort() == true); + + auto isSwapLast2Permute = [](const std::vector &perm) -> bool + { + const int r = static_cast(perm.size()); + if (r < 2) + return false; + for (int i = 0; i < r - 2; ++i) + if (perm[i] != i) + return false; + return perm[r - 2] == r - 1 && perm[r - 1] == r - 2; + }; + + auto isInversePermute = [](const std::vector &p1, + const std::vector &p2) -> bool + { + if (p1.size() != p2.size()) + return false; + const int r = static_cast(p1.size()); + std::vector inv(r, -1); + for (int i = 0; i < r; ++i) + { + const int v = p1[i]; + if (v < 0 || v >= r || inv[v] != -1) + return false; + inv[v] = i; + } + return inv == p2; + }; + + bool changed = true; + while (changed) + { + changed = false; + + // 规则 1:消除连续 transpose(perm 互逆) + for (size_t i = 0; i < ops.size(); ++i) + { + auto op1 = ops[i]; + if (op1->getOpType() != OpType::Transpose) + continue; + auto t1 = as(op1); + auto out1 = t1->getOutput(); + auto targets = out1->getTargets(); + if (targets.size() != 1) + continue; + auto op2 = targets[0]; + if (!op2 || op2->getOpType() != OpType::Transpose) + continue; + auto t2 = as(op2); + if (t2->getInputs(0) != out1) + continue; + if (!isInversePermute(t1->getPermute(), t2->getPermute())) + continue; + + auto in = t1->getInputs(0); + auto out2 = t2->getOutput(); + for (auto &consumer : out2->getTargets()) + consumer->replaceInput(out2, in); + + ops.erase(std::remove(ops.begin(), ops.end(), op1), ops.end()); + ops.erase(std::remove(ops.begin(), ops.end(), op2), ops.end()); + changed = true; + break; + } + if (changed) + continue; + + // 规则 2:将 transpose(交换最后两维) 融合到 matmul 的 transA/transB + std::unordered_set toRemove; + for (auto &op : ops) + { + if (op->getOpType() != OpType::MatMul) + continue; + auto mm = as(op); + for (int inputIdx = 0; inputIdx < 2; ++inputIdx) + { + auto in = mm->getInputs(inputIdx); + auto src = in->getSource(); + if (!src || src->getOpType() != OpType::Transpose) + continue; + auto tr = as(src); + if (tr->getOutput() != in) + continue; + if (!isSwapLast2Permute(tr->getPermute())) + continue; + + auto trIn = tr->getInputs(0); + mm->replaceInput(in, trIn); + if (inputIdx == 0) + mm->setTransA(!mm->getTransA()); + else + mm->setTransB(!mm->getTransB()); + toRemove.insert(src.get()); + changed = true; + } + } + if (!toRemove.empty()) + { + ops.erase(std::remove_if(ops.begin(), ops.end(), + [&](const Operator &op) + { return toRemove.count(op.get()) != 0; }), + ops.end()); + } + } + + // 清理不再被任何算子引用的张量 + { + std::unordered_set referenced; + referenced.reserve(tensors.size()); + for (auto &op : ops) + { + for (auto &t : op->getInputs()) + referenced.insert(t.get()); + for (auto &t : op->getOutputs()) + referenced.insert(t.get()); + } + TensorVec kept; + kept.reserve(tensors.size()); + for (auto &t : tensors) + if (referenced.count(t.get()) != 0) + kept.emplace_back(t); + tensors = std::move(kept); + } + + // 重新构建 pred/succ 与 tensor source/target + for (auto &t : tensors) + { + t->targets.clear(); + t->source.reset(); + } + for (auto &op : ops) + { + op->predecessors.clear(); + op->successors.clear(); + } + for (auto &op : ops) + { + for (auto &input : op->getInputs()) + { + if (input) + { + input->addTarget(op); + if (auto pred = input->getSource()) + { + pred->addSuccessors(op); + op->addPredecessors(pred); + } + } + } + for (auto &output : op->getOutputs()) + { + if (output) + output->setSource(op); + } + } + + sorted = false; + IT_ASSERT(topo_sort() == true); } Tensor GraphObj::getTensor(int fuid) const @@ -148,10 +311,65 @@ namespace infini // topological sorting first IT_ASSERT(topo_sort() == true); - // =================================== 作业 =================================== - // TODO:利用 allocator 给计算图分配内存 - // HINT: 获取分配好的内存指针后,可以调用 tensor 的 setDataBlob 函数给 tensor 绑定内存 - // =================================== 作业 =================================== + std::unordered_map remainingUses; + std::unordered_map bytes; + std::unordered_map offsets; + remainingUses.reserve(tensors.size()); + bytes.reserve(tensors.size()); + offsets.reserve(tensors.size()); + + std::unordered_set keepAlive; + keepAlive.reserve(tensors.size()); + for (auto &t : tensors) + { + bytes[t.get()] = t->getBytes(); + remainingUses[t.get()] = static_cast(t->getTargets().size()); + if (t->getTargets().empty()) + keepAlive.insert(t.get()); + } + + auto ensureAlloc = [&](const Tensor &t) + { + auto *p = t.get(); + if (offsets.find(p) == offsets.end()) + offsets[p] = allocator.alloc(bytes[p]); + }; + + // 输入张量:dataMalloc 后会 setData + for (auto &t : getInputs()) + ensureAlloc(t); + + // 遍历 op:分配输出、回收“已完成最后一次使用”的输入 + for (auto &op : ops) + { + for (auto &out : op->getOutputs()) + ensureAlloc(out); + + for (auto &in : op->getInputs()) + { + auto *p = in.get(); + auto it = remainingUses.find(p); + if (it == remainingUses.end()) + continue; + if (it->second > 0) + --(it->second); + if (it->second == 0 && keepAlive.count(p) == 0) + { + auto offIt = offsets.find(p); + if (offIt != offsets.end()) + allocator.free(offIt->second, bytes[p]); + } + } + } + + void *base = allocator.getPtr(); + for (auto &t : tensors) + { + auto it = offsets.find(t.get()); + IT_ASSERT(it != offsets.end(), "Tensor not allocated in dataMalloc"); + void *ptr = static_cast(static_cast(base) + it->second); + t->setDataBlob(make_ref(runtime, ptr)); + } allocator.info(); } @@ -227,4 +445,4 @@ namespace infini return true; } -} // namespace infini \ No newline at end of file +} // namespace infini From a7a23f00668530291f1c698314dc0766ceb7c4a6 Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:42:56 +0800 Subject: [PATCH 04/17] =?UTF-8?q?=E4=B9=A0=E9=A2=98=E4=BA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement shape inference for Transpose operation. --- src/operators/transpose.cc | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/operators/transpose.cc b/src/operators/transpose.cc index faab2b6..8f367cd 100644 --- a/src/operators/transpose.cc +++ b/src/operators/transpose.cc @@ -26,15 +26,29 @@ namespace infini { const auto A = inputs[0]; auto input_dim = A->getDims(); - auto output_dim = input_dim; - int rank = A->getRank(); // =================================== 作业 =================================== // TODO:修改 output_dim,返回正确的 transpose 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-21 // =================================== 作业 =================================== - return std::nullopt; + const int rank = static_cast(A->getRank()); + IT_ASSERT(static_cast(transposePermute.size()) == rank); + + std::vector seen(rank, 0); + for (int i = 0; i < rank; ++i) + { + int p = transposePermute[i]; + IT_ASSERT(p >= 0 && p < rank); + IT_ASSERT(seen[p] == 0); + seen[p] = 1; + } + + Shape output_dim(rank); + for (int i = 0; i < rank; ++i) + output_dim[i] = input_dim[transposePermute[i]]; + + return {{output_dim}}; } std::string TransposeObj::toString() const From 351e8d00dcaeaddbe16fc904c79e854af8f5174c Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:45:20 +0800 Subject: [PATCH 05/17] =?UTF-8?q?=E4=B9=A0=E9=A2=98=E4=B8=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/operators/unary.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/operators/unary.cc b/src/operators/unary.cc index 3daad36..a957ae8 100644 --- a/src/operators/unary.cc +++ b/src/operators/unary.cc @@ -39,7 +39,8 @@ namespace infini // TODO:返回经过 clip 操作后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Clip.html#clip-13 // =================================== 作业 =================================== - return std::nullopt; + const auto A = inputs[0]; + return {{A->getDims()}}; } std::string ClipObj::toString() const @@ -66,7 +67,8 @@ namespace infini // REF_FILE: src/core/operator.cc // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 // =================================== 作业 =================================== - return {}; + (void)inputs; + return {getOutputDataType()}; } optional> CastObj::inferShape(const TensorVec &inputs) @@ -75,7 +77,8 @@ namespace infini // TODO:返回经过 cast 操作后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 // =================================== 作业 =================================== - return std::nullopt; + const auto A = inputs[0]; + return {{A->getDims()}}; } std::string CastObj::toString() const From 4704fb6ae2bac14744fedeabb362b3b5e2331c85 Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:47:40 +0800 Subject: [PATCH 06/17] =?UTF-8?q?=E4=BD=9C=E4=B8=9A=E4=BA=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/operators/concat.cc | 79 ++++++++++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 29 deletions(-) diff --git a/src/operators/concat.cc b/src/operators/concat.cc index d196330..9c4e0e6 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -1,38 +1,59 @@ #include "operators/concat.h" #include "utils/operator_utils.h" -namespace infini { -ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim) - : OperatorObj(OpType::Concat, inputs, {output}) { - int rank = inputs[0]->getRank(); - dim = get_real_axis(_dim, rank); - IT_ASSERT(checkValid(graph)); -} +namespace infini +{ + ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim) + : OperatorObj(OpType::Concat, inputs, {output}) + { + int rank = inputs[0]->getRank(); + dim = get_real_axis(_dim, rank); + IT_ASSERT(checkValid(graph)); + } -optional> ConcatObj::inferShape(const TensorVec &inputs) { - Shape dims = inputs[0]->getDims(); - auto rank = inputs[0]->getRank(); + optional> ConcatObj::inferShape(const TensorVec &inputs) + { + Shape dims = inputs[0]->getDims(); + auto rank = inputs[0]->getRank(); - // =================================== 作业 =================================== - // TODO:修改 dims,返回正确的 concat 后的 shape - // REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13 - // =================================== 作业 =================================== + // =================================== 作业 =================================== + // TODO:修改 dims,返回正确的 concat 后的 shape + // REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13 + // =================================== 作业 =================================== + IT_ASSERT(!inputs.empty()); + const int rank = static_cast(inputs[0]->getRank()); + IT_ASSERT(dim >= 0 && dim < rank); - return {{dims}}; -} + int sumDim = dims[dim]; + for (size_t i = 1; i < inputs.size(); ++i) + { + IT_ASSERT(static_cast(inputs[i]->getRank()) == rank); + const auto &cur = inputs[i]->getDims(); + for (int r = 0; r < rank; ++r) + { + if (r == dim) + continue; + IT_ASSERT(cur[r] == dims[r], "Concat dims mismatch on non-concat axis"); + } + sumDim += cur[dim]; + } + dims[dim] = sumDim; + return {{dims}}; + } -std::string ConcatObj::toString() const { - std::ostringstream os; - os << "Concat[" << getGuid() << "]"; - os << "("; - for (auto input : inputs) - os << vecToString(input->getDims()) << ","; - os << "dim=" << dim << ","; - os << "input="; - for (auto input : inputs) - os << input->getGuid() << ","; - os << "output=" << outputs[0]->getGuid() << ")"; - return os.str(); -} + std::string ConcatObj::toString() const + { + std::ostringstream os; + os << "Concat[" << getGuid() << "]"; + os << "("; + for (auto input : inputs) + os << vecToString(input->getDims()) << ","; + os << "dim=" << dim << ","; + os << "input="; + for (auto input : inputs) + os << input->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); + } } // namespace infini From 5307d0ec136ee9549cc3166628320c0c243e1d1a Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:49:45 +0800 Subject: [PATCH 07/17] =?UTF-8?q?=E4=B9=A0=E9=A2=98=E5=85=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/utils/operator_utils.cc | 137 ++++++++++++++++++++++-------------- 1 file changed, 84 insertions(+), 53 deletions(-) diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc index edbd2c8..dc1915c 100644 --- a/src/utils/operator_utils.cc +++ b/src/utils/operator_utils.cc @@ -1,69 +1,100 @@ #include "utils/operator_utils.h" #include "core/runtime.h" -namespace infini { +namespace infini +{ -Shape infer_broadcast(const Shape &A, const Shape &B) { + Shape infer_broadcast(const Shape &A, const Shape &B) + { - // =================================== 作业 =================================== - // TODO:对 A 和 B 进行双向广播,返回广播后的形状。 - // REF: https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md - // =================================== 作业 =================================== - - return {}; -} + // =================================== 作业 =================================== + // TODO:对 A 和 B 进行双向广播,返回广播后的形状。 + // REF: https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md + // =================================== 作业 =================================== + const size_t rankA = A.size(); + const size_t rankB = B.size(); + const size_t rank = std::max(rankA, rankB); + Shape out(rank, 1); -int get_real_axis(const int &axis, const int &rank) { - IT_ASSERT(rank >= 1); - IT_ASSERT(axis >= -rank && axis <= (rank - 1)); - int newAxis; - if (axis < 0) { - newAxis = rank + axis; - } else { - newAxis = axis; + for (size_t i = 0; i < rank; ++i) + { + const int dimA = (i < rank - rankA) ? 1 : A[i - (rank - rankA)]; + const int dimB = (i < rank - rankB) ? 1 : B[i - (rank - rankB)]; + IT_ASSERT(dimA >= 0 && dimB >= 0); + if (dimA == dimB) + out[i] = dimA; + else if (dimA == 1) + out[i] = dimB; + else if (dimB == 1) + out[i] = dimA; + else + IT_ASSERT(false, "Broadcast shape mismatch"); + } + return out; } - return newAxis; -} -Shape locate_index(size_t inputN, const Shape &shape) { - Shape ans(shape.size()); - auto i = ans.rbegin(); - auto j = shape.rbegin(), ej = shape.rend(); - while (j != ej) { - auto div = std::div(inputN, *j++); - *i++ = div.rem; - inputN = div.quot; + int get_real_axis(const int &axis, const int &rank) + { + IT_ASSERT(rank >= 1); + IT_ASSERT(axis >= -rank && axis <= (rank - 1)); + int newAxis; + if (axis < 0) + { + newAxis = rank + axis; + } + else + { + newAxis = axis; + } + return newAxis; } - return ans; -} -size_t delocate_index(const Shape &shapeIndex, const Shape &shape, - const Shape &stride) { - size_t ans = 0; - Shape index(shapeIndex.size()); - IT_ASSERT(shapeIndex.size() == shape.size()); - IT_ASSERT(shape.size() == stride.size()); - for (size_t i = 0; i < shape.size(); ++i) { - index[i] = shapeIndex[i] % shape[i]; - ans += index[i] * stride[i]; + Shape locate_index(size_t inputN, const Shape &shape) + { + Shape ans(shape.size()); + auto i = ans.rbegin(); + auto j = shape.rbegin(), ej = shape.rend(); + while (j != ej) + { + auto div = std::div(inputN, *j++); + *i++ = div.rem; + inputN = div.quot; + } + return ans; } - return ans; -} -std::string device_to_str(Device device) { - std::string deviceStr; - switch (device) { - case Device::CPU: - return "CPU"; - default: - IT_TODO_HALT(); + size_t delocate_index(const Shape &shapeIndex, const Shape &shape, + const Shape &stride) + { + size_t ans = 0; + Shape index(shapeIndex.size()); + IT_ASSERT(shapeIndex.size() == shape.size()); + IT_ASSERT(shape.size() == stride.size()); + for (size_t i = 0; i < shape.size(); ++i) + { + index[i] = shapeIndex[i] % shape[i]; + ans += index[i] * stride[i]; + } + return ans; } -} -std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs) { - std::string deviceStr = device_to_str(std::get<0>(kernelAttrs)); - std::string opStr = OpType(std::get<1>(kernelAttrs)).toString(); - return deviceStr + ", " + opStr; -} + std::string device_to_str(Device device) + { + std::string deviceStr; + switch (device) + { + case Device::CPU: + return "CPU"; + default: + IT_TODO_HALT(); + } + } + + std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs) + { + std::string deviceStr = device_to_str(std::get<0>(kernelAttrs)); + std::string opStr = OpType(std::get<1>(kernelAttrs)).toString(); + return deviceStr + ", " + opStr; + } } // namespace infini From d7423b5a2f44f86481f9fbdb25c8abd45b8c01b2 Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:51:08 +0800 Subject: [PATCH 08/17] =?UTF-8?q?=E4=B9=A0=E9=A2=98=E4=B8=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement shape inference for matmul operation --- src/operators/matmul.cc | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 7a16ca2..e15e037 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -1,5 +1,5 @@ #include "operators/matmul.h" - +#include "utils/operator_utils.h" namespace infini { @@ -27,7 +27,35 @@ namespace infini // TODO:返回经过 matmul 操作后的 shape // REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm // =================================== 作业 =================================== - return std::nullopt; + IT_ASSERT(inputs.size() == 2); + const auto A = inputs[0]; + const auto B = inputs[1]; + + const auto &aDims = A->getDims(); + const auto &bDims = B->getDims(); + IT_ASSERT(aDims.size() >= 2); + IT_ASSERT(bDims.size() >= 2); + + // Batch dims: leading dims except last 2 + Shape aBatch(aDims.begin(), aDims.end() - 2); + Shape bBatch(bDims.begin(), bDims.end() - 2); + Shape outBatch = infer_broadcast(aBatch, bBatch); + + // Matrix dims (row-major): A(..., M, K) * B(..., K, N) = C(..., M, N) + const int aM = transA ? aDims[aDims.size() - 1] : aDims[aDims.size() - 2]; + const int aK = transA ? aDims[aDims.size() - 2] : aDims[aDims.size() - 1]; + const int bK = transB ? bDims[bDims.size() - 1] : bDims[bDims.size() - 2]; + const int bN = transB ? bDims[bDims.size() - 2] : bDims[bDims.size() - 1]; + IT_ASSERT(aK == bK, "Matmul K dimension mismatch"); + + m = aM; + n = bN; + k = aK; + + Shape out = outBatch; + out.push_back(m); + out.push_back(n); + return {{out}}; } -} // namespace infini \ No newline at end of file +} // namespace infini From f2114d771c8a8cd6f28033b98d11c9c66264ea06 Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:55:07 +0800 Subject: [PATCH 09/17] =?UTF-8?q?=E4=B9=A0=E9=A2=98=E5=85=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix missing newline at end of file in graph.cc From 6041ff65b0165f28db9dde17826f6cfb3c911643 Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 22:08:29 +0800 Subject: [PATCH 10/17] =?UTF-8?q?=E4=B9=A0=E9=A2=98=E5=85=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add newline at the end of graph.cc file --- src/core/graph.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/graph.cc b/src/core/graph.cc index 1a1ae0d..bb97f3d 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -445,4 +445,4 @@ namespace infini return true; } -} // namespace infini +} From 5ceb5398ea8652dc080ac296653e9bb2ede37d71 Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 23:00:50 +0800 Subject: [PATCH 11/17] =?UTF-8?q?=E4=B9=A0=E9=A2=98=E4=B8=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/core/allocator.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/include/core/allocator.h b/include/core/allocator.h index ce7ef3e..bced4fe 100644 --- a/include/core/allocator.h +++ b/include/core/allocator.h @@ -24,10 +24,6 @@ namespace infini // pointer to the memory actually allocated void *ptr; - // =================================== 作业 =================================== - // TODO:可能需要设计一个数据结构来存储free block,以便于管理和合并 - // HINT: 可以使用一个 map 来存储 free block,key 为 block 的起始/结尾地址,value 为 block 的大小 - // =================================== 作业 =================================== std::map freeBlocks; void addFreeBlock(size_t addr, size_t size); std::map::iterator findFreeBlock(size_t size); From 339cc12f5179f8ff417885e8d95c0dd6cd0cd7db Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 23:02:10 +0800 Subject: [PATCH 12/17] Fix missing newline at end of graph.cc Add missing newline at end of file for graph.cc --- src/core/graph.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/graph.cc b/src/core/graph.cc index bb97f3d..1a1ae0d 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -445,4 +445,4 @@ namespace infini return true; } -} +} // namespace infini From b6b6339cb7043cb6128719b89e61ab0099709fb1 Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 23:02:46 +0800 Subject: [PATCH 13/17] Remove TODO comments in allocator.cc Removed TODO comments for memory allocation and deallocation. --- src/core/allocator.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/core/allocator.cc b/src/core/allocator.cc index 0a695a2..95a81ee 100644 --- a/src/core/allocator.cc +++ b/src/core/allocator.cc @@ -28,7 +28,6 @@ namespace infini IT_ASSERT(this->ptr == nullptr); // pad the size to the multiple of alignment size = this->getAlignedSize(size); - // =================================== 作业 =================================== // TODO: 设计一个算法来分配内存,返回起始地址偏移量 // =================================== 作业 =================================== @@ -62,13 +61,11 @@ namespace infini IT_ASSERT(this->ptr == nullptr); size = getAlignedSize(size); - // =================================== 作业 =================================== - // TODO: 设计一个算法来回收内存 - // =================================== 作业 =================================== IT_ASSERT(used >= size); used -= size; addFreeBlock(addr, size); } + std::map::iterator Allocator::findFreeBlock(size_t size) { // first-fit: 找到第一个 size 足够的空闲块 From 26ec13c85752bc8ca4da96856a58b65145a18d17 Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 23:03:45 +0800 Subject: [PATCH 14/17] Fix shape inference in ConcatObj::inferShape Updated inferShape method to correctly compute the concatenated shape. --- src/operators/concat.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operators/concat.cc b/src/operators/concat.cc index 9c4e0e6..62909b0 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -14,12 +14,11 @@ namespace infini optional> ConcatObj::inferShape(const TensorVec &inputs) { Shape dims = inputs[0]->getDims(); - auto rank = inputs[0]->getRank(); - // =================================== 作业 =================================== // TODO:修改 dims,返回正确的 concat 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13 // =================================== 作业 =================================== + IT_ASSERT(!inputs.empty()); const int rank = static_cast(inputs[0]->getRank()); IT_ASSERT(dim >= 0 && dim < rank); @@ -38,6 +37,7 @@ namespace infini sumDim += cur[dim]; } dims[dim] = sumDim; + return {{dims}}; } From bc26b8d8e037a8906d482f3fc689ed87ea6109ca Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 23:04:17 +0800 Subject: [PATCH 15/17] Fix formatting and comments in matmul.cc Added a missing newline at the end of the file and updated comments. --- src/operators/matmul.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index e15e037..eca81ab 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -1,5 +1,6 @@ #include "operators/matmul.h" #include "utils/operator_utils.h" + namespace infini { @@ -27,6 +28,7 @@ namespace infini // TODO:返回经过 matmul 操作后的 shape // REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm // =================================== 作业 =================================== + IT_ASSERT(inputs.size() == 2); const auto A = inputs[0]; const auto B = inputs[1]; From a72ca8919a75b0c979769924e57dc91e6acaeaaf Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 23:04:43 +0800 Subject: [PATCH 16/17] Modify transposePermute to reverse order of indices Update transpose function to reverse dimensions --- src/operators/transpose.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/operators/transpose.cc b/src/operators/transpose.cc index 8f367cd..343dd5d 100644 --- a/src/operators/transpose.cc +++ b/src/operators/transpose.cc @@ -9,10 +9,9 @@ namespace infini auto rank = input->getRank(); if (permute.empty()) { + transposePermute.resize(rank); for (size_t i = 0; i < rank; ++i) - { - transposePermute[i] = i; - } + transposePermute[i] = static_cast(rank - 1 - i); } else { @@ -26,7 +25,6 @@ namespace infini { const auto A = inputs[0]; auto input_dim = A->getDims(); - // =================================== 作业 =================================== // TODO:修改 output_dim,返回正确的 transpose 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-21 From 40272e2ded17867d33eec40c7cbe4c35d332a44c Mon Sep 17 00:00:00 2001 From: Xiaobanli-new <61371975+Xiaobanli-new@users.noreply.github.com> Date: Tue, 3 Feb 2026 23:05:13 +0800 Subject: [PATCH 17/17] Refactor inferDataType and inferShape methods Removed TODO comments and added clarification about shape. --- src/operators/unary.cc | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/operators/unary.cc b/src/operators/unary.cc index a957ae8..df2544b 100644 --- a/src/operators/unary.cc +++ b/src/operators/unary.cc @@ -62,21 +62,13 @@ namespace infini vector CastObj::inferDataType(const TensorVec &inputs) const { - // =================================== 作业 =================================== - // TODO:返回经过 cast 操作后, 输出 tensor 的数目和数据类型 - // REF_FILE: src/core/operator.cc - // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 - // =================================== 作业 =================================== (void)inputs; return {getOutputDataType()}; } optional> CastObj::inferShape(const TensorVec &inputs) { - // =================================== 作业 =================================== - // TODO:返回经过 cast 操作后的 shape - // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 - // =================================== 作业 =================================== + // Cast 不改变形状 const auto A = inputs[0]; return {{A->getDims()}}; }