Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions packages/reflex-base/src/reflex_base/components/memo.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ class MemoComponentDefinition(MemoDefinition):

export_name: str
_component: _LazyBody[Component]
_runtime_param_values: dict[str, Any] = dataclasses.field(
default_factory=dict, repr=False, compare=False
)
# For passthrough wrappers built by the auto-memoize plugin: the
# ``Bare``-wrapped ``{children}`` placeholder used when rendering the memo
# body. The ``component`` keeps its ORIGINAL children so compile-time
Expand Down Expand Up @@ -724,7 +727,11 @@ def _rest_placeholder(name: str) -> RestProp:
return RestProp(_js_expr=name, _var_type=dict[str, Any])


def _var_placeholder(name: str, annotation: Any) -> Var:
def _var_placeholder(
name: str,
annotation: Any,
runtime_value: Any | None = None,
) -> Var:
"""Create a placeholder Var for a memo parameter.

Args:
Expand All @@ -734,6 +741,11 @@ def _var_placeholder(name: str, annotation: Any) -> Var:
Returns:
The placeholder Var.
"""
if _annotation_inner_type(annotation) is Any and runtime_value is not None:
runtime_type = (
runtime_value._var_type if isinstance(runtime_value, Var) else type(runtime_value)
)
return Var(_js_expr=name, _var_type=runtime_type).guess_type()
return Var(_js_expr=name, _var_type=_annotation_inner_type(annotation)).guess_type()


Expand Down Expand Up @@ -1001,6 +1013,7 @@ def finalize(
def _evaluate_memo_function(
fn: Callable[..., Any],
params: tuple[MemoParam, ...],
runtime_values: Mapping[str, Any] | None = None,
) -> Any:
"""Evaluate a memo function with placeholder vars.

Expand All @@ -1015,7 +1028,14 @@ def _evaluate_memo_function(
keyword_args = {}

for param in params:
placeholder = param.make_placeholder()
if param.kind is MemoParamKind.VALUE:
placeholder = _var_placeholder(
param.placeholder_name,
param.annotation,
runtime_values.get(param.name) if runtime_values is not None else None,
)
else:
placeholder = param.make_placeholder()
if param.parameter_kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
Expand Down Expand Up @@ -1267,7 +1287,9 @@ def _build_args_function(


def _evaluate_component_body(
fn: Callable[..., Any], params: tuple[MemoParam, ...]
fn: Callable[..., Any],
params: tuple[MemoParam, ...],
runtime_values: Mapping[str, Any] | None = None,
) -> Component:
"""Run a component memo's body and return its compiled component.

Expand All @@ -1281,7 +1303,9 @@ def _evaluate_component_body(
Raises:
TypeError: If the body does not return a component.
"""
body = _normalize_component_return(_evaluate_memo_function(fn, params))
body = _normalize_component_return(
_evaluate_memo_function(fn, params, runtime_values)
)
if body is None:
msg = (
f"Component-returning `@rx.memo` `{fn.__name__}` must return an "
Expand Down Expand Up @@ -1325,12 +1349,16 @@ def _create_component_definition(
TypeError: If the function does not return a component.
"""
params = _analyze_params(fn, for_component=True)
runtime_param_values: dict[str, Any] = {}
return MemoComponentDefinition(
fn=fn,
python_name=fn.__name__,
params=params,
export_name=format.to_title_case(fn.__name__),
_component=_LazyBody.ready(_evaluate_component_body(fn, params)),
_component=_LazyBody(
lambda: _evaluate_component_body(fn, params, runtime_param_values)
),
_runtime_param_values=runtime_param_values,
)


Expand Down Expand Up @@ -1593,8 +1621,14 @@ def __call__(self, *children: Any, **props: Any) -> MemoComponent:

# Reading ``component`` materializes the deferred body, so ``type(...)``
# reflects the real wrapped class rather than the placeholder.
definition._runtime_param_values.clear()
definition._runtime_param_values.update(explicit_values)
try:
component_type = type(definition.component)
finally:
definition._runtime_param_values.clear()
return _get_memo_component_class(
definition.export_name, type(definition.component)
definition.export_name, component_type
)._create(
children=list(children),
memo_definition=definition,
Expand Down Expand Up @@ -1881,15 +1915,17 @@ def memo(fn: Callable[..., Any]) -> _MemoComponentWrapper | _MemoFunctionWrapper
# where the name resolves to ``wrapper`` (already bound by first use).
definition: MemoComponentDefinition | MemoFunctionDefinition
if is_component:
runtime_param_values: dict[str, Any] = {}
definition = MemoComponentDefinition(
fn=fn,
python_name=fn.__name__,
params=params,
export_name=format.to_title_case(fn.__name__),
_component=_LazyBody(
lambda: _evaluate_component_body(fn, params),
lambda: _evaluate_component_body(fn, params, runtime_param_values),
placeholder=Fragment.create(),
),
_runtime_param_values=runtime_param_values,
)
wrapper = _create_component_wrapper(definition)
else:
Expand Down
36 changes: 36 additions & 0 deletions tests/units/components/test_memo.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,42 @@ def soft_missing(value) -> rx.Component:
assert "`value`" in kwargs["reason"]


def test_memo_uses_first_call_value_type_for_missing_param_annotation():
"""Component memos should infer missing parameter types from the first call."""

@rx.memo
def user_card(user) -> rx.Component:
return rx.box(
rx.heading(user["name"]),
rx.text(user["email"]),
)
Comment thread
harsh21234i marked this conversation as resolved.

component = user_card(
user={"name": "Ada", "email": "ada@example.com"},
)

assert isinstance(component, MemoComponent)


def test_memo_uses_var_runtime_value_type_for_missing_param_annotation():
"""Component memos should infer missing parameter types from runtime Vars."""

@rx.memo
def user_card(user) -> rx.Component:
return rx.box(
rx.heading(user["name"]),
rx.text(user["email"]),
)

component = user_card(
user=Var.create({"name": "Ada", "email": "ada@example.com"}),
)

assert isinstance(component, MemoComponent)
assert isinstance(component.user, Var)
assert component.user._var_type is dict


def test_memo_warns_on_missing_return_annotation():
"""A missing return annotation should default to ``rx.Component`` with a warning."""
with patch.object(console, "deprecate") as mock_deprecate:
Comment thread
harsh21234i marked this conversation as resolved.
Expand Down