diff --git a/src/datacustomcode/__init__.py b/src/datacustomcode/__init__.py index 00cfae3..fdb0679 100644 --- a/src/datacustomcode/__init__.py +++ b/src/datacustomcode/__init__.py @@ -17,11 +17,13 @@ from datacustomcode.credentials import AuthType, Credentials from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader from datacustomcode.io.writer.print import PrintDataCloudWriter +from datacustomcode.proxy.client.local_proxy_client import LocalProxyClientProvider __all__ = [ "AuthType", "Client", "Credentials", + "LocalProxyClientProvider", "PrintDataCloudWriter", "QueryAPIDataCloudReader", ] diff --git a/src/datacustomcode/client.py b/src/datacustomcode/client.py index 8ba974b..d1a1138 100644 --- a/src/datacustomcode/client.py +++ b/src/datacustomcode/client.py @@ -33,6 +33,7 @@ from datacustomcode.io.reader.base import BaseDataCloudReader from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode + from datacustomcode.proxy.client.base import BaseProxyClient from datacustomcode.spark.base import BaseSparkSessionProvider @@ -106,17 +107,20 @@ class Client: _reader: BaseDataCloudReader _writer: BaseDataCloudWriter _file: DefaultFindFilePath + _proxy: BaseProxyClient _data_layer_history: dict[DataCloudObjectType, set[str]] def __new__( cls, reader: Optional[BaseDataCloudReader] = None, writer: Optional["BaseDataCloudWriter"] = None, + proxy: Optional[BaseProxyClient] = None, spark_provider: Optional["BaseSparkSessionProvider"] = None, ) -> Client: if cls._instance is None: cls._instance = super().__new__(cls) + spark = None # Initialize Readers and Writers from config # and/or provided reader and writer if reader is None or writer is None: @@ -135,6 +139,22 @@ def __new__( provider = DefaultSparkSessionProvider() spark = provider.get_session(config.spark_config) + elif ( + proxy is None + and config.proxy_config is not None + and config.spark_config is not None + ): + # Both reader and writer provided; we still need spark for proxy init + provider = ( + spark_provider + if spark_provider is not None + else ( + config.spark_provider_config.to_object() + if config.spark_provider_config is not None + else DefaultSparkSessionProvider() + ) + ) + spark = provider.get_session(config.spark_config) if config.reader_config is None and reader is None: raise ValueError( @@ -143,9 +163,28 @@ def __new__( elif reader is None or ( config.reader_config is not None and config.reader_config.force ): - reader_init = config.reader_config.to_object(spark) # type: ignore + if config.proxy_config is None: + raise ValueError( + "Proxy config is required when reader is built from config" + ) + assert ( + spark is not None + ) # set in "reader is None or writer is None" branch + assert config.reader_config is not None # ensured by branch condition + proxy_init = config.proxy_config.to_object(spark) + + reader_init = config.reader_config.to_object(spark) else: reader_init = reader + if proxy is not None: + proxy_init = proxy + elif config.proxy_config is None: + raise ValueError("Proxy config is required when reader is provided") + else: + assert ( + spark is not None + ) # set in "both provided; proxy from config" branch + proxy_init = config.proxy_config.to_object(spark) if config.writer_config is None and writer is None: raise ValueError( "Writer config is required when writer is not provided" @@ -153,12 +192,15 @@ def __new__( elif writer is None or ( config.writer_config is not None and config.writer_config.force ): - writer_init = config.writer_config.to_object(spark) # type: ignore + assert spark is not None # set when reader or writer from config + assert config.writer_config is not None # ensured by branch condition + writer_init = config.writer_config.to_object(spark) else: writer_init = writer cls._instance._reader = reader_init cls._instance._writer = writer_init cls._instance._file = DefaultFindFilePath() + cls._instance._proxy = proxy_init cls._instance._data_layer_history = { DataCloudObjectType.DLO: set(), DataCloudObjectType.DMO: set(), @@ -217,6 +259,9 @@ def write_to_dmo( self._validate_data_layer_history_does_not_contain(DataCloudObjectType.DLO) return self._writer.write_to_dmo(name, dataframe, write_mode, **kwargs) + def call_llm_gateway(self, LLM_MODEL_ID: str, prompt: str, maxTokens: int) -> str: + return self._proxy.call_llm_gateway(LLM_MODEL_ID, prompt, maxTokens) + def find_file_path(self, file_name: str) -> Path: """Return a file path""" diff --git a/src/datacustomcode/config.py b/src/datacustomcode/config.py index 8e18551..b1edfc4 100644 --- a/src/datacustomcode/config.py +++ b/src/datacustomcode/config.py @@ -38,6 +38,7 @@ from datacustomcode.io.base import BaseDataAccessLayer from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH001 from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH001 +from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH001 from datacustomcode.spark.base import BaseSparkSessionProvider DEFAULT_CONFIG_NAME = "config.yaml" @@ -109,6 +110,7 @@ def to_object(self) -> _P: class ClientConfig(BaseModel): reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None + proxy_config: Union[AccessLayerObjectConfig[BaseProxyClient], None] = None spark_config: Union[SparkConfig, None] = None spark_provider_config: Union[ SparkProviderConfig[BaseSparkSessionProvider], None @@ -136,6 +138,7 @@ def merge( self.reader_config = merge(self.reader_config, other.reader_config) self.writer_config = merge(self.writer_config, other.writer_config) + self.proxy_config = merge(self.proxy_config, other.proxy_config) self.spark_config = merge(self.spark_config, other.spark_config) self.spark_provider_config = merge( self.spark_provider_config, other.spark_provider_config diff --git a/src/datacustomcode/config.yaml b/src/datacustomcode/config.yaml index 0ed02db..0267b6f 100644 --- a/src/datacustomcode/config.yaml +++ b/src/datacustomcode/config.yaml @@ -17,3 +17,8 @@ spark_config: spark.submit.deployMode: client spark.sql.execution.arrow.pyspark.enabled: 'true' spark.driver.extraJavaOptions: -Djava.security.manager=allow + +proxy_config: + type_config_name: LocalProxyClientProvider + options: + credentials_profile: default diff --git a/src/datacustomcode/proxy/__init__.py b/src/datacustomcode/proxy/__init__.py new file mode 100644 index 0000000..93988ff --- /dev/null +++ b/src/datacustomcode/proxy/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/datacustomcode/proxy/base.py b/src/datacustomcode/proxy/base.py new file mode 100644 index 0000000..cba92f6 --- /dev/null +++ b/src/datacustomcode/proxy/base.py @@ -0,0 +1,24 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from abc import ABC + +from datacustomcode.mixin import UserExtendableNamedConfigMixin + + +class BaseDataAccessLayer(ABC, UserExtendableNamedConfigMixin): + def __init__(self): + pass diff --git a/src/datacustomcode/proxy/client/__init__.py b/src/datacustomcode/proxy/client/__init__.py new file mode 100644 index 0000000..93988ff --- /dev/null +++ b/src/datacustomcode/proxy/client/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/datacustomcode/proxy/client/base.py b/src/datacustomcode/proxy/client/base.py new file mode 100644 index 0000000..3c4a56b --- /dev/null +++ b/src/datacustomcode/proxy/client/base.py @@ -0,0 +1,28 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from abc import abstractmethod + +from datacustomcode.io.base import BaseDataAccessLayer + + +class BaseProxyClient(BaseDataAccessLayer): + def __init__(self, spark=None, **kwargs): + if spark is not None: + super().__init__(spark) + + @abstractmethod + def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str: ... diff --git a/src/datacustomcode/proxy/client/local_proxy_client.py b/src/datacustomcode/proxy/client/local_proxy_client.py new file mode 100644 index 0000000..2c2f962 --- /dev/null +++ b/src/datacustomcode/proxy/client/local_proxy_client.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from datacustomcode.proxy.client.base import BaseProxyClient + + +class LocalProxyClientProvider(BaseProxyClient): + """Default proxy client provider.""" + + CONFIG_NAME = "LocalProxyClientProvider" + + def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str: + return f"Hello, thanks for using {llmModelId}. So many tokens: {maxTokens}" diff --git a/src/datacustomcode/run.py b/src/datacustomcode/run.py index e70e329..0e4e0ff 100644 --- a/src/datacustomcode/run.py +++ b/src/datacustomcode/run.py @@ -21,6 +21,7 @@ from typing import List, Union from datacustomcode.config import config +from datacustomcode.scan import get_package_type def _set_config_option(config_obj, key: str, value: str) -> None: @@ -60,6 +61,8 @@ def run_entrypoint( f"config.json not found at {config_json_path}. config.json is required." ) + package_type = get_package_type(entrypoint_dir) + try: with open(config_json_path, "r") as f: config_json = json.load(f) @@ -68,21 +71,23 @@ def run_entrypoint( f"config.json at {config_json_path} is not valid JSON" ) from err - # Require dataspace to be present in config.json - dataspace = config_json.get("dataspace") - if not dataspace: - raise ValueError( - f"config.json at {config_json_path} is missing required field 'dataspace'. " - f"Please ensure config.json contains a 'dataspace' field." - ) - - # Load config file first - if config_file: - config.load(config_file) - - # Add dataspace to reader and writer config options - _set_config_option(config.reader_config, "dataspace", dataspace) - _set_config_option(config.writer_config, "dataspace", dataspace) + if package_type == "script": + # Require dataspace to be present in config.json + dataspace = config_json.get("dataspace") + if not dataspace: + raise ValueError( + f"config.json at {config_json_path} is missing required " + f"field 'dataspace'. " + f"Please ensure config.json contains a 'dataspace' field." + ) + + # Load config file first + if config_file: + config.load(config_file) + + # Add dataspace to reader and writer config options + _set_config_option(config.reader_config, "dataspace", dataspace) + _set_config_option(config.writer_config, "dataspace", dataspace) if profile != "default": _set_config_option(config.reader_config, "credentials_profile", profile) diff --git a/tests/test_client.py b/tests/test_client.py index 5d97d58..4e7b99e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -17,6 +17,7 @@ ) from datacustomcode.io.reader.base import BaseDataCloudReader from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode +from datacustomcode.proxy.client.base import BaseProxyClient class MockDataCloudReader(BaseDataCloudReader): @@ -75,6 +76,13 @@ def mock_config(mock_spark): ) +@pytest.fixture +def mock_proxy(): + """Mock proxy client to avoid starting Spark when reader/writer are provided.""" + proxy = MagicMock(spec=BaseProxyClient) + return proxy + + @pytest.fixture def reset_client(): """Reset the Client singleton between tests.""" @@ -85,12 +93,12 @@ def reset_client(): class TestClient: - def test_singleton_pattern(self, reset_client, mock_spark): + def test_singleton_pattern(self, reset_client, mock_spark, mock_proxy): """Test that Client behaves as a singleton.""" reader = MockDataCloudReader(mock_spark) writer = MockDataCloudWriter(mock_spark) - client1 = Client(reader=reader, writer=writer) + client1 = Client(reader=reader, writer=writer, proxy=mock_proxy) client2 = Client() assert client1 is client2 @@ -136,38 +144,38 @@ def test_initialization_with_config(self, mock_config, reset_client, mock_spark) assert client._reader is mock_reader assert client._writer is mock_writer - def test_read_dlo(self, reset_client, mock_spark): + def test_read_dlo(self, reset_client, mock_spark, mock_proxy): reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) reader.read_dlo.return_value = mock_df - client = Client(reader=reader, writer=writer) + client = Client(reader=reader, writer=writer, proxy=mock_proxy) result = client.read_dlo("test_dlo") reader.read_dlo.assert_called_once_with("test_dlo") assert result is mock_df assert "test_dlo" in client._data_layer_history[DataCloudObjectType.DLO] - def test_read_dmo(self, reset_client, mock_spark): + def test_read_dmo(self, reset_client, mock_spark, mock_proxy): reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) reader.read_dmo.return_value = mock_df - client = Client(reader=reader, writer=writer) + client = Client(reader=reader, writer=writer, proxy=mock_proxy) result = client.read_dmo("test_dmo") reader.read_dmo.assert_called_once_with("test_dmo") assert result is mock_df assert "test_dmo" in client._data_layer_history[DataCloudObjectType.DMO] - def test_write_to_dlo(self, reset_client, mock_spark): + def test_write_to_dlo(self, reset_client, mock_spark, mock_proxy): reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) - client = Client(reader=reader, writer=writer) + client = Client(reader=reader, writer=writer, proxy=mock_proxy) client._record_dlo_access("some_dlo") client.write_to_dlo("test_dlo", mock_df, WriteMode.APPEND, extra_param=True) @@ -176,12 +184,12 @@ def test_write_to_dlo(self, reset_client, mock_spark): "test_dlo", mock_df, WriteMode.APPEND, extra_param=True ) - def test_write_to_dmo(self, reset_client, mock_spark): + def test_write_to_dmo(self, reset_client, mock_spark, mock_proxy): reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) - client = Client(reader=reader, writer=writer) + client = Client(reader=reader, writer=writer, proxy=mock_proxy) client._record_dmo_access("some_dmo") client.write_to_dmo("test_dmo", mock_df, WriteMode.OVERWRITE, extra_param=True) @@ -190,13 +198,13 @@ def test_write_to_dmo(self, reset_client, mock_spark): "test_dmo", mock_df, WriteMode.OVERWRITE, extra_param=True ) - def test_mixed_dlo_dmo_raises_exception(self, reset_client, mock_spark): + def test_mixed_dlo_dmo_raises_exception(self, reset_client, mock_spark, mock_proxy): """Test that mixing DLOs and DMOs raises an exception.""" reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) - client = Client(reader=reader, writer=writer) + client = Client(reader=reader, writer=writer, proxy=mock_proxy) client._record_dlo_access("test_dlo") with pytest.raises(DataCloudAccessLayerException) as exc_info: @@ -204,13 +212,13 @@ def test_mixed_dlo_dmo_raises_exception(self, reset_client, mock_spark): assert "test_dlo" in str(exc_info.value) - def test_mixed_dmo_dlo_raises_exception(self, reset_client, mock_spark): + def test_mixed_dmo_dlo_raises_exception(self, reset_client, mock_spark, mock_proxy): """Test that mixing DMOs and DLOs raises an exception (converse case).""" reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) - client = Client(reader=reader, writer=writer) + client = Client(reader=reader, writer=writer, proxy=mock_proxy) client._record_dmo_access("test_dmo") with pytest.raises(DataCloudAccessLayerException) as exc_info: @@ -218,14 +226,14 @@ def test_mixed_dmo_dlo_raises_exception(self, reset_client, mock_spark): assert "test_dmo" in str(exc_info.value) - def test_read_pattern_flow(self, reset_client, mock_spark): + def test_read_pattern_flow(self, reset_client, mock_spark, mock_proxy): """Test a complete flow of reading and writing within the same object type.""" reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) reader.read_dlo.return_value = mock_df - client = Client(reader=reader, writer=writer) + client = Client(reader=reader, writer=writer, proxy=mock_proxy) df = client.read_dlo("source_dlo") client.write_to_dlo("target_dlo", df, WriteMode.APPEND) @@ -239,7 +247,7 @@ def test_read_pattern_flow(self, reset_client, mock_spark): # Reset for DMO test Client._instance = None - client = Client(reader=reader, writer=writer) + client = Client(reader=reader, writer=writer, proxy=mock_proxy) reader.read_dmo.return_value = mock_df df = client.read_dmo("source_dmo")