-
Notifications
You must be signed in to change notification settings - Fork 102
呢哇 #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
呢哇 #12
Changes from all commits
c03c42c
3ecf247
5ddc1fe
a7a23f0
351e8d0
4704fb6
5307d0e
d7423b5
f2114d7
6041ff6
5ceb539
339cc12
b6b6339
26ec13c
bc26b8d
a72ca89
40272e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,7 @@ | ||
| #include "core/graph.h" | ||
| #include "core/blob.h" | ||
| #include "operators/matmul.h" | ||
| #include "operators/transpose.h" | ||
| #include <algorithm> | ||
| #include <numeric> | ||
| #include <queue> | ||
|
|
@@ -106,6 +109,166 @@ namespace infini | |
| // 1. 去除冗余的算子(例如,两个相邻的算子都是 transpose 算子,且做的是相反的操作,可以将其全部删除) | ||
| // 2. 合并算子(例如,矩阵乘算子中含有属性transA、transB,如果其输入存在transpose,且对最后两个维度做交换,就可以将transpose融入到矩阵乘算子的属性中去) | ||
| // =================================== 作业 =================================== | ||
|
|
||
| IT_ASSERT(topo_sort() == true); | ||
|
|
||
| auto isSwapLast2Permute = [](const std::vector<int> &perm) -> bool | ||
| { | ||
| const int r = static_cast<int>(perm.size()); | ||
| if (r < 2) | ||
| return false; | ||
| for (int i = 0; i < r - 2; ++i) | ||
| if (perm[i] != i) | ||
| return false; | ||
| return perm[r - 2] == r - 1 && perm[r - 1] == r - 2; | ||
| }; | ||
|
|
||
| auto isInversePermute = [](const std::vector<int> &p1, | ||
| const std::vector<int> &p2) -> bool | ||
| { | ||
| if (p1.size() != p2.size()) | ||
| return false; | ||
| const int r = static_cast<int>(p1.size()); | ||
| std::vector<int> inv(r, -1); | ||
| for (int i = 0; i < r; ++i) | ||
| { | ||
| const int v = p1[i]; | ||
| if (v < 0 || v >= r || inv[v] != -1) | ||
| return false; | ||
| inv[v] = i; | ||
| } | ||
| return inv == p2; | ||
| }; | ||
|
|
||
| bool changed = true; | ||
| while (changed) | ||
| { | ||
| changed = false; | ||
|
|
||
| // 规则 1:消除连续 transpose(perm 互逆) | ||
| for (size_t i = 0; i < ops.size(); ++i) | ||
| { | ||
| auto op1 = ops[i]; | ||
| if (op1->getOpType() != OpType::Transpose) | ||
| continue; | ||
| auto t1 = as<TransposeObj>(op1); | ||
| auto out1 = t1->getOutput(); | ||
| auto targets = out1->getTargets(); | ||
| if (targets.size() != 1) | ||
| continue; | ||
| auto op2 = targets[0]; | ||
| if (!op2 || op2->getOpType() != OpType::Transpose) | ||
| continue; | ||
| auto t2 = as<TransposeObj>(op2); | ||
| if (t2->getInputs(0) != out1) | ||
| continue; | ||
| if (!isInversePermute(t1->getPermute(), t2->getPermute())) | ||
| continue; | ||
|
|
||
| auto in = t1->getInputs(0); | ||
| auto out2 = t2->getOutput(); | ||
| for (auto &consumer : out2->getTargets()) | ||
| consumer->replaceInput(out2, in); | ||
|
|
||
| ops.erase(std::remove(ops.begin(), ops.end(), op1), ops.end()); | ||
| ops.erase(std::remove(ops.begin(), ops.end(), op2), ops.end()); | ||
| changed = true; | ||
| break; | ||
| } | ||
| if (changed) | ||
| continue; | ||
|
|
||
| // 规则 2:将 transpose(交换最后两维) 融合到 matmul 的 transA/transB | ||
| std::unordered_set<OperatorObj *> toRemove; | ||
| for (auto &op : ops) | ||
| { | ||
| if (op->getOpType() != OpType::MatMul) | ||
| continue; | ||
| auto mm = as<MatmulObj>(op); | ||
| for (int inputIdx = 0; inputIdx < 2; ++inputIdx) | ||
| { | ||
| auto in = mm->getInputs(inputIdx); | ||
| auto src = in->getSource(); | ||
| if (!src || src->getOpType() != OpType::Transpose) | ||
| continue; | ||
| auto tr = as<TransposeObj>(src); | ||
| if (tr->getOutput() != in) | ||
| continue; | ||
| if (!isSwapLast2Permute(tr->getPermute())) | ||
| continue; | ||
|
|
||
| auto trIn = tr->getInputs(0); | ||
| mm->replaceInput(in, trIn); | ||
| if (inputIdx == 0) | ||
| mm->setTransA(!mm->getTransA()); | ||
| else | ||
| mm->setTransB(!mm->getTransB()); | ||
| toRemove.insert(src.get()); | ||
| changed = true; | ||
|
Comment on lines
+194
to
+207
|
||
| } | ||
| } | ||
| if (!toRemove.empty()) | ||
| { | ||
| ops.erase(std::remove_if(ops.begin(), ops.end(), | ||
| [&](const Operator &op) | ||
| { return toRemove.count(op.get()) != 0; }), | ||
| ops.end()); | ||
| } | ||
| } | ||
|
|
||
| // 清理不再被任何算子引用的张量 | ||
| { | ||
| std::unordered_set<TensorObj *> referenced; | ||
| referenced.reserve(tensors.size()); | ||
| for (auto &op : ops) | ||
| { | ||
| for (auto &t : op->getInputs()) | ||
| referenced.insert(t.get()); | ||
| for (auto &t : op->getOutputs()) | ||
| referenced.insert(t.get()); | ||
| } | ||
| TensorVec kept; | ||
| kept.reserve(tensors.size()); | ||
| for (auto &t : tensors) | ||
| if (referenced.count(t.get()) != 0) | ||
| kept.emplace_back(t); | ||
| tensors = std::move(kept); | ||
| } | ||
|
|
||
| // 重新构建 pred/succ 与 tensor source/target | ||
| for (auto &t : tensors) | ||
| { | ||
| t->targets.clear(); | ||
| t->source.reset(); | ||
| } | ||
| for (auto &op : ops) | ||
| { | ||
| op->predecessors.clear(); | ||
| op->successors.clear(); | ||
| } | ||
| for (auto &op : ops) | ||
| { | ||
| for (auto &input : op->getInputs()) | ||
| { | ||
| if (input) | ||
| { | ||
| input->addTarget(op); | ||
| if (auto pred = input->getSource()) | ||
| { | ||
| pred->addSuccessors(op); | ||
| op->addPredecessors(pred); | ||
| } | ||
| } | ||
| } | ||
| for (auto &output : op->getOutputs()) | ||
| { | ||
| if (output) | ||
| output->setSource(op); | ||
| } | ||
| } | ||
|
|
||
| sorted = false; | ||
| IT_ASSERT(topo_sort() == true); | ||
| } | ||
|
|
||
| Tensor GraphObj::getTensor(int fuid) const | ||
|
|
@@ -148,10 +311,65 @@ namespace infini | |
| // topological sorting first | ||
| IT_ASSERT(topo_sort() == true); | ||
|
|
||
| // =================================== 作业 =================================== | ||
| // TODO:利用 allocator 给计算图分配内存 | ||
| // HINT: 获取分配好的内存指针后,可以调用 tensor 的 setDataBlob 函数给 tensor 绑定内存 | ||
| // =================================== 作业 =================================== | ||
| std::unordered_map<TensorObj *, int> remainingUses; | ||
| std::unordered_map<TensorObj *, size_t> bytes; | ||
| std::unordered_map<TensorObj *, size_t> offsets; | ||
| remainingUses.reserve(tensors.size()); | ||
| bytes.reserve(tensors.size()); | ||
| offsets.reserve(tensors.size()); | ||
|
|
||
| std::unordered_set<TensorObj *> keepAlive; | ||
| keepAlive.reserve(tensors.size()); | ||
| for (auto &t : tensors) | ||
| { | ||
| bytes[t.get()] = t->getBytes(); | ||
| remainingUses[t.get()] = static_cast<int>(t->getTargets().size()); | ||
| if (t->getTargets().empty()) | ||
| keepAlive.insert(t.get()); | ||
| } | ||
|
|
||
| auto ensureAlloc = [&](const Tensor &t) | ||
| { | ||
| auto *p = t.get(); | ||
| if (offsets.find(p) == offsets.end()) | ||
| offsets[p] = allocator.alloc(bytes[p]); | ||
| }; | ||
|
|
||
| // 输入张量:dataMalloc 后会 setData | ||
| for (auto &t : getInputs()) | ||
| ensureAlloc(t); | ||
|
|
||
| // 遍历 op:分配输出、回收“已完成最后一次使用”的输入 | ||
| for (auto &op : ops) | ||
| { | ||
| for (auto &out : op->getOutputs()) | ||
| ensureAlloc(out); | ||
|
|
||
| for (auto &in : op->getInputs()) | ||
| { | ||
| auto *p = in.get(); | ||
| auto it = remainingUses.find(p); | ||
| if (it == remainingUses.end()) | ||
| continue; | ||
| if (it->second > 0) | ||
| --(it->second); | ||
| if (it->second == 0 && keepAlive.count(p) == 0) | ||
| { | ||
| auto offIt = offsets.find(p); | ||
| if (offIt != offsets.end()) | ||
| allocator.free(offIt->second, bytes[p]); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void *base = allocator.getPtr(); | ||
| for (auto &t : tensors) | ||
| { | ||
| auto it = offsets.find(t.get()); | ||
| IT_ASSERT(it != offsets.end(), "Tensor not allocated in dataMalloc"); | ||
| void *ptr = static_cast<void *>(static_cast<char *>(base) + it->second); | ||
| t->setDataBlob(make_ref<BlobObj>(runtime, ptr)); | ||
| } | ||
|
|
||
| allocator.info(); | ||
| } | ||
|
|
@@ -227,4 +445,4 @@ namespace infini | |
| return true; | ||
| } | ||
|
|
||
| } // namespace infini | ||
| } // namespace infini | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The findFreeBlock function uses a linear search (first-fit) through all free blocks. For workloads with many allocations and deallocations, this could become a performance bottleneck as it has O(n) complexity where n is the number of free blocks. Consider using a more efficient data structure, such as maintaining free blocks sorted by size (in addition to address) or using a best-fit strategy with size-indexed structures for better performance. However, if the number of tensors is small (typical in most ML graphs), this may not be a practical concern.