diff --git a/custom_components/pyscript/__init__.py b/custom_components/pyscript/__init__.py index 79e5fc5..1cf5a6f 100644 --- a/custom_components/pyscript/__init__.py +++ b/custom_components/pyscript/__init__.py @@ -1,6 +1,7 @@ """Component to allow running Python scripts.""" import asyncio +from collections.abc import Awaitable, Callable import glob import json import logging @@ -8,7 +9,7 @@ import shutil import time import traceback -from typing import Any, Callable, Dict, List, Set, Union +from typing import Any import voluptuous as vol from watchdog.events import DirModifiedEvent, FileSystemEvent, FileSystemEventHandler @@ -22,7 +23,7 @@ EVENT_STATE_CHANGED, SERVICE_RELOAD, ) -from homeassistant.core import Event as HAEvent, HomeAssistant, ServiceCall +from homeassistant.core import Event as HAEvent, HomeAssistant, ServiceCall, SupportsResponse from homeassistant.exceptions import HomeAssistantError import homeassistant.helpers.config_validation as cv from homeassistant.helpers.restore_state import DATA_RESTORE_STATE @@ -40,7 +41,6 @@ REQUIREMENTS_FILE, SERVICE_GENERATE_STUBS, SERVICE_JUPYTER_KERNEL_START, - SERVICE_RESPONSE_ONLY, UNSUB_LISTENERS, WATCHDOG_TASK, ) @@ -98,7 +98,7 @@ async def update_yaml_config(hass: HomeAssistant, config_entry: ConfigEntry) -> conf = await async_hass_config_yaml(hass) except HomeAssistantError as err: _LOGGER.error(err) - return + return False config = PYSCRIPT_SCHEMA(conf.get(DOMAIN, {})) @@ -128,7 +128,7 @@ async def update_yaml_config(hass: HomeAssistant, config_entry: ConfigEntry) -> return False -def start_global_contexts(global_ctx_only: str = None) -> None: +def start_global_contexts(global_ctx_only: str | None = None) -> None: """Start all the file and apps global contexts.""" start_list = [] for global_ctx_name, global_ctx in GlobalContextMgr.items(): @@ -145,7 +145,9 @@ def start_global_contexts(global_ctx_only: str = None) -> None: async def watchdog_start( - hass: HomeAssistant, pyscript_folder: str, reload_scripts_handler: Callable[[None], None] + hass: HomeAssistant, + pyscript_folder: str, + reload_scripts_handler: Callable[[ServiceCall], Awaitable[None]], ) -> None: """Start watchdog thread to look for changed files in pyscript_folder.""" if WATCHDOG_TASK in hass.data[DOMAIN]: @@ -201,7 +203,7 @@ def on_deleted(self, event: FileSystemEvent) -> None: self.process(event) async def task_watchdog(watchdog_q: asyncio.Queue) -> None: - def check_event(event, do_reload: bool) -> bool: + def check_event(event: FileSystemEvent, do_reload: bool) -> bool: """Check if event should trigger a reload.""" if event.is_directory: # don't reload if it's just a directory modified @@ -230,7 +232,7 @@ def check_event(event, do_reload: bool) -> bool: do_reload = check_event( await asyncio.wait_for(watchdog_q.get(), timeout=0.05), do_reload ) - except asyncio.TimeoutError: + except TimeoutError: break if do_reload: await reload_scripts_handler(None) @@ -304,14 +306,14 @@ async def reload_scripts_handler(call: ServiceCall) -> None: hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_scripts_handler) - async def generate_stubs_service(call: ServiceCall) -> Dict[str, Any]: + async def generate_stubs_service(call: ServiceCall) -> dict[str, Any]: """Generate pyscript IDE stub files.""" generator = StubsGenerator(hass) generated_body = await generator.build() stubs_path = os.path.join(hass.config.path(FOLDER), "modules", "stubs") - def write_stubs(path) -> dict[str, Any]: + def write_stubs(path: str) -> dict[str, Any]: res: dict[str, Any] = {} try: os.makedirs(path, exist_ok=True) @@ -342,7 +344,7 @@ def write_stubs(path) -> dict[str, Any]: return result hass.services.async_register( - DOMAIN, SERVICE_GENERATE_STUBS, generate_stubs_service, supports_response=SERVICE_RESPONSE_ONLY + DOMAIN, SERVICE_GENERATE_STUBS, generate_stubs_service, supports_response=SupportsResponse.ONLY ) async def jupyter_kernel_start(call: ServiceCall) -> None: @@ -443,7 +445,7 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> return True -async def unload_scripts(global_ctx_only: str = None, unload_all: bool = False) -> None: +async def unload_scripts(global_ctx_only: str | None = None, unload_all: bool = False) -> None: """Unload all scripts from GlobalContextMgr with given name prefixes.""" ctx_delete = {} for global_ctx_name, global_ctx in GlobalContextMgr.items(): @@ -462,7 +464,9 @@ async def unload_scripts(global_ctx_only: str = None, unload_all: bool = False) @bind_hass -async def load_scripts(hass: HomeAssistant, config_data: Dict[str, Any], global_ctx_only: str = None): +async def load_scripts( + hass: HomeAssistant, config_data: dict[str, Any], global_ctx_only: str | None = None +) -> None: """Load all python scripts in FOLDER.""" class SourceFile: @@ -496,8 +500,8 @@ def __init__( pyscript_dir = hass.config.path(FOLDER) def glob_read_files( - load_paths: List[Set[Union[str, bool]]], apps_config: Dict[str, Any] - ) -> Dict[str, SourceFile]: + load_paths: list[tuple[str, str, bool, bool]], apps_config: dict[str, Any] + ) -> dict[str, SourceFile]: """Expand globs and read all the source files.""" ctx2source = {} for path, match, check_config, autoload in load_paths: diff --git a/custom_components/pyscript/const.py b/custom_components/pyscript/const.py index 5800f78..d4d8981 100644 --- a/custom_components/pyscript/const.py +++ b/custom_components/pyscript/const.py @@ -1,20 +1,5 @@ """Define pyscript-wide constants.""" -# -# 2023.7 supports service response; handle older versions by defaulting enum -# Should eventually deprecate this and just use SupportsResponse import -# -try: - from homeassistant.core import SupportsResponse - - SERVICE_RESPONSE_NONE = SupportsResponse.NONE - SERVICE_RESPONSE_OPTIONAL = SupportsResponse.OPTIONAL - SERVICE_RESPONSE_ONLY = SupportsResponse.ONLY -except ImportError: - SERVICE_RESPONSE_NONE = None - SERVICE_RESPONSE_OPTIONAL = None - SERVICE_RESPONSE_ONLY = None - DOMAIN = "pyscript" CONFIG_ENTRY = "config_entry" diff --git a/custom_components/pyscript/eval.py b/custom_components/pyscript/eval.py index 780a1f8..e463f13 100644 --- a/custom_components/pyscript/eval.py +++ b/custom_components/pyscript/eval.py @@ -13,11 +13,13 @@ import sys import time import traceback +from typing import TYPE_CHECKING, Any import weakref import yaml from homeassistant.const import SERVICE_RELOAD +from homeassistant.core import SupportsResponse from homeassistant.helpers.service import async_set_service_schema from .const import ( @@ -27,11 +29,15 @@ DOMAIN, LOGGER_PATH, SERVICE_JUPYTER_KERNEL_START, - SERVICE_RESPONSE_NONE, ) from .function import Function from .state import State +if TYPE_CHECKING: + from .global_ctx import GlobalContext + +type SymTable = dict[str, Any] + _LOGGER = logging.getLogger(LOGGER_PATH + ".eval") # @@ -315,7 +321,14 @@ def getattr(self): class EvalFunc: """Class for a callable pyscript function.""" - def __init__(self, func_def, code_list, code_str, global_ctx, async_func=False): + def __init__( + self, + func_def: ast.FunctionDef, + code_list: list[str], + code_str: str, + global_ctx: "GlobalContext", + async_func: bool = False, + ) -> None: """Initialize a function calling context.""" self.func_def = func_def self.name = func_def.name @@ -533,7 +546,7 @@ async def do_service_call(func, ast_ctx, data): domain, name, pyscript_service_factory(func_name, self), - dec_kwargs.get("supports_response", SERVICE_RESPONSE_NONE), + dec_kwargs.get("supports_response", SupportsResponse.NONE), ) async_set_service_schema(Function.hass, domain, name, service_desc) self.trigger_service.add(srv_name) @@ -838,12 +851,12 @@ async def check_for_closure(self, arg): class EvalFuncVar: """Class for a callable pyscript function.""" - def __init__(self, func): + def __init__(self, func: EvalFunc) -> None: """Initialize instance with given EvalFunc function.""" self.func = func - self.ast_ctx = None + self.ast_ctx: AstEval | None = None - def get_func(self): + def get_func(self) -> EvalFunc: """Return the EvalFunc function.""" return self.func @@ -895,7 +908,7 @@ async def __call__(self, *args, **kwargs): class EvalFuncVarClassInst(EvalFuncVar): """Class for a callable pyscript class instance function.""" - def __init__(self, func, ast_ctx, class_inst_weak): + def __init__(self, func: EvalFunc, ast_ctx: "AstEval", class_inst_weak: weakref.ReferenceType) -> None: """Initialize instance with given EvalFunc function.""" super().__init__(func) self.ast_ctx = ast_ctx @@ -913,25 +926,25 @@ async def __call__(self, *args, **kwargs): class AstEval: """Python interpreter AST object evaluator.""" - def __init__(self, name, global_ctx, logger_name=None): + def __init__(self, name: str, global_ctx: "GlobalContext", logger_name: str | None = None) -> None: """Initialize an interpreter execution context.""" self.name = name self.str = None self.ast = None self.global_ctx = global_ctx - self.global_sym_table = global_ctx.get_global_sym_table() if global_ctx else {} - self.sym_table_stack = [] + self.global_sym_table: SymTable = global_ctx.get_global_sym_table() if global_ctx else {} + self.sym_table_stack: list[SymTable] = [] self.sym_table = self.global_sym_table - self.local_sym_table = {} - self.user_locals = {} - self.curr_func = None + self.local_sym_table: SymTable = {} + self.user_locals: SymTable = {} + self.curr_func: EvalFunc | None = None self.filename = name - self.code_str = None - self.code_list = None - self.exception = None - self.exception_obj = None - self.exception_long = None - self.exception_curr = None + self.code_str: str | None = None + self.code_list: list[str] | None = None + self.exception: str | None = None + self.exception_obj: Exception | None = None + self.exception_long: str | None = None + self.exception_curr: Exception | None = None self.lineno = 1 self.col_offset = 0 self.logger_handlers = set() @@ -2159,7 +2172,7 @@ async def get_names(self, this_ast=None, nonlocal_names=None, global_names=None, await self.get_names_set(this_ast, names, nonlocal_names, global_names, local_names) return names - def parse(self, code_str, filename=None, mode="exec"): + def parse(self, code_str: str, filename: str | None = None, mode: str = "exec") -> bool: """Parse the code_str source code into an AST tree.""" self.exception = None self.exception_obj = None @@ -2298,7 +2311,7 @@ def completions(self, root): for attr in var.__dict__: if attr.lower().startswith(attr_root) and (attr_root != "" or attr[0:1] != "_"): words.add(f"{name}.{attr}") - except Exception: + except Exception: # noqa: S110 pass for keyw in set(keyword.kwlist) - {"yield"}: if keyw.lower().startswith(root): @@ -2313,7 +2326,7 @@ def completions(self, root): words.add(name) return words - async def eval(self, new_state_vars=None, merge_local=False): + async def eval(self, new_state_vars: dict[str, Any] | None = None, merge_local: bool = False) -> None: """Execute parsed code, with the optional state variables added to the scope.""" self.exception = None self.exception_obj = None diff --git a/custom_components/pyscript/function.py b/custom_components/pyscript/function.py index 02388a3..0c55e52 100644 --- a/custom_components/pyscript/function.py +++ b/custom_components/pyscript/function.py @@ -1,12 +1,15 @@ """Function call handling.""" import asyncio +from asyncio import Task +from collections.abc import Callable import logging import traceback +from typing import ClassVar -from homeassistant.core import Context +from homeassistant.core import Context, SupportsResponse -from .const import LOGGER_PATH, SERVICE_RESPONSE_NONE, SERVICE_RESPONSE_ONLY +from .const import LOGGER_PATH _LOGGER = logging.getLogger(LOGGER_PATH + ".function") @@ -22,34 +25,34 @@ class Function: # # Mappings of tasks ids <-> task names # - unique_task2name = {} - unique_name2task = {} + unique_task2name: ClassVar[dict[Task, set[str]]] = {} + unique_name2task: ClassVar[dict[str, Task]] = {} # # Mappings of task id to hass contexts - task2context = {} + task2context: ClassVar[dict[Task, Context]] = {} # # Set of tasks that are running # - our_tasks = set() + our_tasks: ClassVar[set[Task]] = set() # # Done callbacks for each task # - task2cb = {} + task2cb: ClassVar[dict[Task, dict]] = {} # # initial list of available functions # - functions = {} + functions: ClassVar[dict[str, Callable]] = {} # # Functions that take the AstEval context as a first argument, # which is needed by a handful of special functions that need the # ast context # - ast_functions = {} + ast_functions: ClassVar[dict[str, Callable]] = {} # # task id of the task that cancels and waits for other tasks, @@ -68,13 +71,13 @@ class Function: # registers the service call before the old one is removed, so we only # remove the service registration when the reference count goes to zero # - service_cnt = {} + service_cnt: ClassVar[dict[str, int]] = {} # # save the global_ctx name where a service is registered so we can raise # an exception if it gets registered by a different global_ctx. # - service2global_ctx = {} + service2global_ctx: ClassVar[dict[str, str]] = {} def __init__(self): """Warn on Function instantiation.""" @@ -413,25 +416,16 @@ async def service_call(*args, **kwargs): @classmethod async def hass_services_async_call(cls, domain, service, kwargs, **hass_args): """Call a hass async service.""" - if SERVICE_RESPONSE_ONLY is None: - # backwards compatibility < 2023.7 - await cls.hass.services.async_call(domain, service, kwargs, **hass_args) - else: - # allow service responses >= 2023.7 - if ( - "return_response" in hass_args - and hass_args["return_response"] - and "blocking" not in hass_args - ): + if "return_response" in hass_args and hass_args["return_response"] and "blocking" not in hass_args: + hass_args["blocking"] = True + elif ( + "return_response" not in hass_args + and cls.hass.services.supports_response(domain, service) == SupportsResponse.ONLY + ): + hass_args["return_response"] = True + if "blocking" not in hass_args: hass_args["blocking"] = True - elif ( - "return_response" not in hass_args - and cls.hass.services.supports_response(domain, service) == SERVICE_RESPONSE_ONLY - ): - hass_args["return_response"] = True - if "blocking" not in hass_args: - hass_args["blocking"] = True - return await cls.hass.services.async_call(domain, service, kwargs, **hass_args) + return await cls.hass.services.async_call(domain, service, kwargs, **hass_args) @classmethod async def run_coro(cls, coro, ast_ctx=None): @@ -439,7 +433,7 @@ async def run_coro(cls, coro, ast_ctx=None): # # Add a placeholder for the new task so we know it's one we started # - task: asyncio.Task = None + task: asyncio.Task | None = None try: task = asyncio.current_task() cls.our_tasks.add(task) @@ -474,7 +468,7 @@ def create_task(cls, coro, ast_ctx=None): @classmethod def service_register( - cls, global_ctx_name, domain, service, callback, supports_response=SERVICE_RESPONSE_NONE + cls, global_ctx_name, domain, service, callback, supports_response=SupportsResponse.NONE ): """Register a new service callback.""" key = f"{domain}.{service}" @@ -487,12 +481,7 @@ def service_register( f"{global_ctx_name}: can't register service {key}; already defined in {cls.service2global_ctx[key]}" ) cls.service_cnt[key] += 1 - if SERVICE_RESPONSE_ONLY is None: - # backwards compatibility < 2023.7 - cls.hass.services.async_register(domain, service, callback) - else: - # allow service responses >= 2023.7 - cls.hass.services.async_register(domain, service, callback, supports_response=supports_response) + cls.hass.services.async_register(domain, service, callback, supports_response=supports_response) @classmethod def service_remove(cls, global_ctx_name, domain, service): diff --git a/custom_components/pyscript/global_ctx.py b/custom_components/pyscript/global_ctx.py index 3d382ed..746a5c3 100644 --- a/custom_components/pyscript/global_ctx.py +++ b/custom_components/pyscript/global_ctx.py @@ -1,14 +1,15 @@ """Global context handling.""" +from collections.abc import Awaitable, Callable import logging import os from types import ModuleType -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, ClassVar from homeassistant.config_entries import ConfigEntry from .const import CONF_HASS_IS_GLOBAL, CONFIG_ENTRY, DOMAIN, FOLDER, LOGGER_PATH -from .eval import AstEval, EvalFunc +from .eval import AstEval, EvalFunc, SymTable from .function import Function from .trigger import TrigInfo @@ -20,29 +21,29 @@ class GlobalContext: def __init__( self, - name, - global_sym_table: Dict[str, Any] = None, - manager=None, - rel_import_path: str = None, - app_config: Dict[str, Any] = None, - source: str = None, - mtime: float = None, + name: str, + global_sym_table: SymTable | None = None, + manager: type["GlobalContextMgr"] | None = None, + rel_import_path: str | None = None, + app_config: dict[str, Any] | None = None, + source: str | None = None, + mtime: float | None = None, ) -> None: """Initialize GlobalContext.""" self.name: str = name - self.global_sym_table: Dict[str, Any] = global_sym_table if global_sym_table else {} - self.triggers: Set[EvalFunc] = set() - self.triggers_delay_start: Set[EvalFunc] = set() + self.global_sym_table: SymTable = global_sym_table if global_sym_table else {} + self.triggers: set[EvalFunc] = set() + self.triggers_delay_start: set[EvalFunc] = set() self.logger: logging.Logger = logging.getLogger(LOGGER_PATH + "." + name) - self.manager: GlobalContextMgr = manager + self.manager = manager self.auto_start: bool = False - self.module: ModuleType = None + self.module: ModuleType | None = None self.rel_import_path: str = rel_import_path self.source: str = source - self.file_path: str = None + self.file_path: str | None = None self.mtime: float = mtime - self.app_config: Dict[str, Any] = app_config - self.imports: Set[str] = set() + self.app_config = app_config + self.imports: set[str] = set() config_entry: ConfigEntry = Function.hass.data.get(DOMAIN, {}).get(CONFIG_ENTRY, {}) if config_entry.data.get(CONF_HASS_IS_GLOBAL, False): # @@ -87,11 +88,11 @@ def get_name(self) -> str: """Return the global context name.""" return self.name - def set_logger_name(self, name) -> None: + def set_logger_name(self, name: str) -> None: """Set the global context logging name.""" self.logger = logging.getLogger(LOGGER_PATH + "." + name) - def get_global_sym_table(self) -> Dict[str, Any]: + def get_global_sym_table(self) -> dict[str, Any]: """Return the global symbol table.""" return self.global_sym_table @@ -99,7 +100,7 @@ def get_source(self) -> str: """Return the source code.""" return self.source - def get_app_config(self) -> Dict[str, Any]: + def get_app_config(self) -> dict[str, Any]: """Return the app config.""" return self.app_config @@ -111,22 +112,24 @@ def get_file_path(self) -> str: """Return the file path.""" return self.file_path - def get_imports(self) -> Set[str]: + def get_imports(self) -> set[str]: """Return the imports.""" return self.imports - def get_trig_info(self, name: str, trig_args: Dict[str, Any]) -> TrigInfo: + def get_trig_info(self, name: str, trig_args: dict[str, Any]) -> TrigInfo: """Return a new trigger info instance with the given args.""" return TrigInfo(name, trig_args, self) - async def module_import(self, module_name: str, import_level: int) -> List[Optional[str]]: + async def module_import( + self, module_name: str, import_level: int + ) -> tuple[ModuleType | None, AstEval | None]: """Import a pyscript module from the pyscript/modules or apps folder.""" pyscript_dir = Function.hass.config.path(FOLDER) module_path = module_name.replace(".", "/") file_paths = [] - def find_first_file(file_paths: List[Set[str]]) -> List[Optional[Union[str, ModuleType]]]: + def find_first_file(file_paths: list[list[str]]) -> list[str] | None: for ctx_name, path, rel_path in file_paths: abs_path = os.path.join(pyscript_dir, path) if os.path.isfile(abs_path): @@ -173,14 +176,14 @@ def find_first_file(file_paths: List[Set[str]]) -> List[Optional[Union[str, Modu mod_ctx = self.manager.get(ctx_name) if mod_ctx and mod_ctx.module: self.imports.add(mod_ctx.get_name()) - return [mod_ctx.module, None] + return mod_ctx.module, None # # not loaded already, so try to find and import it # file_info = await Function.hass.async_add_executor_job(find_first_file, file_paths) if not file_info: - return [None, None] + return None, None [ctx_name, file_path, rel_import_path] = file_info @@ -197,10 +200,10 @@ def find_first_file(file_paths: List[Set[str]]) -> List[Optional[Union[str, Modu ctx_name, file_path, ) - return [None, error_ctx] + return None, error_ctx global_ctx.module = mod self.imports.add(ctx_name) - return [mod, None] + return mod, None class GlobalContextMgr: @@ -209,12 +212,12 @@ class GlobalContextMgr: # # map of context names to contexts # - contexts = {} + contexts: ClassVar[dict[str, GlobalContext]] = {} # # sequence number for sessions # - name_seq = 0 + name_seq: ClassVar[int] = 0 def __init__(self) -> None: """Report an error if GlobalContextMgr in instantiated.""" @@ -224,29 +227,29 @@ def __init__(self) -> None: def init(cls) -> None: """Initialize GlobalContextMgr.""" - def get_global_ctx_factory(ast_ctx: AstEval) -> Callable[[], str]: + def get_global_ctx_factory(ast_ctx: AstEval) -> Callable[[], Awaitable[str]]: """Generate a pyscript.get_global_ctx() function with given ast_ctx.""" - async def get_global_ctx(): + async def get_global_ctx() -> str: return ast_ctx.get_global_ctx_name() return get_global_ctx - def list_global_ctx_factory(ast_ctx: AstEval) -> Callable[[], List[str]]: + def list_global_ctx_factory(ast_ctx: AstEval) -> Callable[[], Awaitable[list[str]]]: """Generate a pyscript.list_global_ctx() function with given ast_ctx.""" - async def list_global_ctx(): + async def list_global_ctx() -> list[str]: ctx_names = set(cls.contexts.keys()) curr_ctx_name = ast_ctx.get_global_ctx_name() ctx_names.discard(curr_ctx_name) - return [curr_ctx_name] + sorted(sorted(ctx_names)) + return [curr_ctx_name, *sorted(ctx_names)] return list_global_ctx - def set_global_ctx_factory(ast_ctx: AstEval) -> Callable[[str], None]: + def set_global_ctx_factory(ast_ctx: AstEval) -> Callable[[str], Awaitable[None]]: """Generate a pyscript.set_global_ctx() function with given ast_ctx.""" - async def set_global_ctx(name): + async def set_global_ctx(name: str) -> None: global_ctx = cls.get(name) if global_ctx is None: raise NameError(f"global context '{name}' does not exist") @@ -264,7 +267,7 @@ async def set_global_ctx(name): Function.register_ast(ast_funcs) @classmethod - def get(cls, name: str) -> Optional[str]: + def get(cls, name: str) -> GlobalContext | None: """Return the GlobalContext given a name.""" return cls.contexts.get(name, None) @@ -274,7 +277,7 @@ def set(cls, name: str, global_ctx: GlobalContext) -> None: cls.contexts[name] = global_ctx @classmethod - def items(cls) -> List[Set[Union[str, GlobalContext]]]: + def items(cls) -> list[tuple[str, GlobalContext]]: """Return all the global context items.""" return sorted(cls.contexts.items()) @@ -297,14 +300,14 @@ def new_name(cls, root: str) -> str: @classmethod async def load_file( - cls, global_ctx: GlobalContext, file_path: str, source: str = None, reload: bool = False - ) -> Set[Union[bool, AstEval]]: + cls, global_ctx: GlobalContext, file_path: str, source: str | None = None, reload: bool = False + ) -> tuple[bool, AstEval | None]: """Load, parse and run the given script file; returns error ast_ctx on error, or None if ok.""" mtime = None if source is None: - def read_file(path: str) -> Set[Union[str, float]]: + def read_file(path: str) -> tuple[str | None, float]: try: with open(path, encoding="utf-8") as file_desc: source = file_desc.read() diff --git a/custom_components/pyscript/jupyter_kernel.py b/custom_components/pyscript/jupyter_kernel.py index 9b63763..a402783 100644 --- a/custom_components/pyscript/jupyter_kernel.py +++ b/custom_components/pyscript/jupyter_kernel.py @@ -97,7 +97,7 @@ async def handshake(self): # _LOGGER.debug(f"handshake: got initial greeting {greeting}") await self.write_bytes(b"\x03") _ = await self.read_bytes(1) - await self.write_bytes(b"\x00" + "NULL".encode() + b"\x00" * 16 + b"\x00" + b"\x00" * 31) + await self.write_bytes(b"\x00" + b"NULL" + b"\x00" * 16 + b"\x00" + b"\x00" * 31) _ = await self.read_bytes(53) # _LOGGER.debug(f"handshake: got rest of greeting {greeting}") params = [["Socket-Type", self.type]] diff --git a/custom_components/pyscript/requirements.py b/custom_components/pyscript/requirements.py index 7df00e5..1d51700 100644 --- a/custom_components/pyscript/requirements.py +++ b/custom_components/pyscript/requirements.py @@ -1,9 +1,9 @@ """Requirements helpers for pyscript.""" import glob +from importlib.metadata import PackageNotFoundError, version as installed_version import logging import os -import sys from homeassistant.loader import bind_hass from homeassistant.requirements import async_process_requirements @@ -21,17 +21,6 @@ UNPINNED_VERSION, ) -if sys.version_info[:2] >= (3, 8): - from importlib.metadata import ( # pylint: disable=no-name-in-module,import-error - PackageNotFoundError, - version as installed_version, - ) -else: - from importlib_metadata import ( # pylint: disable=import-error - PackageNotFoundError, - version as installed_version, - ) - _LOGGER = logging.getLogger(LOGGER_PATH) @@ -75,7 +64,7 @@ def process_all_requirements(pyscript_folder, requirements_paths, requirements_f all_requirements_to_process = {} for root in requirements_paths: for requirements_path in glob.glob(os.path.join(pyscript_folder, root, requirements_file)): - with open(requirements_path, "r", encoding="utf-8") as requirements_fp: + with open(requirements_path, encoding="utf-8") as requirements_fp: all_requirements_to_process[requirements_path] = requirements_fp.readlines() all_requirements_to_install = {} @@ -217,10 +206,8 @@ async def install_requirements(hass, config_entry, pyscript_folder): if all_requirements and not config_entry.data.get(CONF_ALLOW_ALL_IMPORTS, False): _LOGGER.error( - ( - "Requirements detected but 'allow_all_imports' is set to False, set " - "'allow_all_imports' to True if you want packages to be installed" - ) + "Requirements detected but 'allow_all_imports' is set to False, set " + "'allow_all_imports' to True if you want packages to be installed" ) return @@ -234,10 +221,7 @@ async def install_requirements(hass, config_entry, pyscript_folder): # defer to what is installed if version_to_install == UNPINNED_VERSION: _LOGGER.debug( - ( - "Skipping unpinned version of package '%s' because version '%s' is " - "already installed" - ), + "Skipping unpinned version of package '%s' because version '%s' is already installed", package, pkg_installed_version, ) diff --git a/custom_components/pyscript/state.py b/custom_components/pyscript/state.py index 102a06f..68bfd3e 100644 --- a/custom_components/pyscript/state.py +++ b/custom_components/pyscript/state.py @@ -3,9 +3,10 @@ import asyncio from datetime import datetime import logging +from typing import Any, ClassVar, Self from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN -from homeassistant.core import Context +from homeassistant.core import Context, HomeAssistant, State as CoreState from homeassistant.helpers.restore_state import DATA_RESTORE_STATE from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.helpers.template import ( @@ -30,7 +31,7 @@ class StateVal(str): """Class for representing the value and attributes of a state variable.""" - def __new__(cls, state): + def __new__(cls, state: CoreState) -> Self: """Create a new instance given a state variable.""" new_var = super().__new__(cls, state.state) new_var.__dict__ = state.attributes.copy() @@ -89,12 +90,12 @@ class State: # # Global hass instance # - hass = None + hass: HomeAssistant | None = None # # notify message queues by variable # - notify = {} + notify: ClassVar[dict[str, dict[asyncio.Queue, list[str]]]] = {} # # Last value of state variable notifications. We maintain this @@ -102,22 +103,22 @@ class State: # rather than fetching the current value, which is subject to # race conditions when multiple state variables are set quickly. # - notify_var_last = {} + notify_var_last: ClassVar[dict[str, StateVal | None]] = {} # # pyscript yaml configuration # - pyscript_config = {} + pyscript_config: ClassVar[dict[str, Any]] = {} # # pyscript vars which have already been registered as persisted # - persisted_vars = {} + persisted_vars: ClassVar[dict[str, PyscriptEntity]] = {} # # other parameters of all services that have "entity_id" as a parameter # - service2args = {} + service2args: ClassVar[dict[str, dict[str, Any]]] = {} def __init__(self): """Warn on State instantiation.""" @@ -142,7 +143,7 @@ async def get_service_params(cls): cls.service2args[domain][service].discard("entity_id") @classmethod - async def notify_add(cls, var_names, queue): + async def notify_add(cls, var_names: set[str], queue: asyncio.Queue) -> bool: """Register to notify state variables changes to be sent to queue.""" added = False @@ -158,7 +159,7 @@ async def notify_add(cls, var_names, queue): return added @classmethod - def notify_del(cls, var_names, queue): + def notify_del(cls, var_names: set[str], queue: asyncio.Queue) -> None: """Unregister notify of state variables changes for given queue.""" for var_name in var_names if isinstance(var_names, set) else {var_names}: @@ -171,7 +172,7 @@ def notify_del(cls, var_names, queue): del cls.notify[state_var_name][queue] @classmethod - async def update(cls, new_vars, func_args): + async def update(cls, new_vars: dict[str, Any], func_args: dict[str, Any]) -> None: """Deliver all notifications for state variable changes.""" notify = {} diff --git a/custom_components/pyscript/stubs/generator.py b/custom_components/pyscript/stubs/generator.py index efadfea..8f5aa36 100644 --- a/custom_components/pyscript/stubs/generator.py +++ b/custom_components/pyscript/stubs/generator.py @@ -63,13 +63,13 @@ async def build(self) -> str: module_body: list[ast.stmt] = [] - imports = { + base_imports = { "typing": ["Any", "Literal"], "datetime": ["datetime"], "pyscript_builtins": [_STATE_CLASS], } - for module, imports in imports.items(): + for module, imports in base_imports.items(): module_body.append( ast.ImportFrom( module=module, @@ -127,7 +127,7 @@ async def build(self) -> str: ast.fix_missing_locations(module) return ast.unparse(module) - def _get_or_create_class(self, domain_id: str, base_class: str = None) -> ast.ClassDef: + def _get_or_create_class(self, domain_id: str, base_class: str | None = None) -> ast.ClassDef: cls = self._classes.get(domain_id) if cls is None: cls = ast.ClassDef( @@ -199,7 +199,6 @@ def process_fields(fields: dict[str, Any]) -> list[_ServiceField]: descriptions = await async_get_all_descriptions(self._hass) for domain_id, services in descriptions.items(): - domain_class = self._get_or_create_class(domain_id) for service_id, payload in services.items(): if not self._is_identifier(service_id, f"{domain_id}.{service_id}"): @@ -250,8 +249,9 @@ async def _create_service_function( required=True, default=None, description="Entity ID", - ) - ] + field_nodes + ), + *field_nodes, + ] elif def_type == "entity": args.append(ast.arg(arg="self")) diff --git a/custom_components/pyscript/stubs/pyscript_builtins.py b/custom_components/pyscript/stubs/pyscript_builtins.py index b241283..56e74f2 100644 --- a/custom_components/pyscript/stubs/pyscript_builtins.py +++ b/custom_components/pyscript/stubs/pyscript_builtins.py @@ -61,12 +61,12 @@ def state_active(str_expr: str) -> Callable[..., Any]: ... -def time_trigger(*time_spec: str | None, **kwargs) -> Callable[..., Any]: +def time_trigger(*time_spec: str | None, kwargs: dict | None = None) -> Callable[..., Any]: """Schedule the function using time specifications. Args: - *time_spec: Time expressions such as ``startup``, ``shutdown``, ``once()``, ``period()``, or ``cron()``. - **kwargs: Optional trigger keywords merged into each invocation. + time_spec: Time expressions such as ``startup``, ``shutdown``, ``once()``, ``period()``, or ``cron()``. + kwargs: Optional trigger keywords merged into each invocation. """ ... @@ -81,7 +81,9 @@ def task_unique(name: str, kill_me: bool = False) -> Callable[..., Any]: ... -def event_trigger(*event_type: str, str_expr: str = None, **kwargs) -> Callable[..., Any]: +def event_trigger( + *event_type: str, str_expr: str | None = None, kwargs: dict | None = None +) -> Callable[..., Any]: """Trigger when a Home Assistant event matches the criteria. Args: @@ -106,7 +108,7 @@ def time_active(*time_spec: str, hold_off: int | float | None = None) -> Callabl def mqtt_trigger( - topic: str, str_expr: str | None = None, encoding: str = "utf-8", **kwargs + topic: str, str_expr: str | None = None, encoding: str = "utf-8", kwargs: dict | None = None ) -> Callable[..., Any]: """Trigger when a subscribed MQTT message matches the specification. diff --git a/tests/test_jupyter.py b/tests/test_jupyter.py index 5148d82..bae42ff 100644 --- a/tests/test_jupyter.py +++ b/tests/test_jupyter.py @@ -280,9 +280,9 @@ async def test_jupyter_kernel_msgs(hass, caplog, socket_enabled): # for i in range(5): if i & 1: - msg = (f"hello {i} " * 40).encode("utf-8") + msg = (f"hello {i} " * 40).encode() else: - msg = f"hello {i}".encode("utf-8") + msg = f"hello {i}".encode() await sock["hb_port"].send(msg) await sock["iopub_port"].send(msg) await sock["stdin_port"].send(msg) @@ -422,9 +422,9 @@ async def test_jupyter_kernel_port_close(hass, caplog, socket_enabled): # for i in range(5): if i & 1: - msg = (f"hello {i} " * 40).encode("utf-8") + msg = (f"hello {i} " * 40).encode() else: - msg = f"hello {i}".encode("utf-8") + msg = f"hello {i}".encode() await sock["hb_port"].send(msg) await sock["iopub_port"].send(msg) await sock["stdin_port"].send(msg) @@ -445,9 +445,9 @@ async def test_jupyter_kernel_port_close(hass, caplog, socket_enabled): # for i in range(5): if i & 1: - msg = (f"hello {i} " * 40).encode("utf-8") + msg = (f"hello {i} " * 40).encode() else: - msg = f"hello {i}".encode("utf-8") + msg = f"hello {i}".encode() await sock["hb_port"].send(msg) await sock["iopub_port"].send(msg) await sock["stdin_port"].send(msg) diff --git a/tests/test_state.py b/tests/test_state.py index beb3ca7..f6be05a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1,6 +1,6 @@ """Test pyscripts test module.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from unittest.mock import patch import pytest @@ -73,7 +73,7 @@ def test_state_val_conversions(): assert round_state.as_round(precision=2) == pytest.approx(3.14) datetime_state = StateVal(HassState("test.datetime", "2024-03-05T06:07:08+00:00")) - assert datetime_state.as_datetime() == datetime(2024, 3, 5, 6, 7, 8, tzinfo=timezone.utc) + assert datetime_state.as_datetime() == datetime(2024, 3, 5, 6, 7, 8, tzinfo=UTC) invalid_state = StateVal(HassState("test.invalid", "invalid")) with pytest.raises(ValueError): @@ -93,7 +93,7 @@ def test_state_val_conversions(): assert invalid_state.as_round(default=0) == 0 - fallback_datetime = datetime(1999, 1, 2, 3, 4, 5, tzinfo=timezone.utc) + fallback_datetime = datetime(1999, 1, 2, 3, 4, 5, tzinfo=UTC) assert invalid_state.as_datetime(default=fallback_datetime) == fallback_datetime unknown_state = StateVal(HassState("test.unknown", STATE_UNKNOWN)) diff --git a/tests/test_stubs.py b/tests/test_stubs.py index db7d881..8b41040 100644 --- a/tests/test_stubs.py +++ b/tests/test_stubs.py @@ -10,6 +10,7 @@ import pytest from custom_components.pyscript.const import DOMAIN, FOLDER, SERVICE_GENERATE_STUBS +from homeassistant.core import HomeAssistant from tests.test_init import setup_script @@ -47,7 +48,7 @@ def ready(): ) monkeypatch.setattr("custom_components.pyscript.stubs.generator.er.async_get", lambda _: dummy_registry) - async def fake_service_descriptions(_hass) -> dict[str, dict[str, dict[str, Any]]]: + async def fake_service_descriptions(_hass: HomeAssistant) -> dict[str, dict[str, dict[str, Any]]]: return { "light": { "blink": { diff --git a/tests/test_unit_eval.py b/tests/test_unit_eval.py index 527e479..dad0c79 100644 --- a/tests/test_unit_eval.py +++ b/tests/test_unit_eval.py @@ -183,14 +183,14 @@ class Color(Enum): """ from enum import Enum -class HomeState(Enum): +class HState(Enum): HOME = "home" AWAY = "away" def name_and_value(self): return f"{self.name}:{self.value}" -[HomeState.HOME.name_and_value(), HomeState.AWAY.name_and_value()] +[HState.HOME.name_and_value(), HState.AWAY.name_and_value()] """, ["HOME:home", "AWAY:away"], ],