From 552b17fb269b209b3f5485f31e205b863fbd3304 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 2 Feb 2026 18:54:45 +0800 Subject: [PATCH 1/3] 1 --- .../TinyInfiniTensor/src/core/graph.cc | 151 ++++++++++ include/core/allocator.h | 8 +- src/core/allocator.cc | 93 +++++- src/core/graph.cc | 272 ++++++++++++++++-- src/operators/concat.cc | 13 +- src/operators/matmul.cc | 76 ++++- src/operators/transpose.cc | 8 +- src/operators/unary.cc | 10 +- src/utils/operator_utils.cc | 30 +- 9 files changed, 625 insertions(+), 36 deletions(-) create mode 100644 home/ros/lectures/ai-complier/TinyInfiniTensor/src/core/graph.cc diff --git a/home/ros/lectures/ai-complier/TinyInfiniTensor/src/core/graph.cc b/home/ros/lectures/ai-complier/TinyInfiniTensor/src/core/graph.cc new file mode 100644 index 0000000..d5b991a --- /dev/null +++ b/home/ros/lectures/ai-complier/TinyInfiniTensor/src/core/graph.cc @@ -0,0 +1,151 @@ +void GraphObj::optimize() +{ + // Based on the test case, we need to implement optimizations to merge/remove transpose operations + // and combine them with matmul operations + + bool changed = true; + while(changed) { + changed = false; + + // Look for consecutive transpose operations that cancel each other out + for(size_t i = 0; i < ops.size(); i++) { + auto op1 = ops[i]; + if(op1->getOpType() != OpType::Transpose) continue; + + auto transpose1 = dynamic_pointer_cast(op1); + if(transpose1->getOutputs().empty()) continue; + + auto mid_tensor = transpose1->getOutputs()[0]; + auto consumers = mid_tensor->getTargets(); + + // If this tensor only goes to another transpose operation + if(consumers.size() == 1 && consumers[0]->getOpType() == OpType::Transpose) { + auto transpose2 = dynamic_pointer_cast(consumers[0]); + + // Check if the transposes cancel each other out (are inverses) + auto perm1 = transpose1->getPermute(); + auto perm2 = transpose2->getPermute(); + + bool is_inverse = true; + if(perm1.size() != perm2.size()) { + is_inverse = false; + } else { + for(size_t idx = 0; idx < perm1.size(); idx++) { + if(perm2[perm1[idx]] != (int)idx) { + is_inverse = false; + break; + } + } + } + + if(is_inverse) { + // Remove both transposes by connecting input directly to final consumers + + // Get the original input tensor and final output tensor + auto input_tensor = transpose1->getInputs()[0]; + auto output_tensor = transpose2->getOutputs()[0]; + + // Update all operators that consume the output of transpose2 + // to instead consume the input of transpose1 + auto final_consumers = output_tensor->getTargets(); + for(auto consumer : final_consumers) { + for(size_t j = 0; j < consumer->getInputs().size(); j++) { + if(consumer->getInputs(j) == output_tensor) { + consumer->replaceInput(output_tensor, input_tensor); + } + } + } + + // Update tensor graph connections + input_tensor->removeTarget(op1); + mid_tensor->removeTarget(transpose2); + mid_tensor->removeSource(); + output_tensor->removeSource(); + + // Remove the two transpose operations + ops.erase(std::remove(ops.begin(), ops.end(), op1), ops.end()); + ops.erase(std::remove(ops.begin(), ops.end(), transpose2), ops.end()); + + // Remove the intermediate tensors + tensors.erase(std::remove(tensors.begin(), tensors.end(), mid_tensor), tensors.end()); + tensors.erase(std::remove(tensors.begin(), tensors.end(), output_tensor), tensors.end()); + + changed = true; + break; + } + } + } + + if(changed) continue; + + // Look for transpose followed by matmul that can be fused + for(size_t i = 0; i < ops.size(); i++) { + auto op = ops[i]; + if(op->getOpType() != OpType::Matmul) continue; + + auto matmul = dynamic_pointer_cast(op); + + // Check first input for transpose fusion + if(!matmul->getInputs(0)->getSource()) continue; + auto prev_op = matmul->getInputs(0)->getSource(); + if(prev_op->getOpType() == OpType::Transpose) { + auto transpose = dynamic_pointer_cast(prev_op); + // Check if transpose swaps last 2 dimensions (common case for matmul) + auto perm = transpose->getPermute(); + if(perm.size() >= 2 && + perm[perm.size()-2] == (int)perm.size()-1 && + perm[perm.size()-1] == (int)perm.size()-2) { + // Fuse: set transA flag and update matmul + matmul->setTransA(true); + + // Update the matmul to take input from the transpose's input + auto trans_input = transpose->getInputs()[0]; + matmul->replaceInput(matmul->getInputs(0), trans_input); + + // Update tensor connections + matmul->getInputs(0)->removeTarget(op); + trans_input->addTarget(op); + + // Remove transpose and intermediate tensor + ops.erase(std::remove(ops.begin(), ops.end(), prev_op), ops.end()); + tensors.erase(std::remove(tensors.begin(), tensors.end(), matmul->getInputs(0)), tensors.end()); + + changed = true; + break; + } + } + + if(changed) continue; + + // Check second input for transpose fusion + if(!matmul->getInputs(1)->getSource()) continue; + prev_op = matmul->getInputs(1)->getSource(); + if(prev_op->getOpType() == OpType::Transpose) { + auto transpose = dynamic_pointer_cast(prev_op); + // Check if transpose swaps last 2 dimensions (common case for matmul) + auto perm = transpose->getPermute(); + if(perm.size() >= 2 && + perm[perm.size()-2] == (int)perm.size()-1 && + perm[perm.size()-1] == (int)perm.size()-2) { + // Fuse: set transB flag and update matmul + matmul->setTransB(true); + + // Update the matmul to take input from the transpose's input + auto trans_input = transpose->getInputs()[0]; + matmul->replaceInput(matmul->getInputs(1), trans_input); + + // Update tensor connections + matmul->getInputs(1)->removeTarget(op); + trans_input->addTarget(op); + + // Remove transpose and intermediate tensor + ops.erase(std::remove(ops.begin(), ops.end(), prev_op), ops.end()); + tensors.erase(std::remove(tensors.begin(), tensors.end(), matmul->getInputs(1)), tensors.end()); + + changed = true; + break; + } + } + } + } +} \ No newline at end of file diff --git a/include/core/allocator.h b/include/core/allocator.h index 002601d..5ac5240 100644 --- a/include/core/allocator.h +++ b/include/core/allocator.h @@ -27,7 +27,11 @@ namespace infini { // TODO:可能需要设计一个数据结构来存储free block,以便于管理和合并 // HINT: 可以使用一个 map 来存储 free block,key 为 block 的起始/结尾地址,value 为 block 的大小 // =================================== 作业 =================================== - + + // 使用 map 存储空闲块,key 为起始地址,value 为块大小 + // map 会自动按地址排序,便于查找和合并相邻块 + std::map free_blocks; + public: Allocator(Runtime runtime); @@ -56,4 +60,4 @@ namespace infini { // return: size of the aligned memory block size_t getAlignedSize(size_t size); }; -} +} \ No newline at end of file diff --git a/src/core/allocator.cc b/src/core/allocator.cc index ff593ae..7a2b0ec 100644 --- a/src/core/allocator.cc +++ b/src/core/allocator.cc @@ -32,18 +32,103 @@ namespace infini // =================================== 作业 =================================== // TODO: 设计一个算法来分配内存,返回起始地址偏移量 // =================================== 作业 =================================== - - return 0; + + // First Fit 策略:遍历所有空闲块,找到第一个足够大的块 + for (auto it = free_blocks.begin(); it != free_blocks.end(); ++it) + { + size_t block_addr = it->first; + size_t block_size = it->second; + + // 情况1: 找到足够大的空闲块,直接使用 + if (block_size >= size) + { + // 从空闲列表中移除该块 + free_blocks.erase(it); + + // 如果有剩余空间,将剩余部分重新加入空闲列表 + if (block_size > size) + { + free_blocks[block_addr + size] = block_size - size; + } + + // 更新已使用内存 + this->used += size; + + return block_addr; + } + // 情况2: 该块在内存末尾(block_addr + block_size == peak) + // 即使块不够大,也可以扩展使用 + else if (block_addr + block_size == this->peak) + { + // 从空闲列表中移除该块 + free_blocks.erase(it); + + // 计算还需要多少额外空间 + size_t extra_needed = size - block_size; + + // 扩展 peak + this->peak += extra_needed; + + // 更新已使用内存 + this->used += size; + + return block_addr; + } + } + + // 没有找到任何空闲块,在内存末尾分配新空间 + size_t block_addr = this->peak; + this->used += size; + this->peak += size; + + return block_addr; } void Allocator::free(size_t addr, size_t size) { IT_ASSERT(this->ptr == nullptr); size = getAlignedSize(size); - + // =================================== 作业 =================================== // TODO: 设计一个算法来回收内存 // =================================== 作业 =================================== + + // 更新已使用内存 + this->used -= size; + + // 将释放的块加入空闲列表 + free_blocks[addr] = size; + + // 获取当前释放块的迭代器 + auto current = free_blocks.find(addr); + + // 尝试与后面的块合并 + auto next = std::next(current); + if (next != free_blocks.end()) + { + // 如果当前块的结尾恰好是下一个块的开始 + if (current->first + current->second == next->first) + { + // 合并:扩展当前块的大小 + current->second += next->second; + // 删除下一个块 + free_blocks.erase(next); + } + } + + // 尝试与前面的块合并 + if (current != free_blocks.begin()) + { + auto prev = std::prev(current); + // 如果前一个块的结尾恰好是当前块的开始 + if (prev->first + prev->second == current->first) + { + // 合并:扩展前一个块的大小 + prev->second += current->second; + // 删除当前块 + free_blocks.erase(current); + } + } } void *Allocator::getPtr() @@ -66,4 +151,4 @@ namespace infini std::cout << "Used memory: " << this->used << ", peak memory: " << this->peak << std::endl; } -} +} \ No newline at end of file diff --git a/src/core/graph.cc b/src/core/graph.cc index 3a90637..71bbf1a 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -2,6 +2,8 @@ #include #include #include +#include "operators/matmul.h" +#include "operators/transpose.h" namespace infini { @@ -98,15 +100,213 @@ namespace infini return this->sorted = true; } - void GraphObj::optimize() +void GraphObj::optimize() +{ + // =================================== 作业 =================================== + // TODO: 设计一个算法来实现指定的图优化规则 + // 图优化规则如下: + // 1. 去除冗余的算子(例如,两个相邻的算子都是 transpose 算子,且做的是相反的操作,可以将其全部删除) + // 2. 合并算子(例如,矩阵乘算子中含有属性transA、transB,如果其输入存在transpose,且对最后两个维度做交换,就可以将transpose融入到矩阵乘算子的属性中去) + // =================================== 作业 =================================== + + // 获取图的输出tensor + auto graphOutputs = getOutputs(); + + bool changed = true; + while (changed) { - // =================================== 作业 =================================== - // TODO: 设计一个算法来实现指定的图优化规则 - // 图优化规则如下: - // 1. 去除冗余的算子(例如,两个相邻的算子都是 transpose 算子,且做的是相反的操作,可以将其全部删除) - // 2. 合并算子(例如,矩阵乘算子中含有属性transA、transB,如果其输入存在transpose,且对最后两个维度做交换,就可以将transpose融入到矩阵乘算子的属性中去) - // =================================== 作业 =================================== + changed = false; + + // === 规则1: 去除冗余的Transpose(互逆抵消)=== + for (auto it = ops.begin(); it != ops.end(); ++it) + { + auto op = *it; + + if (op->getOpType() == OpType::Transpose) + { + auto trans2 = std::static_pointer_cast(op); + auto inputTensor = trans2->getInputs(0); + auto srcOp = inputTensor->getSource(); + + if (srcOp && srcOp->getOpType() == OpType::Transpose) + { + auto trans1 = std::static_pointer_cast(srcOp); + auto perm1 = trans1->getPermute(); + auto perm2 = trans2->getPermute(); + + // 检查两个perm是否互逆:perm1[perm2[i]] == i + bool isInverse = true; + if (perm1.size() != perm2.size()) + { + isInverse = false; + } + else + { + for (size_t i = 0; i < perm1.size(); ++i) + { + if (perm2[perm1[i]] != (int)i) + { + isInverse = false; + break; + } + } + } + + if (isInverse) + { + // 获取原始输入和最终输出 + auto originalInput = trans1->getInputs(0); + auto outputTensor = trans2->getOutput(); + + // 保存output的所有后继 + auto targets = outputTensor->getTargets(); + + // 将后继的输入从outputTensor改为originalInput + for (auto &succ : targets) + { + succ->replaceInput(outputTensor, originalInput); + // 更新tensor的target关系 + originalInput->addTarget(succ); + + // 更新算子关系 + succ->removePredecessors(trans2); + trans2->removeSuccessors(succ); + } + + // 更新trans2的输入关系 + inputTensor->removeTarget(trans2); + + // 更新trans1的关系 + originalInput->removeTarget(trans1); + inputTensor->setSource(nullptr); + + // 从图中删除这两个算子和中间tensor + removeOperator(trans2); + removeOperator(trans1); + + // 删除中间tensor + removeTensor(inputTensor); + removeTensor(outputTensor); + + changed = true; + + // 由于迭代器可能失效,跳出当前循环 + break; + } + } + } + } + + if (changed) continue; // 如果有变化,继续下一轮循环 + + // === 规则2: MatMul融合Transpose === + for (auto it = ops.begin(); it != ops.end(); ++it) + { + auto op = *it; + + if (op->getOpType() == OpType::MatMul) + { + auto matmul = std::static_pointer_cast(op); + bool modifiedMatmul = false; + + for (int i = 0; i < 2; ++i) + { + auto input = matmul->getInputs(i); + auto srcOp = input->getSource(); + + if (srcOp && srcOp->getOpType() == OpType::Transpose) + { + auto trans = std::static_pointer_cast(srcOp); + auto perm = trans->getPermute(); + int rank = perm.size(); + + // 检查是否为交换最后两维的transpose + if (rank >= 2) + { + bool isLastTwoSwap = true; + for (int j = 0; j < rank - 2; ++j) + { + if (perm[j] != j) + { + isLastTwoSwap = false; + break; + } + } + if (isLastTwoSwap && perm[rank-2] == rank-1 && perm[rank-1] == rank-2) + { + // 修改MatMul属性 + if (i == 0) + { + matmul->setTransA(!matmul->getTransA()); + } + else + { + matmul->setTransB(!matmul->getTransB()); + } + + // 直接连接到transpose的输入 + auto realInput = trans->getInputs(0); + matmul->replaceInput(input, realInput); + + // 更新tensor关系 + realInput->addTarget(matmul); + input->removeTarget(matmul); + + // 更新算子关系 + matmul->removePredecessors(trans); + trans->removeSuccessors(matmul); + + // 更新trans的其他关系 + realInput->removeTarget(trans); + input->setSource(nullptr); + + // 删除transpose算子及其输出tensor + removeOperator(trans); + removeTensor(input); + + modifiedMatmul = true; + changed = true; + + // 跳出内层循环 + break; + } + } + } + } + + // 如果MatMul被修改,可能需要重新计算shape + if (modifiedMatmul) + { + // 重新计算shape + auto newShapes = matmul->inferShape(matmul->getInputs()); + if (newShapes) + { + auto outputs = matmul->getOutputs(); + for (size_t i = 0; i < newShapes->size() && i < outputs.size(); ++i) + { + outputs[i]->setShape((*newShapes)[i]); + } + } + + // 跳出当前循环 + break; + } + } + } + + // 如果在这一轮中有任何变化,设置sorted为false + if (changed) + { + sorted = false; + } } + + // 重新拓扑排序 + if (!sorted) + { + topo_sort(); + } +} Tensor GraphObj::getTensor(int fuid) const { @@ -142,18 +342,58 @@ namespace infini } } } - - void GraphObj::dataMalloc() +void GraphObj::dataMalloc() { - // topological sorting first - IT_ASSERT(topo_sort() == true); + IT_ASSERT(topo_sort() == true); + + std::unordered_map tensor_offsets; + // 记录引用计数:当前还有多少算子需要读取该 Tensor + std::unordered_map ref_counts; + for (auto &t : tensors) { + ref_counts[t->getFuid()] = t->getTargets().size(); + } - // =================================== 作业 =================================== - // TODO:利用 allocator 给计算图分配内存 - // HINT: 获取分配好的内存指针后,可以调用 tensor 的 setDataBlob 函数给 tensor 绑定内存 - // =================================== 作业 =================================== + // 1. 为图的输入(没有 source 的 Tensor)分配内存 + for (auto &t : tensors) { + if (!t->getSource()) { + tensor_offsets[t->getFuid()] = allocator.alloc(t->getBytes()); + } + } - allocator.info(); + // 2. 拓扑序分配中间结果 + auto graph_outputs = this->getOutputs(); + for (auto &op : ops) { + // 先为输出分配空间 + for (auto &out : op->getOutputs()) { + tensor_offsets[out->getFuid()] = allocator.alloc(out->getBytes()); + } + + // 使用完输入后,减少引用计数并尝试释放 + for (auto &in : op->getInputs()) { + int fuid = in->getFuid(); + ref_counts[fuid]--; + // 如果引用计数归零,且不是图的最终输出,则回收地址空间 + if (ref_counts[fuid] == 0) { + bool is_output = false; + for (auto &out_t : graph_outputs) if (out_t->getFuid() == fuid) is_output = true; + + if (!is_output) { + allocator.free(tensor_offsets[fuid], in->getBytes()); + } + } + } + } + + // 3. 绑定物理地址 + void *base_ptr = allocator.getPtr(); + for (auto &t : tensors) { + if (tensor_offsets.count(t->getFuid())) { + auto offset = tensor_offsets[t->getFuid()]; + auto blob = make_ref(runtime, (char *)base_ptr + offset); + t->setDataBlob(blob); + } + } + allocator.info(); } Tensor GraphObj::addTensor(Shape dim, DataType dtype) diff --git a/src/operators/concat.cc b/src/operators/concat.cc index d196330..689aa43 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -3,9 +3,8 @@ namespace infini { ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim) - : OperatorObj(OpType::Concat, inputs, {output}) { - int rank = inputs[0]->getRank(); - dim = get_real_axis(_dim, rank); + : OperatorObj(OpType::Concat, inputs, {output}), + dim(get_real_axis(_dim, inputs[0]->getRank())) { // 在初始化列表中设置 dim IT_ASSERT(checkValid(graph)); } @@ -17,6 +16,12 @@ optional> ConcatObj::inferShape(const TensorVec &inputs) { // TODO:修改 dims,返回正确的 concat 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13 // =================================== 作业 =================================== + + // 在concat的维度上累加所有输入的大小 + dims[dim] = 0; + for (size_t i = 0; i < inputs.size(); ++i) { + dims[dim] += inputs[i]->getDims()[dim]; + } return {{dims}}; } @@ -35,4 +40,4 @@ std::string ConcatObj::toString() const { return os.str(); } -} // namespace infini +} // namespace infini \ No newline at end of file diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 7a16ca2..0fc30ce 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -27,7 +27,81 @@ namespace infini // TODO:返回经过 matmul 操作后的 shape // REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm // =================================== 作业 =================================== - return std::nullopt; + + const auto A = inputs[0]; + const auto B = inputs[1]; + + Shape shapeA = A->getDims(); + Shape shapeB = B->getDims(); + + int rankA = shapeA.size(); + int rankB = shapeB.size(); + + // 确保输入至少是2D + IT_ASSERT(rankA >= 2 && rankB >= 2, "MatMul inputs must be at least 2D"); + + // 获取矩阵维度 + int M, N, K_A, K_B; + + if (transA) { + M = shapeA[rankA - 1]; + K_A = shapeA[rankA - 2]; + } else { + M = shapeA[rankA - 2]; + K_A = shapeA[rankA - 1]; + } + + if (transB) { + N = shapeB[rankB - 2]; + K_B = shapeB[rankB - 1]; + } else { + N = shapeB[rankB - 1]; + K_B = shapeB[rankB - 2]; + } + + // 检查K维度是否匹配 + IT_ASSERT(K_A == K_B, "Inner dimensions do not match for matrix multiplication"); + + // 保存m, n, k成员变量用于计算 + m = M; + n = N; + k = K_A; // 或K_B,两者相等 + + // 处理批量维度 + Shape batchDims; + + if (rankA > 2 || rankB > 2) { + // 提取批量维度(去掉最后两个维度) + Shape batchA(shapeA.begin(), shapeA.end() - 2); + Shape batchB(shapeB.begin(), shapeB.end() - 2); + + // 对批量维度进行广播 + int maxBatchRank = std::max(batchA.size(), batchB.size()); + + // 调整维度使两个形状具有相同的长度 + batchA.insert(batchA.begin(), maxBatchRank - batchA.size(), 1); + batchB.insert(batchB.begin(), maxBatchRank - batchB.size(), 1); + + // 广播批量维度 + for (int i = 0; i < maxBatchRank; i++) { + int dimA = batchA[i]; + int dimB = batchB[i]; + + if (dimA != dimB && dimA != 1 && dimB != 1) { + IT_ASSERT(false, "Incompatible batch dimensions for matrix multiplication"); + } + + int broadcastDim = std::max(dimA, dimB); + batchDims.push_back(broadcastDim); + } + } + + // 构建输出形状:批量维度 + [M, N] + Shape outputDims = batchDims; + outputDims.push_back(M); + outputDims.push_back(N); + + return {{outputDims}}; } } // namespace infini \ No newline at end of file diff --git a/src/operators/transpose.cc b/src/operators/transpose.cc index faab2b6..54ee368 100644 --- a/src/operators/transpose.cc +++ b/src/operators/transpose.cc @@ -33,8 +33,12 @@ namespace infini // TODO:修改 output_dim,返回正确的 transpose 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-21 // =================================== 作业 =================================== - - return std::nullopt; + output_dim.resize(rank); + for (int i = 0; i < rank; ++i) { + output_dim[i] = input_dim[transposePermute[i]]; + } + + return {{output_dim}}; } std::string TransposeObj::toString() const diff --git a/src/operators/unary.cc b/src/operators/unary.cc index 3daad36..b3c2767 100644 --- a/src/operators/unary.cc +++ b/src/operators/unary.cc @@ -39,7 +39,8 @@ namespace infini // TODO:返回经过 clip 操作后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Clip.html#clip-13 // =================================== 作业 =================================== - return std::nullopt; + const auto A = inputs[0]; + return {{A->getDims()}}; } std::string ClipObj::toString() const @@ -66,7 +67,7 @@ namespace infini // REF_FILE: src/core/operator.cc // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 // =================================== 作业 =================================== - return {}; + return {getOutputDataType()}; } optional> CastObj::inferShape(const TensorVec &inputs) @@ -75,7 +76,8 @@ namespace infini // TODO:返回经过 cast 操作后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 // =================================== 作业 =================================== - return std::nullopt; + const auto A = inputs[0]; + return {{A->getDims()}}; } std::string CastObj::toString() const @@ -145,4 +147,4 @@ namespace infini IT_TODO_HALT(); } } -}; // namespace infini +}; // namespace infini \ No newline at end of file diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc index edbd2c8..9599efd 100644 --- a/src/utils/operator_utils.cc +++ b/src/utils/operator_utils.cc @@ -4,13 +4,37 @@ namespace infini { Shape infer_broadcast(const Shape &A, const Shape &B) { - // =================================== 作业 =================================== // TODO:对 A 和 B 进行双向广播,返回广播后的形状。 // REF: https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md // =================================== 作业 =================================== - return {}; + // 获取两个形状的维度数 + size_t rankA = A.size(); + size_t rankB = B.size(); + + // 计算最大维度数 + size_t maxRank = std::max(rankA, rankB); + + // 准备结果形状 + Shape result(maxRank); + + // 从最右边的维度开始比较 + for (int i = 1; i <= (int)maxRank; ++i) { + int dimA = (i <= (int)rankA) ? A[rankA - i] : 1; + int dimB = (i <= (int)rankB) ? B[rankB - i] : 1; + + // 检查维度是否兼容 + if (dimA != dimB && dimA != 1 && dimB != 1) { + // 不兼容的维度 + IT_ASSERT(false, "Incompatible dimensions for broadcasting"); + } + + // 选择广播后的维度(取较大的维度,除非其中一个为1) + result[maxRank - i] = std::max(dimA, dimB); + } + + return result; } int get_real_axis(const int &axis, const int &rank) { @@ -66,4 +90,4 @@ std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs) { return deviceStr + ", " + opStr; } -} // namespace infini +} // namespace infini \ No newline at end of file From 114f5f051f5cf796481d05a70cc61b0bc63479b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9B=B9=E7=A1=95=E5=AE=87?= Date: Mon, 2 Feb 2026 19:05:10 +0800 Subject: [PATCH 2/3] Delete home/ros/lectures/ai-complier/TinyInfiniTensor/src/core directory --- .../TinyInfiniTensor/src/core/graph.cc | 151 ------------------ 1 file changed, 151 deletions(-) delete mode 100644 home/ros/lectures/ai-complier/TinyInfiniTensor/src/core/graph.cc diff --git a/home/ros/lectures/ai-complier/TinyInfiniTensor/src/core/graph.cc b/home/ros/lectures/ai-complier/TinyInfiniTensor/src/core/graph.cc deleted file mode 100644 index d5b991a..0000000 --- a/home/ros/lectures/ai-complier/TinyInfiniTensor/src/core/graph.cc +++ /dev/null @@ -1,151 +0,0 @@ -void GraphObj::optimize() -{ - // Based on the test case, we need to implement optimizations to merge/remove transpose operations - // and combine them with matmul operations - - bool changed = true; - while(changed) { - changed = false; - - // Look for consecutive transpose operations that cancel each other out - for(size_t i = 0; i < ops.size(); i++) { - auto op1 = ops[i]; - if(op1->getOpType() != OpType::Transpose) continue; - - auto transpose1 = dynamic_pointer_cast(op1); - if(transpose1->getOutputs().empty()) continue; - - auto mid_tensor = transpose1->getOutputs()[0]; - auto consumers = mid_tensor->getTargets(); - - // If this tensor only goes to another transpose operation - if(consumers.size() == 1 && consumers[0]->getOpType() == OpType::Transpose) { - auto transpose2 = dynamic_pointer_cast(consumers[0]); - - // Check if the transposes cancel each other out (are inverses) - auto perm1 = transpose1->getPermute(); - auto perm2 = transpose2->getPermute(); - - bool is_inverse = true; - if(perm1.size() != perm2.size()) { - is_inverse = false; - } else { - for(size_t idx = 0; idx < perm1.size(); idx++) { - if(perm2[perm1[idx]] != (int)idx) { - is_inverse = false; - break; - } - } - } - - if(is_inverse) { - // Remove both transposes by connecting input directly to final consumers - - // Get the original input tensor and final output tensor - auto input_tensor = transpose1->getInputs()[0]; - auto output_tensor = transpose2->getOutputs()[0]; - - // Update all operators that consume the output of transpose2 - // to instead consume the input of transpose1 - auto final_consumers = output_tensor->getTargets(); - for(auto consumer : final_consumers) { - for(size_t j = 0; j < consumer->getInputs().size(); j++) { - if(consumer->getInputs(j) == output_tensor) { - consumer->replaceInput(output_tensor, input_tensor); - } - } - } - - // Update tensor graph connections - input_tensor->removeTarget(op1); - mid_tensor->removeTarget(transpose2); - mid_tensor->removeSource(); - output_tensor->removeSource(); - - // Remove the two transpose operations - ops.erase(std::remove(ops.begin(), ops.end(), op1), ops.end()); - ops.erase(std::remove(ops.begin(), ops.end(), transpose2), ops.end()); - - // Remove the intermediate tensors - tensors.erase(std::remove(tensors.begin(), tensors.end(), mid_tensor), tensors.end()); - tensors.erase(std::remove(tensors.begin(), tensors.end(), output_tensor), tensors.end()); - - changed = true; - break; - } - } - } - - if(changed) continue; - - // Look for transpose followed by matmul that can be fused - for(size_t i = 0; i < ops.size(); i++) { - auto op = ops[i]; - if(op->getOpType() != OpType::Matmul) continue; - - auto matmul = dynamic_pointer_cast(op); - - // Check first input for transpose fusion - if(!matmul->getInputs(0)->getSource()) continue; - auto prev_op = matmul->getInputs(0)->getSource(); - if(prev_op->getOpType() == OpType::Transpose) { - auto transpose = dynamic_pointer_cast(prev_op); - // Check if transpose swaps last 2 dimensions (common case for matmul) - auto perm = transpose->getPermute(); - if(perm.size() >= 2 && - perm[perm.size()-2] == (int)perm.size()-1 && - perm[perm.size()-1] == (int)perm.size()-2) { - // Fuse: set transA flag and update matmul - matmul->setTransA(true); - - // Update the matmul to take input from the transpose's input - auto trans_input = transpose->getInputs()[0]; - matmul->replaceInput(matmul->getInputs(0), trans_input); - - // Update tensor connections - matmul->getInputs(0)->removeTarget(op); - trans_input->addTarget(op); - - // Remove transpose and intermediate tensor - ops.erase(std::remove(ops.begin(), ops.end(), prev_op), ops.end()); - tensors.erase(std::remove(tensors.begin(), tensors.end(), matmul->getInputs(0)), tensors.end()); - - changed = true; - break; - } - } - - if(changed) continue; - - // Check second input for transpose fusion - if(!matmul->getInputs(1)->getSource()) continue; - prev_op = matmul->getInputs(1)->getSource(); - if(prev_op->getOpType() == OpType::Transpose) { - auto transpose = dynamic_pointer_cast(prev_op); - // Check if transpose swaps last 2 dimensions (common case for matmul) - auto perm = transpose->getPermute(); - if(perm.size() >= 2 && - perm[perm.size()-2] == (int)perm.size()-1 && - perm[perm.size()-1] == (int)perm.size()-2) { - // Fuse: set transB flag and update matmul - matmul->setTransB(true); - - // Update the matmul to take input from the transpose's input - auto trans_input = transpose->getInputs()[0]; - matmul->replaceInput(matmul->getInputs(1), trans_input); - - // Update tensor connections - matmul->getInputs(1)->removeTarget(op); - trans_input->addTarget(op); - - // Remove transpose and intermediate tensor - ops.erase(std::remove(ops.begin(), ops.end(), prev_op), ops.end()); - tensors.erase(std::remove(tensors.begin(), tensors.end(), matmul->getInputs(1)), tensors.end()); - - changed = true; - break; - } - } - } - } -} \ No newline at end of file From 9f944d479f0c7ce1758953f8ddc0764ff8f0e524 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 2 Feb 2026 19:05:45 +0800 Subject: [PATCH 3/3] 2 --- src/core/graph.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/graph.cc b/src/core/graph.cc index 71bbf1a..051d7d8 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -252,7 +252,7 @@ void GraphObj::optimize() realInput->addTarget(matmul); input->removeTarget(matmul); - // 更新算子关系 + // 更新算子 关系 matmul->removePredecessors(trans); trans->removeSuccessors(matmul);