diff --git a/README.md b/README.md index 1058e61..cac9dda 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,10 @@ Use of this project with Salesforce is subject to the [TERMS OF USE](./TERMS_OF_ - JDK 17 - Docker support like [Docker Desktop](https://docs.docker.com/desktop/) - A salesforce org with some DLOs or DMOs with data and this feature enabled (it is not GA) -- An [External Client App](#creating-an-external-client-app) +- **One of the following** for authentication: + - A Salesforce org already authenticated via the [Salesforce CLI](https://developer.salesforce.com/tools/salesforcecli) + (simplest — no External Client App needed) + - An [External Client App](#creating-an-external-client-app) configured with OAuth settings ## Installation The SDK can be downloaded directly from PyPI with `pip`: @@ -65,6 +68,13 @@ datacustomcode configure datacustomcode run ./payload/entrypoint.py ``` +> [!TIP] +> **Already using the Salesforce CLI?** If you have authenticated an org with `sf org login web +> --alias myorg`, you can skip `datacustomcode configure` entirely: +> ```zsh +> datacustomcode run ./payload/entrypoint.py --sf-cli-org myorg +> ``` + > [!IMPORTANT] > The example entrypoint.py requires a `Account_std__dll` DLO to be present. And in order to deploy the script (next step), the output DLO (which is `Account_std_copy__dll` in the example entrypoint.py) also needs to exist and be in the same dataspace as `Account_std__dll`. @@ -183,17 +193,19 @@ Options: - `--auth-type TEXT`: Authentication method (default: `oauth_tokens`) - `oauth_tokens` - OAuth tokens with refresh_token - `client_credentials` - Server-to-server using client_id/secret only -- `--login-url TEXT`: Salesforce login URL -For OAuth Tokens authentication: -- `--client-id TEXT`: External Client App Client ID -- `--client-secret TEXT`: External Client App Client Secret -- `--refresh-token TEXT`: OAuth refresh token (see [Obtaining Refresh Token](#obtaining-refresh-token-and-core-token)) -- `--core-token TEXT`: (Optional) OAuth core/access token - if not provided, it will be obtained using the refresh token +You will be prompted for the following depending on auth type: + +*Common to all auth types:* +- **Login URL**: Salesforce login URL +- **Client ID**: External Client App Client ID + +*For OAuth Tokens authentication:* +- **Client Secret**: External Client App Client Secret +- **Redirect URI**: OAuth redirect URI -For Client Credentials authentication (server-to-server): -- `--client-id TEXT`: External Client App Client ID -- `--client-secret TEXT`: External Client App Client Secret +*For Client Credentials authentication:* +- **Client Secret**: External Client App Client Secret ##### Using Environment Variables (Alternative) @@ -255,6 +267,9 @@ Options: - `--config-file TEXT`: Path to configuration file - `--dependencies TEXT`: Additional dependencies (can be specified multiple times) - `--profile TEXT`: Credential profile name (default: "default") +- `--sf-cli-org TEXT`: Salesforce CLI org alias or username (e.g. `dev1`). Fetches + credentials via `sf org display` — no `datacustomcode configure` step needed. + Takes precedence over `--profile` if both are supplied. #### `datacustomcode zip` @@ -277,7 +292,7 @@ Options: - `--version TEXT`: Version of the transformation job (default: "0.0.1") - `--description TEXT`: Description of the transformation job (default: "") - `--network TEXT`: docker network (default: "default") -- `--cpu-size TEXT`: CPU size for the deployment (default: "CPU_XL"). Available options: CPU_L(Large), CPU_XL(Extra Large), CPU_2XL(2X Large), CPU_4XL(4X Large) +- `--cpu-size TEXT`: CPU size for the deployment (default: `CPU_2XL`). Available options: CPU_L(Large), CPU_XL(Extra Large), CPU_2XL(2X Large), CPU_4XL(4X Large) ## Docker usage @@ -365,6 +380,54 @@ You can read more about Jupyter Notebooks here: https://jupyter.org/ You now have all fields necessary for the `datacustomcode configure` command. +### Using the Salesforce CLI for authentication + +The [Salesforce CLI](https://developer.salesforce.com/tools/salesforcecli) (`sf`) lets you authenticate an org once and then reference it by alias across tools — including this SDK via `--sf-cli-org`. + +#### Installing the Salesforce CLI + +Follow the [official install guide](https://developer.salesforce.com/docs/atlas.en-us.sfdx_setup.meta/sfdx_setup/sfdx_setup_install_cli.htm), or use a package manager: + +```zsh +# macOS (Homebrew) +brew install sf + +# npm (all platforms) +npm install --global @salesforce/cli +``` + +Verify the install: +```zsh +sf --version +``` + +#### Authenticating an org + +**Browser-based (recommended for developer orgs and sandboxes):** +```zsh +# Production / Developer Edition +sf org login web --alias myorg + +# Sandbox +sf org login web --alias mysandbox --instance-url https://test.salesforce.com + +# Custom domain +sf org login web --alias myorg --instance-url https://mycompany.my.salesforce.com +``` + +Each command opens a browser tab. After you log in and approve access, the CLI stores the session locally. + +**Verify the stored org and confirm the alias:** +```zsh +sf org list +sf org display --target-org myorg +``` + +Once authenticated, pass the alias directly to `datacustomcode run`: +```zsh +datacustomcode run ./payload/entrypoint.py --sf-cli-org myorg +``` + ### Obtaining Refresh Token and Core Token If you're using OAuth Tokens authentication, the initial configure will retrieve and store tokens. Run `datacustomcode auth` to refresh these when they expire. diff --git a/src/datacustomcode/cli.py b/src/datacustomcode/cli.py index a689140..c12e1c9 100644 --- a/src/datacustomcode/cli.py +++ b/src/datacustomcode/cli.py @@ -16,7 +16,11 @@ import json import os import sys -from typing import List, Union +from typing import ( + List, + Optional, + Union, +) import click from loguru import logger @@ -294,12 +298,20 @@ def scan(filename: str, config: str, dry_run: bool, no_requirements: bool): @click.option("--config-file", default=None) @click.option("--dependencies", default=[], multiple=True) @click.option("--profile", default="default") +@click.option( + "--sf-cli-org", + default=None, + help="SF CLI org alias or username. Fetches credentials via `sf org display`.", +) def run( entrypoint: str, config_file: Union[str, None], dependencies: List[str], profile: str, + sf_cli_org: Optional[str], ): from datacustomcode.run import run_entrypoint - run_entrypoint(entrypoint, config_file, dependencies, profile) + run_entrypoint( + entrypoint, config_file, dependencies, profile, sf_cli_org=sf_cli_org + ) diff --git a/src/datacustomcode/io/reader/query_api.py b/src/datacustomcode/io/reader/query_api.py index f4adf89..98d2596 100644 --- a/src/datacustomcode/io/reader/query_api.py +++ b/src/datacustomcode/io/reader/query_api.py @@ -22,50 +22,21 @@ Union, ) -import pandas.api.types as pd_types -from pyspark.sql.types import ( - BooleanType, - DoubleType, - LongType, - StringType, - StructField, - StructType, - TimestampType, -) from salesforcecdpconnector.connection import SalesforceCDPConnection from datacustomcode.credentials import AuthType, Credentials from datacustomcode.io.reader.base import BaseDataCloudReader +from datacustomcode.io.reader.sf_cli import SFCLIDataCloudReader +from datacustomcode.io.reader.utils import _pandas_to_spark_schema if TYPE_CHECKING: - import pandas from pyspark.sql import DataFrame as PySparkDataFrame, SparkSession - from pyspark.sql.types import AtomicType + from pyspark.sql.types import AtomicType, StructType logger = logging.getLogger(__name__) SQL_QUERY_TEMPLATE: Final = "SELECT * FROM {} LIMIT {}" -PANDAS_TYPE_MAPPING = { - "object": StringType(), - "int64": LongType(), - "float64": DoubleType(), - "bool": BooleanType(), -} - - -def _pandas_to_spark_schema( - pandas_df: pandas.DataFrame, nullable: bool = True -) -> StructType: - fields = [] - for column, dtype in pandas_df.dtypes.items(): - spark_type: AtomicType - if pd_types.is_datetime64_any_dtype(dtype): - spark_type = TimestampType() - else: - spark_type = PANDAS_TYPE_MAPPING.get(str(dtype), StringType()) - fields.append(StructField(column, spark_type, nullable)) - return StructType(fields) def create_cdp_connection( @@ -136,6 +107,7 @@ class QueryAPIDataCloudReader(BaseDataCloudReader): Supports multiple authentication methods: - OAuth Tokens (default, needs client_id/secret with refresh_token) - Client Credentials (server-to-server, needs client_id/secret only) + - SF CLI (uses ``sf org display`` access token via the REST API directly) Supports dataspace configuration for querying data within specific dataspaces. When a dataspace is provided (and not "default"), queries are executed within @@ -149,6 +121,7 @@ def __init__( spark: SparkSession, credentials_profile: str = "default", dataspace: Optional[str] = None, + sf_cli_org: Optional[str] = None, ) -> None: """Initialize QueryAPIDataCloudReader. @@ -160,14 +133,30 @@ def __init__( dataspace: Optional dataspace identifier. If provided and not "default", the connection will be configured for the specified dataspace. When None or "default", uses the default dataspace. + sf_cli_org: Optional SF CLI org alias or username. When set, the + reader delegates to :class:`SFCLIDataCloudReader` which calls + the Data Cloud REST API directly using the token obtained from + ``sf org display``, bypassing the CDP token-exchange flow. """ self.spark = spark - credentials = Credentials.from_available(profile=credentials_profile) - logger.debug( - "Initializing QueryAPIDataCloudReader with " - f"auth_type={credentials.auth_type.value}" - ) - self._conn = create_cdp_connection(credentials, dataspace) + if sf_cli_org: + logger.debug( + f"Initializing QueryAPIDataCloudReader with SF CLI org '{sf_cli_org}'" + ) + self._sf_cli_reader: Optional[SFCLIDataCloudReader] = SFCLIDataCloudReader( + spark=spark, + sf_cli_org=sf_cli_org, + dataspace=dataspace, + ) + self._conn = None + else: + self._sf_cli_reader = None + credentials = Credentials.from_available(profile=credentials_profile) + logger.debug( + "Initializing QueryAPIDataCloudReader with " + f"auth_type={credentials.auth_type.value}" + ) + self._conn = create_cdp_connection(credentials, dataspace) def read_dlo( self, @@ -186,8 +175,15 @@ def read_dlo( Returns: PySparkDataFrame: The PySpark DataFrame. """ + sf_cli_reader: Optional[SFCLIDataCloudReader] = getattr( + self, "_sf_cli_reader", None + ) + if sf_cli_reader is not None: + return sf_cli_reader.read_dlo(name, schema, row_limit) + query = SQL_QUERY_TEMPLATE.format(name, row_limit) + assert self._conn is not None pandas_df = self._conn.get_pandas_dataframe(query) # Convert pandas DataFrame to Spark DataFrame @@ -214,8 +210,15 @@ def read_dmo( Returns: PySparkDataFrame: The PySpark DataFrame. """ + sf_cli_reader: Optional[SFCLIDataCloudReader] = getattr( + self, "_sf_cli_reader", None + ) + if sf_cli_reader is not None: + return sf_cli_reader.read_dmo(name, schema, row_limit) + query = SQL_QUERY_TEMPLATE.format(name, row_limit) + assert self._conn is not None pandas_df = self._conn.get_pandas_dataframe(query) # Convert pandas DataFrame to Spark DataFrame diff --git a/src/datacustomcode/io/reader/sf_cli.py b/src/datacustomcode/io/reader/sf_cli.py new file mode 100644 index 0000000..49a5838 --- /dev/null +++ b/src/datacustomcode/io/reader/sf_cli.py @@ -0,0 +1,229 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json +import logging +import subprocess +from typing import ( + TYPE_CHECKING, + Final, + Optional, + Union, +) + +import pandas as pd +import requests + +from datacustomcode.io.reader.base import BaseDataCloudReader +from datacustomcode.io.reader.utils import _pandas_to_spark_schema + +if TYPE_CHECKING: + from pyspark.sql import DataFrame as PySparkDataFrame, SparkSession + from pyspark.sql.types import AtomicType, StructType + +logger = logging.getLogger(__name__) + +API_VERSION: Final = "v66.0" + + +class SFCLIDataCloudReader(BaseDataCloudReader): + """DataCloud reader that authenticates via the Salesforce CLI. + + Uses ``sf org display`` to obtain a fresh access token and queries + Data Cloud through the REST API directly + (``/services/data/{version}/ssot/query-sql``), bypassing the CDP + token-exchange flow that requires special OAuth scopes. + """ + + CONFIG_NAME = "SFCLIDataCloudReader" + + def __init__( + self, + spark: SparkSession, + sf_cli_org: str, + dataspace: Optional[str] = None, + ) -> None: + """Initialize SFCLIDataCloudReader. + + Args: + spark: SparkSession instance for creating DataFrames. + sf_cli_org: Salesforce org alias or username as known to the SF CLI + (e.g. the alias given to ``sf org login web --alias dev1``). + dataspace: Optional dataspace identifier. If ``None`` or + ``"default"`` the query runs against the default dataspace. + """ + self.spark = spark + self.sf_cli_org = sf_cli_org + self.dataspace = ( + dataspace if dataspace and dataspace != "default" else "default" + ) + logger.debug(f"Initialized SFCLIDataCloudReader for org '{sf_cli_org}'") + + def _get_token(self) -> tuple[str, str]: + """Fetch a fresh access token and instance URL from the SF CLI. + + Returns: + ``(access_token, instance_url)`` + + Raises: + RuntimeError: If the ``sf`` command is not on PATH, times out, or + returns an error. + """ + try: + result = subprocess.run( + ["sf", "org", "display", "--target-org", self.sf_cli_org, "--json"], + capture_output=True, + text=True, + check=True, + timeout=30, + ) + except FileNotFoundError as exc: + raise RuntimeError( + "The 'sf' command was not found. " + "Please install Salesforce CLI: https://developer.salesforce.com/tools/salesforcecli" + ) from exc + except subprocess.TimeoutExpired as exc: + raise RuntimeError( + f"'sf org display' timed out for org '{self.sf_cli_org}'" + ) from exc + except subprocess.CalledProcessError as exc: + raise RuntimeError( + f"'sf org display' failed for org '{self.sf_cli_org}'.\n" + f"Ensure the org is authenticated via 'sf org login web'.\n" + f"stderr: {exc.stderr.strip()}" + ) from exc + + try: + data = json.loads(result.stdout) + except json.JSONDecodeError as exc: + raise RuntimeError( + f"Failed to parse 'sf org display' output: {exc}" + ) from exc + + if data.get("status") != 0: + raise RuntimeError( + f"SF CLI error for org '{self.sf_cli_org}': " + f"{data.get('message', 'unknown error')}" + ) + + org_result = data.get("result", {}) + access_token = org_result.get("accessToken") + instance_url = org_result.get("instanceUrl") + + if not access_token or not instance_url: + raise RuntimeError( + f"'sf org display' did not return an access token or instance URL " + f"for org '{self.sf_cli_org}'" + ) + + logger.debug(f"Fetched token from SF CLI for org '{self.sf_cli_org}'") + return access_token, instance_url + + def _execute_query(self, sql: str, row_limit: int) -> pd.DataFrame: + """Execute *sql* against the Data Cloud REST endpoint. + + Args: + sql: Base SQL query (no ``LIMIT`` clause). + row_limit: Maximum rows to return. + + Returns: + Pandas DataFrame with query results. + + Raises: + RuntimeError: On HTTP errors or unexpected response shapes. + """ + access_token, instance_url = self._get_token() + + url = f"{instance_url}/services/data/{API_VERSION}/ssot/query-sql" + headers = {"Authorization": f"Bearer {access_token}"} + params = {"dataspace": self.dataspace} + body = {"sql": f"{sql} LIMIT {row_limit}"} + + logger.debug(f"Executing Data Cloud query: {body['sql']}") + + try: + response = requests.post( + url, + json=body, + params=params, + headers=headers, + timeout=120, + ) + except requests.RequestException as exc: + raise RuntimeError(f"Data Cloud query request failed: {exc}") from exc + + if response.status_code >= 300: + error_msg = response.text + try: + error_data = response.json() + if isinstance(error_data, list) and error_data: + error_msg = error_data[0].get("message", error_msg) + except (json.JSONDecodeError, KeyError): + pass + raise RuntimeError( + f"Data Cloud query failed (HTTP {response.status_code}): {error_msg}" + ) + + result = response.json() + metadata = result.get("metadata", []) + column_names = [col.get("name") for col in metadata] + rows = result.get("data", []) + + if not rows: + return pd.DataFrame(columns=column_names) + return pd.DataFrame(rows, columns=column_names) + + def read_dlo( + self, + name: str, + schema: Union[AtomicType, StructType, str, None] = None, + row_limit: int = 1000, + ) -> PySparkDataFrame: + """Read a Data Lake Object (DLO) from Data Cloud. + + Args: + name: DLO name. + schema: Optional explicit schema. + row_limit: Maximum rows to fetch. + + Returns: + PySpark DataFrame. + """ + pandas_df = self._execute_query(f"SELECT * FROM {name}", row_limit) + if not schema: + schema = _pandas_to_spark_schema(pandas_df) + return self.spark.createDataFrame(pandas_df, schema) + + def read_dmo( + self, + name: str, + schema: Union[AtomicType, StructType, str, None] = None, + row_limit: int = 1000, + ) -> PySparkDataFrame: + """Read a Data Model Object (DMO) from Data Cloud. + + Args: + name: DMO name. + schema: Optional explicit schema. + row_limit: Maximum rows to fetch. + + Returns: + PySpark DataFrame. + """ + pandas_df = self._execute_query(f"SELECT * FROM {name}", row_limit) + if not schema: + schema = _pandas_to_spark_schema(pandas_df) + return self.spark.createDataFrame(pandas_df, schema) diff --git a/src/datacustomcode/io/reader/utils.py b/src/datacustomcode/io/reader/utils.py new file mode 100644 index 0000000..737a76c --- /dev/null +++ b/src/datacustomcode/io/reader/utils.py @@ -0,0 +1,53 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pandas.api.types as pd_types +from pyspark.sql.types import ( + BooleanType, + DoubleType, + LongType, + StringType, + StructField, + StructType, + TimestampType, +) + +if TYPE_CHECKING: + import pandas + from pyspark.sql.types import AtomicType + +PANDAS_TYPE_MAPPING = { + "object": StringType(), + "int64": LongType(), + "float64": DoubleType(), + "bool": BooleanType(), +} + + +def _pandas_to_spark_schema( + pandas_df: pandas.DataFrame, nullable: bool = True +) -> StructType: + fields = [] + for column, dtype in pandas_df.dtypes.items(): + spark_type: AtomicType + if pd_types.is_datetime64_any_dtype(dtype): + spark_type = TimestampType() + else: + spark_type = PANDAS_TYPE_MAPPING.get(str(dtype), StringType()) + fields.append(StructField(column, spark_type, nullable)) + return StructType(fields) diff --git a/src/datacustomcode/io/writer/print.py b/src/datacustomcode/io/writer/print.py index 4eaa1ee..5645f7a 100644 --- a/src/datacustomcode/io/writer/print.py +++ b/src/datacustomcode/io/writer/print.py @@ -45,6 +45,7 @@ def __init__( reader: Optional[QueryAPIDataCloudReader] = None, credentials_profile: str = "default", dataspace: Optional[str] = None, + sf_cli_org: Optional[str] = None, ) -> None: """Initialize PrintDataCloudWriter. @@ -57,20 +58,17 @@ def __init__( The profile determines which credentials to load and which authentication method to use. dataspace: Optional dataspace identifier for multi-tenant queries. + sf_cli_org: Optional SF CLI org alias or username. If provided, + credentials are fetched via `sf org display`. """ super().__init__(spark) if reader is None: - if dataspace is not None: - self.reader = QueryAPIDataCloudReader( - self.spark, - credentials_profile=credentials_profile, - dataspace=dataspace, - ) - else: - self.reader = QueryAPIDataCloudReader( - self.spark, - credentials_profile=credentials_profile, - ) + self.reader = QueryAPIDataCloudReader( + self.spark, + credentials_profile=credentials_profile, + dataspace=dataspace, + sf_cli_org=sf_cli_org, + ) else: self.reader = reader diff --git a/src/datacustomcode/run.py b/src/datacustomcode/run.py index 0e4e0ff..bb9abed 100644 --- a/src/datacustomcode/run.py +++ b/src/datacustomcode/run.py @@ -18,7 +18,11 @@ from pathlib import Path import runpy import sys -from typing import List, Union +from typing import ( + List, + Optional, + Union, +) from datacustomcode.config import config from datacustomcode.scan import get_package_type @@ -41,6 +45,7 @@ def run_entrypoint( config_file: Union[str, None], dependencies: List[str], profile: str, + sf_cli_org: Optional[str] = None, ) -> None: """Run the entrypoint script with the given config and dependencies. @@ -49,6 +54,8 @@ def run_entrypoint( config_file: The config file to use. dependencies: The dependencies to import. profile: The credentials profile to use. + sf_cli_org: Optional SF CLI org alias or username. If provided, credentials + are fetched via `sf org display` instead of from credentials.ini. """ add_py_folder(entrypoint) @@ -89,7 +96,10 @@ def run_entrypoint( _set_config_option(config.reader_config, "dataspace", dataspace) _set_config_option(config.writer_config, "dataspace", dataspace) - if profile != "default": + if sf_cli_org: + _set_config_option(config.reader_config, "sf_cli_org", sf_cli_org) + _set_config_option(config.writer_config, "sf_cli_org", sf_cli_org) + elif profile != "default": _set_config_option(config.reader_config, "credentials_profile", profile) _set_config_option(config.writer_config, "credentials_profile", profile) for dependency in dependencies: diff --git a/tests/io/reader/test_query_api.py b/tests/io/reader/test_query_api.py index b649354..6bb8b5a 100644 --- a/tests/io/reader/test_query_api.py +++ b/tests/io/reader/test_query_api.py @@ -21,8 +21,8 @@ from datacustomcode.io.reader.query_api import ( SQL_QUERY_TEMPLATE, QueryAPIDataCloudReader, - _pandas_to_spark_schema, ) +from datacustomcode.io.reader.utils import _pandas_to_spark_schema class TestPandasToSparkSchema: diff --git a/tests/io/reader/test_sf_cli.py b/tests/io/reader/test_sf_cli.py new file mode 100644 index 0000000..3a94b0b --- /dev/null +++ b/tests/io/reader/test_sf_cli.py @@ -0,0 +1,381 @@ +from __future__ import annotations + +import json +import subprocess +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from datacustomcode.io.reader.sf_cli import API_VERSION, SFCLIDataCloudReader + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_reader( + sf_cli_org: str = "dev1", dataspace: str | None = None +) -> SFCLIDataCloudReader: + spark = MagicMock() + spark.createDataFrame.return_value = MagicMock() + return SFCLIDataCloudReader(spark=spark, sf_cli_org=sf_cli_org, dataspace=dataspace) + + +def _sf_display_output( + access_token: str = "tok", instance_url: str = "https://example.my.salesforce.com" +) -> str: + return json.dumps( + { + "status": 0, + "result": {"accessToken": access_token, "instanceUrl": instance_url}, + } + ) + + +# --------------------------------------------------------------------------- +# __init__ +# --------------------------------------------------------------------------- + + +class TestSFCLIDataCloudReaderInit: + def test_stores_org(self): + reader = _make_reader(sf_cli_org="my-org") + assert reader.sf_cli_org == "my-org" + + def test_dataspace_none_becomes_default(self): + reader = _make_reader(dataspace=None) + assert reader.dataspace == "default" + + def test_dataspace_string_default_stays_default(self): + reader = _make_reader(dataspace="default") + assert reader.dataspace == "default" + + def test_custom_dataspace_preserved(self): + reader = _make_reader(dataspace="myspace") + assert reader.dataspace == "myspace" + + +# --------------------------------------------------------------------------- +# _get_token +# --------------------------------------------------------------------------- + + +class TestGetToken: + @pytest.fixture + def reader(self): + return _make_reader() + + def _run_result(self, stdout: str) -> MagicMock: + result = MagicMock() + result.stdout = stdout + return result + + def test_returns_token_and_instance_url(self, reader): + with patch( + "subprocess.run", + return_value=self._run_result( + _sf_display_output("mytoken", "https://org.salesforce.com") + ), + ) as mock_run: + token, url = reader._get_token() + + assert token == "mytoken" + assert url == "https://org.salesforce.com" + mock_run.assert_called_once_with( + ["sf", "org", "display", "--target-org", "dev1", "--json"], + capture_output=True, + text=True, + check=True, + timeout=30, + ) + + def test_file_not_found_raises_runtime_error(self, reader): + with patch("subprocess.run", side_effect=FileNotFoundError): + with pytest.raises(RuntimeError, match="'sf' command was not found"): + reader._get_token() + + def test_timeout_raises_runtime_error(self, reader): + with patch( + "subprocess.run", + side_effect=subprocess.TimeoutExpired(cmd="sf", timeout=30), + ): + with pytest.raises(RuntimeError, match="timed out"): + reader._get_token() + + def test_called_process_error_raises_runtime_error(self, reader): + exc = subprocess.CalledProcessError(returncode=1, cmd="sf", stderr="auth error") + with patch("subprocess.run", side_effect=exc): + with pytest.raises(RuntimeError, match="failed for org"): + reader._get_token() + + def test_called_process_error_includes_stderr(self, reader): + exc = subprocess.CalledProcessError( + returncode=1, cmd="sf", stderr="not authenticated" + ) + with patch("subprocess.run", side_effect=exc): + with pytest.raises(RuntimeError, match="not authenticated"): + reader._get_token() + + def test_invalid_json_raises_runtime_error(self, reader): + result = MagicMock() + result.stdout = "not valid json{" + with patch("subprocess.run", return_value=result): + with pytest.raises(RuntimeError, match="Failed to parse"): + reader._get_token() + + def test_nonzero_status_raises_runtime_error(self, reader): + payload = json.dumps({"status": 1, "message": "something went wrong"}) + result = MagicMock() + result.stdout = payload + with patch("subprocess.run", return_value=result): + with pytest.raises(RuntimeError, match="something went wrong"): + reader._get_token() + + def test_nonzero_status_without_message_uses_unknown_error(self, reader): + payload = json.dumps({"status": 1}) + result = MagicMock() + result.stdout = payload + with patch("subprocess.run", return_value=result): + with pytest.raises(RuntimeError, match="unknown error"): + reader._get_token() + + def test_missing_access_token_raises_runtime_error(self, reader): + payload = json.dumps( + {"status": 0, "result": {"instanceUrl": "https://x.salesforce.com"}} + ) + result = MagicMock() + result.stdout = payload + with patch("subprocess.run", return_value=result): + with pytest.raises(RuntimeError, match="access token or instance URL"): + reader._get_token() + + def test_missing_instance_url_raises_runtime_error(self, reader): + payload = json.dumps({"status": 0, "result": {"accessToken": "tok"}}) + result = MagicMock() + result.stdout = payload + with patch("subprocess.run", return_value=result): + with pytest.raises(RuntimeError, match="access token or instance URL"): + reader._get_token() + + +# --------------------------------------------------------------------------- +# _execute_query +# --------------------------------------------------------------------------- + + +class TestExecuteQuery: + @pytest.fixture + def reader(self): + return _make_reader() + + @pytest.fixture + def mock_token(self, reader): + with patch.object( + reader, "_get_token", return_value=("mytoken", "https://org.salesforce.com") + ): + yield + + def _mock_response( + self, status_code: int = 200, json_body: dict | None = None, text: str = "" + ) -> MagicMock: + response = MagicMock() + response.status_code = status_code + response.text = text + response.json.return_value = json_body or {} + return response + + def test_posts_to_correct_url(self, reader, mock_token): + api_response = {"metadata": [{"name": "col"}], "data": [["v"]]} + with patch( + "requests.post", return_value=self._mock_response(json_body=api_response) + ) as mock_post: + reader._execute_query("SELECT * FROM foo", 100) + + url = mock_post.call_args[0][0] + assert ( + url + == f"https://org.salesforce.com/services/data/{API_VERSION}/ssot/query-sql" + ) + + def test_passes_bearer_token_header(self, reader, mock_token): + api_response = {"metadata": [], "data": []} + with patch( + "requests.post", return_value=self._mock_response(json_body=api_response) + ) as mock_post: + reader._execute_query("SELECT * FROM foo", 10) + + headers = mock_post.call_args.kwargs["headers"] + assert headers["Authorization"] == "Bearer mytoken" + + def test_passes_dataspace_param(self, reader, mock_token): + api_response = {"metadata": [], "data": []} + with patch( + "requests.post", return_value=self._mock_response(json_body=api_response) + ) as mock_post: + reader._execute_query("SELECT * FROM foo", 10) + + params = mock_post.call_args.kwargs["params"] + assert params["dataspace"] == "default" + + def test_appends_limit_to_sql(self, reader, mock_token): + api_response = {"metadata": [], "data": []} + with patch( + "requests.post", return_value=self._mock_response(json_body=api_response) + ) as mock_post: + reader._execute_query("SELECT * FROM foo", 42) + + body = mock_post.call_args.kwargs["json"] + assert body["sql"] == "SELECT * FROM foo LIMIT 42" + + def test_returns_dataframe_with_rows(self, reader, mock_token): + api_response = { + "metadata": [{"name": "id"}, {"name": "name"}], + "data": [[1, "alice"], [2, "bob"]], + } + with patch( + "requests.post", return_value=self._mock_response(json_body=api_response) + ): + df = reader._execute_query("SELECT * FROM foo", 100) + + assert list(df.columns) == ["id", "name"] + assert len(df) == 2 + + def test_returns_empty_dataframe_when_no_rows(self, reader, mock_token): + api_response = {"metadata": [{"name": "id"}, {"name": "name"}], "data": []} + with patch( + "requests.post", return_value=self._mock_response(json_body=api_response) + ): + df = reader._execute_query("SELECT * FROM foo", 100) + + assert list(df.columns) == ["id", "name"] + assert len(df) == 0 + + def test_http_error_raises_runtime_error(self, reader, mock_token): + with patch( + "requests.post", + return_value=self._mock_response(status_code=401, text="Unauthorized"), + ): + with pytest.raises(RuntimeError, match="HTTP 401"): + reader._execute_query("SELECT * FROM foo", 10) + + def test_http_error_uses_json_message_when_available(self, reader, mock_token): + error_body = [{"message": "insufficient privileges"}] + response = self._mock_response(status_code=403, text="Forbidden") + response.json.return_value = error_body + with patch("requests.post", return_value=response): + with pytest.raises(RuntimeError, match="insufficient privileges"): + reader._execute_query("SELECT * FROM foo", 10) + + def test_http_error_falls_back_to_text_when_json_not_list(self, reader, mock_token): + response = self._mock_response(status_code=500, text="Internal Server Error") + response.json.return_value = {"error": "oops"} # dict, not list + with patch("requests.post", return_value=response): + with pytest.raises(RuntimeError, match="Internal Server Error"): + reader._execute_query("SELECT * FROM foo", 10) + + def test_request_exception_raises_runtime_error(self, reader, mock_token): + import requests as req_lib + + with patch( + "requests.post", side_effect=req_lib.RequestException("connection refused") + ): + with pytest.raises(RuntimeError, match="Data Cloud query request failed"): + reader._execute_query("SELECT * FROM foo", 10) + + def test_custom_dataspace_passed_as_param(self): + reader = _make_reader(dataspace="myspace") + with patch.object( + reader, "_get_token", return_value=("tok", "https://org.salesforce.com") + ): + api_response = {"metadata": [], "data": []} + with patch( + "requests.post", + return_value=self._mock_response(json_body=api_response), + ) as mock_post: + reader._execute_query("SELECT * FROM foo", 10) + + params = mock_post.call_args.kwargs["params"] + assert params["dataspace"] == "myspace" + + +# --------------------------------------------------------------------------- +# read_dlo / read_dmo +# --------------------------------------------------------------------------- + + +class TestReadDloAndDmo: + @pytest.fixture + def reader(self): + return _make_reader() + + @pytest.fixture + def sample_df(self): + return pd.DataFrame({"id": [1, 2], "name": ["a", "b"]}) + + @pytest.mark.parametrize( + "method,obj_name", + [ + ("read_dlo", "MyDLO__dll"), + ("read_dmo", "MyDMO__dlm"), + ], + ) + def test_executes_select_star_query(self, reader, sample_df, method, obj_name): + with patch.object( + reader, "_execute_query", return_value=sample_df + ) as mock_exec: + getattr(reader, method)(obj_name) + + mock_exec.assert_called_once_with(f"SELECT * FROM {obj_name}", 1000) + + @pytest.mark.parametrize("method", ["read_dlo", "read_dmo"]) + def test_custom_row_limit(self, reader, sample_df, method): + with patch.object( + reader, "_execute_query", return_value=sample_df + ) as mock_exec: + getattr(reader, method)("SomeObj", row_limit=50) + + _, row_limit_arg = mock_exec.call_args[0] + assert row_limit_arg == 50 + + @pytest.mark.parametrize("method", ["read_dlo", "read_dmo"]) + def test_auto_infers_schema_when_none_given(self, reader, sample_df, method): + from pyspark.sql.types import StructType + + with patch.object(reader, "_execute_query", return_value=sample_df): + getattr(reader, method)("SomeObj") + + _, schema_arg = reader.spark.createDataFrame.call_args[0] + assert isinstance(schema_arg, StructType) + + @pytest.mark.parametrize("method", ["read_dlo", "read_dmo"]) + def test_uses_provided_schema(self, reader, sample_df, method): + from pyspark.sql.types import ( + LongType, + StringType, + StructField, + StructType, + ) + + custom_schema = StructType( + [ + StructField("id", LongType(), True), + StructField("name", StringType(), True), + ] + ) + + with patch.object(reader, "_execute_query", return_value=sample_df): + getattr(reader, method)("SomeObj", schema=custom_schema) + + _, schema_arg = reader.spark.createDataFrame.call_args[0] + assert schema_arg is custom_schema + + @pytest.mark.parametrize("method", ["read_dlo", "read_dmo"]) + def test_returns_spark_dataframe(self, reader, sample_df, method): + expected = MagicMock() + reader.spark.createDataFrame.return_value = expected + + with patch.object(reader, "_execute_query", return_value=sample_df): + result = getattr(reader, method)("SomeObj") + + assert result is expected