Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 101 additions & 38 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,46 +107,90 @@ bool Fusion::sameDefinition(const Fusion& other) const {
void Fusion::swap(Fusion& a, Fusion& b) noexcept {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noexcept swap performs allocating operations

swap is marked noexcept but performs operations that can throw std::bad_alloc:

  • std::vector construction and .assign() on lines 115-128
  • transferFusion does unordered_set::insert (line 135-136)
  • per_fusion_vals_[&a] with operator[] can allocate (line 182)
  • transferStatementOwnership does unordered_set::insert and unordered_map::operator[] (lines 186-192)

If any of these throw, std::terminate will be called. This is a pre-existing pattern from the old swap, but the new implementation adds more allocating operations (the std::vector snapshots). Consider either removing noexcept or wrapping in try/catch with a fallback.

FUSER_PERF_SCOPE("Fusion swap");

// We need to be careful to call IrContainer swap not unique_ptr swap, which
// will only swap the ptrs NOT the contents.
IrContainer::swap(*(a.ir_container()), *(b.ir_container()));
if (&a == &b) {
return;
}

// After swapping container contents, per-Fusion tracking keys point to the
// wrong Fusions. Rename: a's container had b's entries, b's had a's.
a.ir_container()->transferStatementOwnership(&b, &a);
b.ir_container()->transferStatementOwnership(&a, &b);
// Collect statements owned by each Fusion BEFORE swap
std::vector<Val*> a_owned_vals, b_owned_vals;
std::vector<Expr*> a_owned_exprs, b_owned_exprs;

if (a.ir_container_) {
for (auto val : a.vals()) {
val->ir_container_ = &a;
}
for (auto expr : a.deterministic_exprs()) {
expr->ir_container_ = &a;
}
const auto& av = a.ir_container_->valsOwnedBy(&a);
const auto& ae = a.ir_container_->exprsOwnedBy(&a);
a_owned_vals.assign(av.begin(), av.end());
a_owned_exprs.assign(ae.begin(), ae.end());
}
if (b.ir_container_) {
for (auto val : b.vals()) {
val->ir_container_ = &b;
}
for (auto expr : b.deterministic_exprs()) {
expr->ir_container_ = &b;
}
const auto& bv = b.ir_container_->valsOwnedBy(&b);
const auto& be = b.ir_container_->exprsOwnedBy(&b);
b_owned_vals.assign(bv.begin(), bv.end());
b_owned_exprs.assign(be.begin(), be.end());
}

// Transfer Fusion registrations between containers before pointer swap.
// After swap, a will own b's container and b will own a's container.
if (a.ir_container_ && b.ir_container_ &&
a.ir_container_.get() != b.ir_container_.get()) {
a.ir_container_->transferFusion(&a, &b);
b.ir_container_->transferFusion(&b, &a);
}
Comment on lines +131 to 137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing sharing_fusions_ update when one container is null

When only one Fusion has a container (e.g., a has a container and b doesn't), transferFusion is skipped because the guard on line 133 requires both containers to be non-null. After the container pointer swap on line 140, the container's sharing_fusions_ set will still reference the old Fusion pointer instead of the new owner.

For example, if a has container C and b has null:

  1. transferFusion is skipped (line 133 fails)
  2. After swap: a.ir_container_ = null, b.ir_container_ = C
  3. C's sharing_fusions_ still contains &a instead of &b
  4. inContainer() checks would fail for b's statements

While this is likely unreachable in practice (all constructors initialize ir_container_), the defensive code path on lines 189-192 for transferStatementOwnership suggests this case was considered. If so, it should also handle transferFusion:

  if (a.ir_container_ && b.ir_container_ &&
      a.ir_container_.get() != b.ir_container_.get()) {
    a.ir_container_->transferFusion(&a, &b);
    b.ir_container_->transferFusion(&b, &a);
  } else if (a.ir_container_ && !b.ir_container_) {
    a.ir_container_->transferFusion(&a, &b);
  } else if (b.ir_container_ && !a.ir_container_) {
    b.ir_container_->transferFusion(&b, &a);
  }


// Swap container pointers
std::swap(a.ir_container_, b.ir_container_);

// Swap all Fusion-level members
std::swap(a.inputs_, b.inputs_);
std::swap(a.outputs_, b.outputs_);

std::swap(a.io_alias_, b.io_alias_);

// Swap per-Fusion special values (Phase 2)
std::swap(a.all_tv_uses_valid_, b.all_tv_uses_valid_);
std::swap(a.is_during_update_uses_, b.is_during_update_uses_);
std::swap(a.managed_data_, b.managed_data_);
std::swap(a.managed_named_data_, b.managed_named_data_);
std::swap(a.expected_dynamic_smem_bytes_, b.expected_dynamic_smem_bytes_);
std::swap(a.all_tvs_ptr_, b.all_tvs_ptr_);
std::swap(a.zero_val_, b.zero_val_);
std::swap(a.one_val_, b.one_val_);
std::swap(a.true_val_, b.true_val_);
std::swap(a.false_val_, b.false_val_);
std::swap(a.magic_zero_val_, b.magic_zero_val_);

std::swap(a.axioms_, b.axioms_);
std::swap(a.metadata_, b.metadata_);
std::swap(a.val_type_name_map_, b.val_type_name_map_);
std::swap(a.expr_name_counter_, b.expr_name_counter_);

// Update Statement::ir_container_ pointers: a's old statements now belong
// to b, and b's old statements now belong to a
for (auto* val : a_owned_vals) {
val->ir_container_ = &b;
}
for (auto* expr : a_owned_exprs) {
expr->ir_container_ = &b;
}
for (auto* val : b_owned_vals) {
val->ir_container_ = &a;
}
for (auto* expr : b_owned_exprs) {
expr->ir_container_ = &a;
}

// Update per-Fusion tracking keys in containers
if (a.ir_container_ && b.ir_container_) {
if (a.ir_container_.get() == b.ir_container_.get()) {
// Same container: directly swap per-Fusion tracking entries
auto* c = a.ir_container_.get();
std::swap(c->per_fusion_vals_[&a], c->per_fusion_vals_[&b]);
std::swap(c->per_fusion_exprs_[&a], c->per_fusion_exprs_[&b]);
} else {
// Different containers: rename tracking keys to match new owners
a.ir_container_->transferStatementOwnership(&b, &a);
b.ir_container_->transferStatementOwnership(&a, &b);
}
} else if (a.ir_container_) {
a.ir_container_->transferStatementOwnership(&b, &a);
} else if (b.ir_container_) {
b.ir_container_->transferStatementOwnership(&a, &b);
}
}

std::unique_ptr<SegmentedFusion> Fusion::segment(
Expand All @@ -158,10 +202,30 @@ std::unique_ptr<SegmentedFusion> Fusion::segment(
IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
to->clear();

auto ir_cloner =
IrContainer::copy(from->ir_container(), to->ir_container(), to);
IrCloner ir_cloner(to);

// Clone from's vals in insertion order
for (auto val : from->deterministic_vals()) {
ir_cloner.clone(val);
}

// Sync per-Fusion name counters from source to dest.
// During cloning, registerVal increments the dest Fusion's counter for each
// val, then IrBuilder::clone overrides the name with setName(src->name()).
// If source names are non-sequential (e.g., {0..10, 22..27} from segmenter
// creating intermediate TVs), the dest counter ends up at N (number of vals)
// instead of max(name)+1. Copying the source's counter state ensures new
// vals created post-copy won't collide with existing names.
to->val_type_name_map_ = from->val_type_name_map_;
to->expr_name_counter_ = from->expr_name_counter_;

// Wire up definitions and uses on cloned vals
for (auto val : from->vals()) {
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
}

// Remap cached special val pointers through the cloner
// Remap cached special val pointers
if (from->zero_val_) {
to->zero_val_ = ir_cloner.clone(from->zero_val_);
}
Expand All @@ -179,11 +243,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
ir_cloner.clone(from->magic_zero_val_)->as<NamedScalar>();
}

for (auto val : from->vals()) {
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
}

to->inputs_ = ir_cloner.clone(from->inputs_);
to->outputs_ = ir_cloner.clone(from->outputs_);
for (auto inp : to->inputs_) {
Expand All @@ -193,7 +252,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
out->setIsFusionOutput(true);
}

// TODO: put this into ir_cloner instead
for (Val* out : from->outputs_) {
const AliasInfo& alias = from->io_alias_.get(out);
if (alias.type == AllocationType::New) {
Expand All @@ -206,14 +264,12 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
}

to->all_tv_uses_valid_ = from->all_tv_uses_valid_;
// This should never be true on copy, but copying for completeness.
to->is_during_update_uses_ = from->is_during_update_uses_;

for (const auto& i : from->managed_data_) {
if (i.first.has_value()) {
to->managed_data_.emplace_back(i.second(ir_cloner, i.first), i.second);
} else {
// Don't clone managed data if it has been reset
to->managed_data_.emplace_back(i.first, i.second);
}
}
Expand Down Expand Up @@ -256,9 +312,10 @@ Fusion::Fusion() : ir_container_(std::make_shared<IrContainer>()) {
ir_container_->addFusion(this);
}

// Copy constructor
Fusion::Fusion(const Fusion& other) : Fusion() {
// Copy constructor -- shares the source's container
Fusion::Fusion(const Fusion& other) : ir_container_(other.ir_container_) {
FUSER_PERF_SCOPE("Fusion copy");
ir_container_->addFusion(this);
Fusion::copy(&other, this);
}

Expand All @@ -278,6 +335,9 @@ Fusion& Fusion::operator=(const Fusion& other) {

Fusion& Fusion::operator=(Fusion&& other) noexcept {
FUSER_PERF_SCOPE("Fusion move assign");
if (this == &other) {
return *this;
}
clear();
swap(*this, other);
return *this;
Expand Down Expand Up @@ -317,6 +377,9 @@ void Fusion::clear() noexcept {
axioms_.reset();
metadata_.clear();

val_type_name_map_.clear();
expr_name_counter_ = 0;

invalidateTvsAndUses();

is_during_update_uses_ = false;
Expand Down Expand Up @@ -932,7 +995,7 @@ void Fusion::registerVal(Val* val) {
c->vals_up_.emplace_back(val);
c->vals_.insert(val);
c->per_fusion_vals_[this].insert(val);
val->setName(IrContainerPasskey(), c->getValName(val->vtype()));
val->setName(IrContainerPasskey(), getValName(val->vtype()));
}

void Fusion::registerExpr(Expr* expr) {
Expand All @@ -949,7 +1012,7 @@ void Fusion::registerExpr(Expr* expr) {
c->exprs_up_.emplace_back(expr);
c->exprs_.insert(expr);
c->per_fusion_exprs_[this].insert(expr);
expr->setName(IrContainerPasskey(), c->getExprName());
expr->setName(IrContainerPasskey(), getExprName());

for (Val* input : expr->inputs()) {
assertInContainer(input, "Input to expr is invalid, ");
Expand Down
18 changes: 18 additions & 0 deletions csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,24 @@ class NVF_API Fusion : public PolymorphicBase {
std::unique_ptr<std::vector<Val*>> axioms_;

std::unordered_map<Val*, std::pair<Val*, Expr*>> metadata_;

// Per-Fusion name counters. Each Fusion independently tracks name assignment
// so that cloned Fusions get matching names (T0→T0) regardless of whether
// they share an IrContainer. This is required by downstream consumers that
// use tv->name() as a map key (alias_memory, GreedyParams, etc.).
std::unordered_map<ValType, StmtNameType> val_type_name_map_;
StmtNameType expr_name_counter_ = 0;

StmtNameType getValName(ValType vtype) {
if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) {
val_type_name_map_[vtype] = 0;
}
return val_type_name_map_[vtype]++;
}
Comment on lines +658 to +663
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using try_emplace or direct operator[]

Minor style nit: the find + operator[] pattern can be simplified. Since operator[] on unordered_map default-constructs the value (0 for StmtNameType) if the key is missing, the find check is redundant:

Suggested change
StmtNameType getValName(ValType vtype) {
if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) {
val_type_name_map_[vtype] = 0;
}
return val_type_name_map_[vtype]++;
}
StmtNameType getValName(ValType vtype) {
return val_type_name_map_[vtype]++;
}

This produces identical behavior because StmtNameType (which is unsigned int) is value-initialized to 0 when default-constructed by operator[].

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


StmtNameType getExprName() {
return expr_name_counter_++;
}
};

// Template implementations for Fusion::manage<T>() that use IrCloner
Expand Down