Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions src/std/trait/init.luau
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
--- A `Trait` carries a trait's default-method implementations (`D`). It does
--- *not* carry the required-method contract: the contract is supplied directly
--- at each `impl` call (see below). Keeping the contract out of the `Trait`
--- value - and out of every implementer's type - is the key fix over earlier
--- attempts, where intersecting a `Requires` contract into a class' public
--- type re-declared methods the class already had. Luau then saw each such
--- method as an *overloaded* function (`(A) -> R & (B) -> R`) and refused to
--- call it: "Calling function ... is ambiguous". Here the contract is only
--- ever a *constraint* on `impl`'s argument; it never reaches a class type.
export type Trait<D> = {
read defaults: D;
}

--- Defines a trait from its table of default methods. `D` is inferred from
--- `defaults`, so no turbofish is needed:
---
--- return trait.define(defaults)
local function define<D>(defaults: D): Trait<D>
return table.freeze({
defaults = defaults;
})
end

--- Implements `t` for `methods` (a class' `__index` method table): copies the
--- trait's default methods into `methods` and returns it retyped as `M & D`.
--- Defaults never overwrite a method the class already defines - the
--- implementer's own definitions always win.
---
--- The parameter type `methods: M & R` is what enforces the required-method
--- contract: `R` is the trait's `Requires` type, and if `methods` is missing a
--- required method it is not assignable to `M & R`, which is a type error *at
--- this call*. `M` and `R` must both be supplied via turbofish (`M` cannot be
--- recovered from the `M & R` intersection of a single value; `R` is the
--- contract being asserted). `D` is inferred from `t`:
---
--- local Impl = trait.impl<<typeof(methods), MyTrait.Requires>>(methods, MyTrait)
---
--- For a generic trait, pass the contract instantiated at the concrete element
--- type for a fully precise check:
---
--- local Impl = trait.impl<<typeof(methods), Iter.Requires<number>>>(methods, IterTrait)
local function impl<M, R, D>(methods: M & R, t: Trait<D>): M & D
local target: any = methods
local defaults: any = t.defaults

for key, default_fn in defaults do
if target[key] == nil then
target[key] = default_fn
end
end

return target
end

return table.freeze({
define = define;
impl = impl;
})
75 changes: 75 additions & 0 deletions tests/std/trait/.check.luau
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
--!nolint LocalUnused

local trait = require("../../../src/std/trait")

-- A trait, defined the convention way: the `Requires` / `Defaults` / `For`
-- triple, and `define` with no turbofish.
type DisplayRequires<Self> = {
display: (self: Self) -> string;
}

local display_defaults = {}

-- Default method: `self` is `DisplayFor<Self>` because the body calls a trait
-- method (`self:display()`).
function display_defaults.print_display<Self>(self: DisplayFor<Self>): ()
print(self:display())
end

type DisplayDefaults = typeof(display_defaults)
type DisplayFor<Self> = DisplayRequires<Self> & DisplayDefaults

local Display = trait.define(display_defaults)

-- `define` infers `D` from the defaults table; the result is `Trait<D>`.
local display_trait: trait.Trait<DisplayDefaults> = Display

-- An implementer, built the convention way: a method table whose name doubles
-- as the type name, `__index :: {}`, public type
-- `setmetatable<...> & Trait.For<...>`, and a constructor ending in `:: T`.
local Point = {}
local PointPrototype = table.freeze({
__index = Point :: {};
})

function Point.display(self: Point): string
return `({ self.x }, { self.y })`
end

trait.impl<<typeof(Point), DisplayRequires<Point>>>(Point, Display)

type Point = setmetatable<{
read x: number;
read y: number;
}, typeof(PointPrototype)> & DisplayFor<Point>

local function new_point(x: number, y: number): Point
return table.freeze(setmetatable({
x = x;
y = y;
}, PointPrototype)) :: Point
end

-- The implementer exposes both the required method and the copied-in default
-- method off its single public type.
local origin: Point = new_point(0, 0)
local _origin_display: string = origin:display()
origin:print_display()

-- A class missing the required `display` is not assignable to `M & Requires`,
-- so the `impl` call would be a type error:
local bad_methods = {}
function bad_methods.unrelated(self: typeof(bad_methods)): number
return 0
end

-- trait.impl<<typeof(bad_methods), DisplayRequires<typeof(bad_methods)>>>(bad_methods, Display) -- should error

-- A wrongly-typed required method (returns `number`, not `string`) fails the
-- same contract check:
local wrong_methods = {}
function wrong_methods.display(self: typeof(wrong_methods)): number
return 0
end

-- trait.impl<<typeof(wrong_methods), DisplayRequires<typeof(wrong_methods)>>>(wrong_methods, Display) -- should error
98 changes: 98 additions & 0 deletions tests/std/trait/.spec.luau
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
local test = require("@std/test")

local trait = require("../../../src/std/trait")

--- A small trait, defined the convention way: one required method (`name`) and
--- one default method (`greeting`) that dispatches through it. `greeting`'s
--- `self` is `GreetFor<Self>` because the body calls a trait method.
type GreetRequires<Self> = {
name: (self: Self) -> string;
}

local greet_defaults = {}

function greet_defaults.greeting<Self>(self: GreetFor<Self>): string
return `Hello, { self:name() }`
end

type GreetDefaults = typeof(greet_defaults)
type GreetFor<Self> = GreetRequires<Self> & GreetDefaults

local Greet = trait.define(greet_defaults)

test.suite("std.trait", function(suite)
suite:case(".define exposes the defaults table", function(asserts)
asserts.eq(Greet.defaults, greet_defaults)
asserts.eq(type(Greet.defaults.greeting), "function")
end)

suite:case(".define returns a frozen trait", function(asserts)
asserts.eq(table.isfrozen(Greet :: any), true)
end)

suite:case(".impl copies default methods into the method table", function(asserts)
local methods = {}
function methods.name(self: typeof(methods)): string
return "world"
end

trait.impl<<typeof(methods), GreetRequires<typeof(methods)>>>(methods, Greet)

asserts.eq(type((methods :: any).greeting), "function")
end)

suite:case(".impl returns the same method table", function(asserts)
local methods = {}
function methods.name(self: typeof(methods)): string
return "world"
end

local returned = trait.impl<<typeof(methods), GreetRequires<typeof(methods)>>>(methods, Greet)

asserts.eq(returned, methods)
end)

suite:case(".impl does not overwrite a method the class already defines", function(asserts)
local methods = {}
function methods.name(self: typeof(methods)): string
return "world"
end
-- the class defines its own `greeting`, shadowing the trait default
function methods.greeting(self: typeof(methods)): string
return "custom"
end

trait.impl<<typeof(methods), GreetRequires<typeof(methods)>>>(methods, Greet)

asserts.eq((methods :: any):greeting(), "custom")
end)

suite:case("a default method dispatches through the implementer's required method", function(asserts)
-- A convention-shaped implementer: a method table whose name doubles as
-- the type name, with `__index :: {}` so methods have a single type
-- source, and a public type `setmetatable<...> & For<...>`.
local Person = {}
local PersonPrototype = table.freeze({
__index = Person :: {};
})

function Person.name(self: Person): string
return self.who
end

trait.impl<<typeof(Person), GreetRequires<Person>>>(Person, Greet)

type Person = setmetatable<{
read who: string;
}, typeof(PersonPrototype)> & GreetFor<Person>

local alice: Person = table.freeze(setmetatable({
who = "Alice";
}, PersonPrototype)) :: Person

asserts.eq(alice:name(), "Alice")
asserts.eq(alice:greeting(), "Hello, Alice")
end)
end)

return nil