diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 8f76a1b57d0..91cf035aafc 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -107,46 +107,90 @@ bool Fusion::sameDefinition(const Fusion& other) const { void Fusion::swap(Fusion& a, Fusion& b) noexcept { FUSER_PERF_SCOPE("Fusion swap"); - // We need to be careful to call IrContainer swap not unique_ptr swap, which - // will only swap the ptrs NOT the contents. - IrContainer::swap(*(a.ir_container()), *(b.ir_container())); + if (&a == &b) { + return; + } - // After swapping container contents, per-Fusion tracking keys point to the - // wrong Fusions. Rename: a's container had b's entries, b's had a's. - a.ir_container()->transferStatementOwnership(&b, &a); - b.ir_container()->transferStatementOwnership(&a, &b); + // Collect statements owned by each Fusion BEFORE swap + std::vector a_owned_vals, b_owned_vals; + std::vector a_owned_exprs, b_owned_exprs; if (a.ir_container_) { - for (auto val : a.vals()) { - val->ir_container_ = &a; - } - for (auto expr : a.deterministic_exprs()) { - expr->ir_container_ = &a; - } + const auto& av = a.ir_container_->valsOwnedBy(&a); + const auto& ae = a.ir_container_->exprsOwnedBy(&a); + a_owned_vals.assign(av.begin(), av.end()); + a_owned_exprs.assign(ae.begin(), ae.end()); } if (b.ir_container_) { - for (auto val : b.vals()) { - val->ir_container_ = &b; - } - for (auto expr : b.deterministic_exprs()) { - expr->ir_container_ = &b; - } + const auto& bv = b.ir_container_->valsOwnedBy(&b); + const auto& be = b.ir_container_->exprsOwnedBy(&b); + b_owned_vals.assign(bv.begin(), bv.end()); + b_owned_exprs.assign(be.begin(), be.end()); + } + + // Transfer Fusion registrations between containers before pointer swap. + // After swap, a will own b's container and b will own a's container. + if (a.ir_container_ && b.ir_container_ && + a.ir_container_.get() != b.ir_container_.get()) { + a.ir_container_->transferFusion(&a, &b); + b.ir_container_->transferFusion(&b, &a); } + // Swap container pointers + std::swap(a.ir_container_, b.ir_container_); + + // Swap all Fusion-level members std::swap(a.inputs_, b.inputs_); std::swap(a.outputs_, b.outputs_); - std::swap(a.io_alias_, b.io_alias_); - - // Swap per-Fusion special values (Phase 2) + std::swap(a.all_tv_uses_valid_, b.all_tv_uses_valid_); + std::swap(a.is_during_update_uses_, b.is_during_update_uses_); + std::swap(a.managed_data_, b.managed_data_); + std::swap(a.managed_named_data_, b.managed_named_data_); + std::swap(a.expected_dynamic_smem_bytes_, b.expected_dynamic_smem_bytes_); + std::swap(a.all_tvs_ptr_, b.all_tvs_ptr_); std::swap(a.zero_val_, b.zero_val_); std::swap(a.one_val_, b.one_val_); std::swap(a.true_val_, b.true_val_); std::swap(a.false_val_, b.false_val_); std::swap(a.magic_zero_val_, b.magic_zero_val_); - std::swap(a.axioms_, b.axioms_); std::swap(a.metadata_, b.metadata_); + std::swap(a.val_type_name_map_, b.val_type_name_map_); + std::swap(a.expr_name_counter_, b.expr_name_counter_); + + // Update Statement::ir_container_ pointers: a's old statements now belong + // to b, and b's old statements now belong to a + for (auto* val : a_owned_vals) { + val->ir_container_ = &b; + } + for (auto* expr : a_owned_exprs) { + expr->ir_container_ = &b; + } + for (auto* val : b_owned_vals) { + val->ir_container_ = &a; + } + for (auto* expr : b_owned_exprs) { + expr->ir_container_ = &a; + } + + // Update per-Fusion tracking keys in containers + if (a.ir_container_ && b.ir_container_) { + if (a.ir_container_.get() == b.ir_container_.get()) { + // Same container: directly swap per-Fusion tracking entries + auto* c = a.ir_container_.get(); + std::swap(c->per_fusion_vals_[&a], c->per_fusion_vals_[&b]); + std::swap(c->per_fusion_exprs_[&a], c->per_fusion_exprs_[&b]); + } else { + // Different containers: rename tracking keys to match new owners + a.ir_container_->transferStatementOwnership(&b, &a); + b.ir_container_->transferStatementOwnership(&a, &b); + } + } else if (a.ir_container_) { + a.ir_container_->transferStatementOwnership(&b, &a); + } else if (b.ir_container_) { + b.ir_container_->transferStatementOwnership(&a, &b); + } } std::unique_ptr Fusion::segment( @@ -158,10 +202,30 @@ std::unique_ptr Fusion::segment( IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->clear(); - auto ir_cloner = - IrContainer::copy(from->ir_container(), to->ir_container(), to); + IrCloner ir_cloner(to); + + // Clone from's vals in insertion order + for (auto val : from->deterministic_vals()) { + ir_cloner.clone(val); + } + + // Sync per-Fusion name counters from source to dest. + // During cloning, registerVal increments the dest Fusion's counter for each + // val, then IrBuilder::clone overrides the name with setName(src->name()). + // If source names are non-sequential (e.g., {0..10, 22..27} from segmenter + // creating intermediate TVs), the dest counter ends up at N (number of vals) + // instead of max(name)+1. Copying the source's counter state ensures new + // vals created post-copy won't collide with existing names. + to->val_type_name_map_ = from->val_type_name_map_; + to->expr_name_counter_ = from->expr_name_counter_; + + // Wire up definitions and uses on cloned vals + for (auto val : from->vals()) { + ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); + ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); + } - // Remap cached special val pointers through the cloner + // Remap cached special val pointers if (from->zero_val_) { to->zero_val_ = ir_cloner.clone(from->zero_val_); } @@ -179,11 +243,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { ir_cloner.clone(from->magic_zero_val_)->as(); } - for (auto val : from->vals()) { - ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); - ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); - } - to->inputs_ = ir_cloner.clone(from->inputs_); to->outputs_ = ir_cloner.clone(from->outputs_); for (auto inp : to->inputs_) { @@ -193,7 +252,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { out->setIsFusionOutput(true); } - // TODO: put this into ir_cloner instead for (Val* out : from->outputs_) { const AliasInfo& alias = from->io_alias_.get(out); if (alias.type == AllocationType::New) { @@ -206,14 +264,12 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { } to->all_tv_uses_valid_ = from->all_tv_uses_valid_; - // This should never be true on copy, but copying for completeness. to->is_during_update_uses_ = from->is_during_update_uses_; for (const auto& i : from->managed_data_) { if (i.first.has_value()) { to->managed_data_.emplace_back(i.second(ir_cloner, i.first), i.second); } else { - // Don't clone managed data if it has been reset to->managed_data_.emplace_back(i.first, i.second); } } @@ -256,9 +312,10 @@ Fusion::Fusion() : ir_container_(std::make_shared()) { ir_container_->addFusion(this); } -// Copy constructor -Fusion::Fusion(const Fusion& other) : Fusion() { +// Copy constructor -- shares the source's container +Fusion::Fusion(const Fusion& other) : ir_container_(other.ir_container_) { FUSER_PERF_SCOPE("Fusion copy"); + ir_container_->addFusion(this); Fusion::copy(&other, this); } @@ -278,6 +335,9 @@ Fusion& Fusion::operator=(const Fusion& other) { Fusion& Fusion::operator=(Fusion&& other) noexcept { FUSER_PERF_SCOPE("Fusion move assign"); + if (this == &other) { + return *this; + } clear(); swap(*this, other); return *this; @@ -317,6 +377,9 @@ void Fusion::clear() noexcept { axioms_.reset(); metadata_.clear(); + val_type_name_map_.clear(); + expr_name_counter_ = 0; + invalidateTvsAndUses(); is_during_update_uses_ = false; @@ -932,7 +995,7 @@ void Fusion::registerVal(Val* val) { c->vals_up_.emplace_back(val); c->vals_.insert(val); c->per_fusion_vals_[this].insert(val); - val->setName(IrContainerPasskey(), c->getValName(val->vtype())); + val->setName(IrContainerPasskey(), getValName(val->vtype())); } void Fusion::registerExpr(Expr* expr) { @@ -949,7 +1012,7 @@ void Fusion::registerExpr(Expr* expr) { c->exprs_up_.emplace_back(expr); c->exprs_.insert(expr); c->per_fusion_exprs_[this].insert(expr); - expr->setName(IrContainerPasskey(), c->getExprName()); + expr->setName(IrContainerPasskey(), getExprName()); for (Val* input : expr->inputs()) { assertInContainer(input, "Input to expr is invalid, "); diff --git a/csrc/fusion.h b/csrc/fusion.h index d244517a703..69c25b3049b 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -647,6 +647,24 @@ class NVF_API Fusion : public PolymorphicBase { std::unique_ptr> axioms_; std::unordered_map> metadata_; + + // Per-Fusion name counters. Each Fusion independently tracks name assignment + // so that cloned Fusions get matching names (T0→T0) regardless of whether + // they share an IrContainer. This is required by downstream consumers that + // use tv->name() as a map key (alias_memory, GreedyParams, etc.). + std::unordered_map val_type_name_map_; + StmtNameType expr_name_counter_ = 0; + + StmtNameType getValName(ValType vtype) { + if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) { + val_type_name_map_[vtype] = 0; + } + return val_type_name_map_[vtype]++; + } + + StmtNameType getExprName() { + return expr_name_counter_++; + } }; // Template implementations for Fusion::manage() that use IrCloner