-
Notifications
You must be signed in to change notification settings - Fork 78
[IR Container] Phase 2.5 Copy-Move Semantics #5964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: md/phase2-per-fusion
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -107,46 +107,90 @@ bool Fusion::sameDefinition(const Fusion& other) const { | |
| void Fusion::swap(Fusion& a, Fusion& b) noexcept { | ||
| FUSER_PERF_SCOPE("Fusion swap"); | ||
|
|
||
| // We need to be careful to call IrContainer swap not unique_ptr swap, which | ||
| // will only swap the ptrs NOT the contents. | ||
| IrContainer::swap(*(a.ir_container()), *(b.ir_container())); | ||
| if (&a == &b) { | ||
| return; | ||
| } | ||
|
|
||
| // After swapping container contents, per-Fusion tracking keys point to the | ||
| // wrong Fusions. Rename: a's container had b's entries, b's had a's. | ||
| a.ir_container()->transferStatementOwnership(&b, &a); | ||
| b.ir_container()->transferStatementOwnership(&a, &b); | ||
| // Collect statements owned by each Fusion BEFORE swap | ||
| std::vector<Val*> a_owned_vals, b_owned_vals; | ||
| std::vector<Expr*> a_owned_exprs, b_owned_exprs; | ||
|
|
||
| if (a.ir_container_) { | ||
| for (auto val : a.vals()) { | ||
| val->ir_container_ = &a; | ||
| } | ||
| for (auto expr : a.deterministic_exprs()) { | ||
| expr->ir_container_ = &a; | ||
| } | ||
| const auto& av = a.ir_container_->valsOwnedBy(&a); | ||
| const auto& ae = a.ir_container_->exprsOwnedBy(&a); | ||
| a_owned_vals.assign(av.begin(), av.end()); | ||
| a_owned_exprs.assign(ae.begin(), ae.end()); | ||
| } | ||
| if (b.ir_container_) { | ||
| for (auto val : b.vals()) { | ||
| val->ir_container_ = &b; | ||
| } | ||
| for (auto expr : b.deterministic_exprs()) { | ||
| expr->ir_container_ = &b; | ||
| } | ||
| const auto& bv = b.ir_container_->valsOwnedBy(&b); | ||
| const auto& be = b.ir_container_->exprsOwnedBy(&b); | ||
| b_owned_vals.assign(bv.begin(), bv.end()); | ||
| b_owned_exprs.assign(be.begin(), be.end()); | ||
| } | ||
|
|
||
| // Transfer Fusion registrations between containers before pointer swap. | ||
| // After swap, a will own b's container and b will own a's container. | ||
| if (a.ir_container_ && b.ir_container_ && | ||
| a.ir_container_.get() != b.ir_container_.get()) { | ||
| a.ir_container_->transferFusion(&a, &b); | ||
| b.ir_container_->transferFusion(&b, &a); | ||
| } | ||
|
Comment on lines
+131
to
137
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing When only one Fusion has a container (e.g., For example, if
While this is likely unreachable in practice (all constructors initialize |
||
|
|
||
| // Swap container pointers | ||
| std::swap(a.ir_container_, b.ir_container_); | ||
|
|
||
| // Swap all Fusion-level members | ||
| std::swap(a.inputs_, b.inputs_); | ||
| std::swap(a.outputs_, b.outputs_); | ||
|
|
||
| std::swap(a.io_alias_, b.io_alias_); | ||
|
|
||
| // Swap per-Fusion special values (Phase 2) | ||
| std::swap(a.all_tv_uses_valid_, b.all_tv_uses_valid_); | ||
| std::swap(a.is_during_update_uses_, b.is_during_update_uses_); | ||
| std::swap(a.managed_data_, b.managed_data_); | ||
| std::swap(a.managed_named_data_, b.managed_named_data_); | ||
| std::swap(a.expected_dynamic_smem_bytes_, b.expected_dynamic_smem_bytes_); | ||
| std::swap(a.all_tvs_ptr_, b.all_tvs_ptr_); | ||
| std::swap(a.zero_val_, b.zero_val_); | ||
| std::swap(a.one_val_, b.one_val_); | ||
| std::swap(a.true_val_, b.true_val_); | ||
| std::swap(a.false_val_, b.false_val_); | ||
| std::swap(a.magic_zero_val_, b.magic_zero_val_); | ||
|
|
||
| std::swap(a.axioms_, b.axioms_); | ||
| std::swap(a.metadata_, b.metadata_); | ||
| std::swap(a.val_type_name_map_, b.val_type_name_map_); | ||
| std::swap(a.expr_name_counter_, b.expr_name_counter_); | ||
|
|
||
| // Update Statement::ir_container_ pointers: a's old statements now belong | ||
| // to b, and b's old statements now belong to a | ||
| for (auto* val : a_owned_vals) { | ||
| val->ir_container_ = &b; | ||
| } | ||
| for (auto* expr : a_owned_exprs) { | ||
| expr->ir_container_ = &b; | ||
| } | ||
| for (auto* val : b_owned_vals) { | ||
| val->ir_container_ = &a; | ||
| } | ||
| for (auto* expr : b_owned_exprs) { | ||
| expr->ir_container_ = &a; | ||
| } | ||
|
|
||
| // Update per-Fusion tracking keys in containers | ||
| if (a.ir_container_ && b.ir_container_) { | ||
| if (a.ir_container_.get() == b.ir_container_.get()) { | ||
| // Same container: directly swap per-Fusion tracking entries | ||
| auto* c = a.ir_container_.get(); | ||
| std::swap(c->per_fusion_vals_[&a], c->per_fusion_vals_[&b]); | ||
| std::swap(c->per_fusion_exprs_[&a], c->per_fusion_exprs_[&b]); | ||
| } else { | ||
| // Different containers: rename tracking keys to match new owners | ||
| a.ir_container_->transferStatementOwnership(&b, &a); | ||
| b.ir_container_->transferStatementOwnership(&a, &b); | ||
| } | ||
| } else if (a.ir_container_) { | ||
| a.ir_container_->transferStatementOwnership(&b, &a); | ||
| } else if (b.ir_container_) { | ||
| b.ir_container_->transferStatementOwnership(&a, &b); | ||
| } | ||
| } | ||
|
|
||
| std::unique_ptr<SegmentedFusion> Fusion::segment( | ||
|
|
@@ -158,10 +202,30 @@ std::unique_ptr<SegmentedFusion> Fusion::segment( | |
| IrCloner Fusion::copy(const Fusion* from, Fusion* to) { | ||
| to->clear(); | ||
|
|
||
| auto ir_cloner = | ||
| IrContainer::copy(from->ir_container(), to->ir_container(), to); | ||
| IrCloner ir_cloner(to); | ||
|
|
||
| // Clone from's vals in insertion order | ||
| for (auto val : from->deterministic_vals()) { | ||
| ir_cloner.clone(val); | ||
| } | ||
|
|
||
| // Sync per-Fusion name counters from source to dest. | ||
| // During cloning, registerVal increments the dest Fusion's counter for each | ||
| // val, then IrBuilder::clone overrides the name with setName(src->name()). | ||
| // If source names are non-sequential (e.g., {0..10, 22..27} from segmenter | ||
| // creating intermediate TVs), the dest counter ends up at N (number of vals) | ||
| // instead of max(name)+1. Copying the source's counter state ensures new | ||
| // vals created post-copy won't collide with existing names. | ||
| to->val_type_name_map_ = from->val_type_name_map_; | ||
| to->expr_name_counter_ = from->expr_name_counter_; | ||
|
|
||
| // Wire up definitions and uses on cloned vals | ||
| for (auto val : from->vals()) { | ||
| ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); | ||
| ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); | ||
| } | ||
|
|
||
| // Remap cached special val pointers through the cloner | ||
| // Remap cached special val pointers | ||
| if (from->zero_val_) { | ||
| to->zero_val_ = ir_cloner.clone(from->zero_val_); | ||
| } | ||
|
|
@@ -179,11 +243,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { | |
| ir_cloner.clone(from->magic_zero_val_)->as<NamedScalar>(); | ||
| } | ||
|
|
||
| for (auto val : from->vals()) { | ||
| ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); | ||
| ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); | ||
| } | ||
|
|
||
| to->inputs_ = ir_cloner.clone(from->inputs_); | ||
| to->outputs_ = ir_cloner.clone(from->outputs_); | ||
| for (auto inp : to->inputs_) { | ||
|
|
@@ -193,7 +252,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { | |
| out->setIsFusionOutput(true); | ||
| } | ||
|
|
||
| // TODO: put this into ir_cloner instead | ||
| for (Val* out : from->outputs_) { | ||
| const AliasInfo& alias = from->io_alias_.get(out); | ||
| if (alias.type == AllocationType::New) { | ||
|
|
@@ -206,14 +264,12 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { | |
| } | ||
|
|
||
| to->all_tv_uses_valid_ = from->all_tv_uses_valid_; | ||
| // This should never be true on copy, but copying for completeness. | ||
| to->is_during_update_uses_ = from->is_during_update_uses_; | ||
|
|
||
| for (const auto& i : from->managed_data_) { | ||
| if (i.first.has_value()) { | ||
| to->managed_data_.emplace_back(i.second(ir_cloner, i.first), i.second); | ||
| } else { | ||
| // Don't clone managed data if it has been reset | ||
| to->managed_data_.emplace_back(i.first, i.second); | ||
| } | ||
| } | ||
|
|
@@ -256,9 +312,10 @@ Fusion::Fusion() : ir_container_(std::make_shared<IrContainer>()) { | |
| ir_container_->addFusion(this); | ||
| } | ||
|
|
||
| // Copy constructor | ||
| Fusion::Fusion(const Fusion& other) : Fusion() { | ||
| // Copy constructor -- shares the source's container | ||
| Fusion::Fusion(const Fusion& other) : ir_container_(other.ir_container_) { | ||
| FUSER_PERF_SCOPE("Fusion copy"); | ||
| ir_container_->addFusion(this); | ||
| Fusion::copy(&other, this); | ||
| } | ||
|
|
||
|
|
@@ -278,6 +335,9 @@ Fusion& Fusion::operator=(const Fusion& other) { | |
|
|
||
| Fusion& Fusion::operator=(Fusion&& other) noexcept { | ||
| FUSER_PERF_SCOPE("Fusion move assign"); | ||
| if (this == &other) { | ||
| return *this; | ||
| } | ||
| clear(); | ||
| swap(*this, other); | ||
| return *this; | ||
|
|
@@ -317,6 +377,9 @@ void Fusion::clear() noexcept { | |
| axioms_.reset(); | ||
| metadata_.clear(); | ||
|
|
||
| val_type_name_map_.clear(); | ||
| expr_name_counter_ = 0; | ||
|
|
||
| invalidateTvsAndUses(); | ||
|
|
||
| is_during_update_uses_ = false; | ||
|
|
@@ -932,7 +995,7 @@ void Fusion::registerVal(Val* val) { | |
| c->vals_up_.emplace_back(val); | ||
| c->vals_.insert(val); | ||
| c->per_fusion_vals_[this].insert(val); | ||
| val->setName(IrContainerPasskey(), c->getValName(val->vtype())); | ||
| val->setName(IrContainerPasskey(), getValName(val->vtype())); | ||
| } | ||
|
|
||
| void Fusion::registerExpr(Expr* expr) { | ||
|
|
@@ -949,7 +1012,7 @@ void Fusion::registerExpr(Expr* expr) { | |
| c->exprs_up_.emplace_back(expr); | ||
| c->exprs_.insert(expr); | ||
| c->per_fusion_exprs_[this].insert(expr); | ||
| expr->setName(IrContainerPasskey(), c->getExprName()); | ||
| expr->setName(IrContainerPasskey(), getExprName()); | ||
|
|
||
| for (Val* input : expr->inputs()) { | ||
| assertInContainer(input, "Input to expr is invalid, "); | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -647,6 +647,24 @@ class NVF_API Fusion : public PolymorphicBase { | |||||||||||||||||||
| std::unique_ptr<std::vector<Val*>> axioms_; | ||||||||||||||||||||
|
|
||||||||||||||||||||
| std::unordered_map<Val*, std::pair<Val*, Expr*>> metadata_; | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // Per-Fusion name counters. Each Fusion independently tracks name assignment | ||||||||||||||||||||
| // so that cloned Fusions get matching names (T0→T0) regardless of whether | ||||||||||||||||||||
| // they share an IrContainer. This is required by downstream consumers that | ||||||||||||||||||||
| // use tv->name() as a map key (alias_memory, GreedyParams, etc.). | ||||||||||||||||||||
| std::unordered_map<ValType, StmtNameType> val_type_name_map_; | ||||||||||||||||||||
| StmtNameType expr_name_counter_ = 0; | ||||||||||||||||||||
|
|
||||||||||||||||||||
| StmtNameType getValName(ValType vtype) { | ||||||||||||||||||||
| if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) { | ||||||||||||||||||||
| val_type_name_map_[vtype] = 0; | ||||||||||||||||||||
| } | ||||||||||||||||||||
| return val_type_name_map_[vtype]++; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
Comment on lines
+658
to
+663
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using Minor style nit: the
Suggested change
This produces identical behavior because Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||||||||||||||||||||
|
|
||||||||||||||||||||
| StmtNameType getExprName() { | ||||||||||||||||||||
| return expr_name_counter_++; | ||||||||||||||||||||
| } | ||||||||||||||||||||
| }; | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // Template implementations for Fusion::manage<T>() that use IrCloner | ||||||||||||||||||||
|
|
||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noexceptswap performs allocating operationsswapis markednoexceptbut performs operations that can throwstd::bad_alloc:std::vectorconstruction and.assign()on lines 115-128transferFusiondoesunordered_set::insert(line 135-136)per_fusion_vals_[&a]withoperator[]can allocate (line 182)transferStatementOwnershipdoesunordered_set::insertandunordered_map::operator[](lines 186-192)If any of these throw,
std::terminatewill be called. This is a pre-existing pattern from the old swap, but the new implementation adds more allocating operations (thestd::vectorsnapshots). Consider either removingnoexceptor wrapping in try/catch with a fallback.