From 6b33c0512bb719e8a445efb942278e8a3bed0e32 Mon Sep 17 00:00:00 2001 From: Yuseok Jo Date: Sat, 23 May 2026 23:55:00 +0900 Subject: [PATCH 1/4] Fix SFTPToS3Operator using s3_conn_id instead of aws_conn_id --- .../amazon/aws/transfers/sftp_to_s3.py | 22 +++++++-- .../amazon/aws/transfers/test_sftp_to_s3.py | 46 +++++++++++++++++-- 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py index 4897ccca25c91..124ea592a59ab 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import warnings from collections.abc import Sequence from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING @@ -44,8 +45,11 @@ class SFTPToS3Operator(BaseOperator): Connection. :param sftp_path: The sftp remote path. This is the specified file path for downloading the file from the SFTP server. - :param s3_conn_id: The s3 connection id. The name or identifier for - establishing a connection to S3 + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is None or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). :param s3_bucket: The targeted s3 bucket. This is the S3 bucket to where the file is uploaded. :param s3_key: The targeted s3 key. This is the specified path for @@ -66,18 +70,26 @@ def __init__( sftp_path: str, sftp_conn_id: str = "ssh_default", sftp_remote_host: str = "", - s3_conn_id: str = "aws_default", + aws_conn_id: str = "aws_default", + s3_conn_id: str | None = None, use_temp_file: bool = True, fail_on_file_not_exist: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) + if s3_conn_id is not None: + warnings.warn( + "The s3_conn_id parameter is deprecated. Use aws_conn_id instead.", + DeprecationWarning, + stacklevel=2, + ) + aws_conn_id = s3_conn_id self.sftp_conn_id = sftp_conn_id self.sftp_path = sftp_path self.sftp_remote_host = sftp_remote_host self.s3_bucket = s3_bucket self.s3_key = s3_key - self.s3_conn_id = s3_conn_id + self.aws_conn_id = aws_conn_id self.use_temp_file = use_temp_file self.fail_on_file_not_exist = fail_on_file_not_exist @@ -92,7 +104,7 @@ def execute(self, context: Context) -> None: # SSHHook will handle a None/"" sftp_remote_host ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id, remote_host=self.sftp_remote_host) - s3_hook = S3Hook(self.s3_conn_id) + s3_hook = S3Hook(self.aws_conn_id) sftp_client = ssh_hook.get_conn().open_sftp() diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_sftp_to_s3.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_sftp_to_s3.py index feb85e33a3c17..f1d484cef666b 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_sftp_to_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_sftp_to_s3.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +import warnings + import boto3 import pytest from moto import mock_aws @@ -99,7 +101,7 @@ def test_sftp_to_s3_operation(self, use_temp_file): s3_key=S3_KEY, sftp_path=SFTP_PATH, sftp_conn_id=SFTP_CONN_ID, - s3_conn_id=S3_CONN_ID, + aws_conn_id=S3_CONN_ID, use_temp_file=use_temp_file, task_id="test_sftp_to_s3", dag=self.dag, @@ -137,7 +139,7 @@ def test_sftp_to_s3_fail_on_file_not_exist(self, fail_on_file_not_exist): s3_key=self.s3_key, sftp_path="/tmp/wrong_path.txt", sftp_conn_id=SFTP_CONN_ID, - s3_conn_id=S3_CONN_ID, + aws_conn_id=S3_CONN_ID, fail_on_file_not_exist=fail_on_file_not_exist, task_id="test_sftp_to_s3", dag=self.dag, @@ -148,7 +150,7 @@ def test_sftp_to_s3_fail_on_file_not_exist(self, fail_on_file_not_exist): s3_key=self.s3_key, sftp_path=self.sftp_path, sftp_conn_id=SFTP_CONN_ID, - s3_conn_id=S3_CONN_ID, + aws_conn_id=S3_CONN_ID, fail_on_file_not_exist=fail_on_file_not_exist, task_id="test_sftp_to_s3", dag=self.dag, @@ -191,7 +193,7 @@ def test_sftp_to_s3_sftp_remote_host(self): sftp_path=SFTP_PATH, sftp_conn_id=SFTP_CONN_ID, sftp_remote_host="localhost", - s3_conn_id=S3_CONN_ID, + aws_conn_id=S3_CONN_ID, task_id="test_sftp_to_s3_remote_host", dag=self.dag, ) @@ -208,3 +210,39 @@ def test_sftp_to_s3_sftp_remote_host(self): conn.delete_object(Bucket=self.s3_bucket, Key=self.s3_key) conn.delete_bucket(Bucket=self.s3_bucket) assert not s3_hook.check_for_bucket(self.s3_bucket) + + +class TestSFTPToS3OperatorInit: + """Unit tests for SFTPToS3Operator.__init__ that do not require an SSH server.""" + + def test_s3_conn_id_deprecated(self): + """s3_conn_id is a deprecated alias for aws_conn_id and must raise DeprecationWarning.""" + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + op = SFTPToS3Operator( + task_id="test_deprecated", + s3_bucket=BUCKET, + s3_key=S3_KEY, + sftp_path=SFTP_PATH, + sftp_conn_id=SFTP_CONN_ID, + s3_conn_id="my_legacy_conn", + ) + deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(deprecation_warnings) == 1 + assert "s3_conn_id" in str(deprecation_warnings[0].message) + assert op.aws_conn_id == "my_legacy_conn" + + def test_aws_conn_id_default(self): + """aws_conn_id defaults to 'aws_default' and no DeprecationWarning is raised.""" + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + op = SFTPToS3Operator( + task_id="test_default", + s3_bucket=BUCKET, + s3_key=S3_KEY, + sftp_path=SFTP_PATH, + sftp_conn_id=SFTP_CONN_ID, + ) + deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert not deprecation_warnings + assert op.aws_conn_id == "aws_default" From 68a5c19bc90ed4e7c28aeec9a1cfbd06dc590649 Mon Sep 17 00:00:00 2001 From: Yuseok Jo Date: Sun, 24 May 2026 00:08:42 +0900 Subject: [PATCH 2/4] Add replace/encrypt/gzip/acl_policy options to SFTPToS3Operator --- .../amazon/aws/transfers/sftp_to_s3.py | 32 +++++++++++++++++-- .../amazon/aws/transfers/test_sftp_to_s3.py | 30 +++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py index 124ea592a59ab..8efe9e71dd8c7 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py @@ -58,6 +58,10 @@ class SFTPToS3Operator(BaseOperator): if False streams file from SFTP to S3. :param fail_on_file_not_exist: If True, operator fails when file does not exist, if False, operator will not fail and skips transfer. Default is True. + :param replace: If True, overwrite the S3 key if it already exists. + :param encrypt: If True, the file will be encrypted on the server-side by S3. + :param gzip: If True, the file will be compressed locally before upload. + :param acl_policy: Canned ACL policy for the file being uploaded to S3. """ template_fields: Sequence[str] = ("s3_key", "sftp_path", "s3_bucket") @@ -74,6 +78,10 @@ def __init__( s3_conn_id: str | None = None, use_temp_file: bool = True, fail_on_file_not_exist: bool = True, + replace: bool = False, + encrypt: bool = False, + gzip: bool = False, + acl_policy: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -92,6 +100,10 @@ def __init__( self.aws_conn_id = aws_conn_id self.use_temp_file = use_temp_file self.fail_on_file_not_exist = fail_on_file_not_exist + self.replace = replace + self.encrypt = encrypt + self.gzip = gzip + self.acl_policy = acl_policy @staticmethod def get_s3_key(s3_key: str) -> str: @@ -119,8 +131,22 @@ def execute(self, context: Context) -> None: if self.use_temp_file: with NamedTemporaryFile("w") as f: sftp_client.get(self.sftp_path, f.name) - - s3_hook.load_file(filename=f.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=True) + s3_hook.load_file( + filename=f.name, + key=self.s3_key, + bucket_name=self.s3_bucket, + replace=self.replace, + encrypt=self.encrypt, + gzip=self.gzip, + acl_policy=self.acl_policy, + ) else: + extra_args: dict = {} + if self.encrypt: + extra_args["ServerSideEncryption"] = "AES256" + if self.acl_policy: + extra_args["ACL"] = self.acl_policy with sftp_client.file(self.sftp_path, mode="rb") as data: - s3_hook.get_conn().upload_fileobj(data, self.s3_bucket, self.s3_key, Callback=self.log.info) + s3_hook.get_conn().upload_fileobj( + data, self.s3_bucket, self.s3_key, ExtraArgs=extra_args or None, Callback=self.log.info + ) diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_sftp_to_s3.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_sftp_to_s3.py index f1d484cef666b..779d43de1bf1c 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_sftp_to_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_sftp_to_s3.py @@ -246,3 +246,33 @@ def test_aws_conn_id_default(self): deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] assert not deprecation_warnings assert op.aws_conn_id == "aws_default" + + @pytest.mark.parametrize( + ("kwargs", "expected"), + [ + ({}, {"replace": False, "encrypt": False, "gzip": False, "acl_policy": None}), + ( + {"replace": True, "encrypt": True, "gzip": True, "acl_policy": "bucket-owner-full-control"}, + { + "replace": True, + "encrypt": True, + "gzip": True, + "acl_policy": "bucket-owner-full-control", + }, + ), + ], + ) + def test_s3_upload_options(self, kwargs, expected): + """replace/encrypt/gzip/acl_policy are stored and default to False/None.""" + op = SFTPToS3Operator( + task_id="test_options", + s3_bucket=BUCKET, + s3_key=S3_KEY, + sftp_path=SFTP_PATH, + sftp_conn_id=SFTP_CONN_ID, + **kwargs, + ) + assert op.replace == expected["replace"] + assert op.encrypt == expected["encrypt"] + assert op.gzip == expected["gzip"] + assert op.acl_policy == expected["acl_policy"] From f96ce334a5cb589056de2989372892bd87a44991 Mon Sep 17 00:00:00 2001 From: Yuseok Jo Date: Sun, 24 May 2026 00:17:22 +0900 Subject: [PATCH 3/4] Add multi-file transfer support to SFTPToS3, S3ToSFTP, S3ToFTP operators --- .../amazon/aws/transfers/s3_to_ftp.py | 71 +++++++++--- .../amazon/aws/transfers/s3_to_sftp.py | 81 ++++++++++++-- .../amazon/aws/transfers/sftp_to_s3.py | 101 ++++++++++++++---- .../amazon/aws/transfers/test_s3_to_ftp.py | 28 +++++ .../amazon/aws/transfers/test_s3_to_sftp.py | 27 +++++ .../amazon/aws/transfers/test_sftp_to_s3.py | 23 ++++ 6 files changed, 291 insertions(+), 40 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_ftp.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_ftp.py index 2a0a4fb91e8e4..0232e5fdca665 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_ftp.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_ftp.py @@ -39,16 +39,25 @@ class S3ToFTPOperator(BaseOperator): :param s3_bucket: The targeted s3 bucket. This is the S3 bucket from where the file is downloaded. - :param s3_key: The targeted s3 key. This is the specified file path for - downloading the file from S3. - :param ftp_path: The ftp remote path. This is the specified file path for - uploading file to the FTP server. + :param s3_key: The targeted s3 key. For a single file it must include the file + path. For multiple files it is the key prefix (directory) and must end with + ``"/"``. + :param s3_filenames: Only used if you want to move multiple files. You can pass + a list with exact key suffixes present under the s3_key prefix, or a string + prefix that all filenames must match. Use ``"*"`` to move all objects under + the s3_key prefix. + :param ftp_path: The ftp remote path. For a single file it must include the file + path. For multiple files it is the destination directory path and must end + with ``"/"``. + :param ftp_filenames: Only used if you want to move multiple files and name them + differently at the destination. It can be a list of filenames or a string + prefix that replaces the s3 prefix. :param aws_conn_id: reference to a specific AWS connection :param ftp_conn_id: The ftp connection id. The name or identifier for establishing a connection to the FTP server. """ - template_fields: Sequence[str] = ("s3_bucket", "s3_key", "ftp_path") + template_fields: Sequence[str] = ("s3_bucket", "s3_key", "ftp_path", "s3_filenames", "ftp_filenames") def __init__( self, @@ -56,6 +65,8 @@ def __init__( s3_bucket, s3_key, ftp_path, + s3_filenames: str | list[str] | None = None, + ftp_filenames: str | list[str] | None = None, aws_conn_id="aws_default", ftp_conn_id="ftp_default", **kwargs, @@ -64,18 +75,54 @@ def __init__( self.s3_bucket = s3_bucket self.s3_key = s3_key self.ftp_path = ftp_path + self.s3_filenames = s3_filenames + self.ftp_filenames = ftp_filenames self.aws_conn_id = aws_conn_id self.ftp_conn_id = ftp_conn_id + def _download_from_s3(self, s3_hook: S3Hook, ftp_hook: FTPHook, s3_key: str, ftp_path: str) -> None: + s3_obj = s3_hook.get_key(s3_key, self.s3_bucket) + with NamedTemporaryFile() as local_tmp_file: + self.log.info("Downloading file from %s", s3_key) + s3_obj.download_fileobj(local_tmp_file) + local_tmp_file.seek(0) + ftp_hook.store_file(ftp_path, local_tmp_file.name) + self.log.info("File stored in %s", ftp_path) + def execute(self, context: Context): s3_hook = S3Hook(self.aws_conn_id) ftp_hook = FTPHook(ftp_conn_id=self.ftp_conn_id) - s3_obj = s3_hook.get_key(self.s3_key, self.s3_bucket) + if self.s3_filenames: + if isinstance(self.s3_filenames, str): + self.log.info("Getting files in s3://%s/%s", self.s3_bucket, self.s3_key) + all_keys = s3_hook.list_keys(bucket_name=self.s3_bucket, prefix=self.s3_key) or [] + filenames = [k[len(self.s3_key) :] for k in all_keys] + if self.s3_filenames == "*": + files = filenames + else: + s3_prefix: str = self.s3_filenames + files = [f for f in filenames if s3_prefix in f] - with NamedTemporaryFile() as local_tmp_file: - self.log.info("Downloading file from %s", self.s3_key) - s3_obj.download_fileobj(local_tmp_file) - local_tmp_file.seek(0) - ftp_hook.store_file(self.ftp_path, local_tmp_file.name) - self.log.info("File stored in %s", {self.ftp_path}) + for file in files: + self.log.info("Moving file %s", file) + if self.ftp_filenames and isinstance(self.ftp_filenames, str): + ftp_filename = file.replace(self.s3_filenames, self.ftp_filenames) + else: + ftp_filename = file + self._download_from_s3( + s3_hook, ftp_hook, self.s3_key + file, self.ftp_path + ftp_filename + ) + else: + if self.ftp_filenames: + for s3_file, ftp_file in zip(self.s3_filenames, self.ftp_filenames): + self._download_from_s3( + s3_hook, ftp_hook, self.s3_key + s3_file, self.ftp_path + ftp_file + ) + else: + for s3_file in self.s3_filenames: + self._download_from_s3( + s3_hook, ftp_hook, self.s3_key + s3_file, self.ftp_path + s3_file + ) + else: + self._download_from_s3(s3_hook, ftp_hook, self.s3_key, self.ftp_path) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py index 87d8454af963b..26c662c91ca13 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py @@ -27,6 +27,8 @@ from airflow.providers.ssh.hooks.ssh import SSHHook if TYPE_CHECKING: + import paramiko + from airflow.sdk import Context @@ -40,8 +42,9 @@ class S3ToSFTPOperator(BaseOperator): :param sftp_conn_id: The sftp connection id. The name or identifier for establishing a connection to the SFTP server. - :param sftp_path: The sftp remote path. This is the specified file path for - uploading file to the SFTP server. + :param sftp_path: The sftp remote path. For a single file it must include the + file path. For multiple files it is the destination directory path and must + end with ``"/"``. :param sftp_remote_host: The remote host of the SFTP server. Overrides host in Connection. :param aws_conn_id: The Airflow connection used for AWS credentials. @@ -51,14 +54,22 @@ class S3ToSFTPOperator(BaseOperator): maintained on each worker node). :param s3_bucket: The targeted s3 bucket. This is the S3 bucket from where the file is downloaded. - :param s3_key: The targeted s3 key. This is the specified file path for - downloading the file from S3. + :param s3_key: The targeted s3 key. For a single file it must include the file + path. For multiple files it is the key prefix (directory) and must end with + ``"/"``. + :param s3_filenames: Only used if you want to move multiple files. You can pass + a list with exact key suffixes present under the s3_key prefix, or a string + prefix that all filenames must match. Use ``"*"`` to move all objects under + the s3_key prefix. + :param sftp_filenames: Only used if you want to move multiple files and name them + differently at the destination. It can be a list of filenames or a string + prefix that replaces the s3 prefix. :param confirm: specify if the SFTP operation should be confirmed, defaults to True. When True, a stat will be performed on the remote file after upload to verify the file size matches and confirm successful transfer. """ - template_fields: Sequence[str] = ("s3_key", "sftp_path", "s3_bucket") + template_fields: Sequence[str] = ("s3_key", "sftp_path", "s3_bucket", "s3_filenames", "sftp_filenames") def __init__( self, @@ -69,6 +80,8 @@ def __init__( sftp_conn_id: str = "ssh_default", sftp_remote_host: str = "", aws_conn_id: str | None = "aws_default", + s3_filenames: str | list[str] | None = None, + sftp_filenames: str | list[str] | None = None, confirm: bool = True, **kwargs, ) -> None: @@ -79,6 +92,8 @@ def __init__( self.s3_key = s3_key self.sftp_remote_host = sftp_remote_host self.aws_conn_id = aws_conn_id + self.s3_filenames = s3_filenames + self.sftp_filenames = sftp_filenames self.confirm = confirm @staticmethod @@ -87,6 +102,17 @@ def get_s3_key(s3_key: str) -> str: parsed_s3_key = urlsplit(s3_key) return parsed_s3_key.path.lstrip("/") + def _download_from_s3( + self, + sftp_client: paramiko.SFTPClient, + s3_client, + s3_key: str, + sftp_path: str, + ) -> None: + with NamedTemporaryFile("w") as f: + s3_client.download_file(self.s3_bucket, s3_key, f.name) + sftp_client.put(f.name, sftp_path, confirm=self.confirm) + def execute(self, context: Context) -> None: self.s3_key = self.get_s3_key(self.s3_key) @@ -97,6 +123,45 @@ def execute(self, context: Context) -> None: s3_client = s3_hook.get_conn() sftp_client = ssh_hook.get_conn().open_sftp() - with NamedTemporaryFile("w") as f: - s3_client.download_file(self.s3_bucket, self.s3_key, f.name) - sftp_client.put(f.name, self.sftp_path, confirm=self.confirm) + if self.s3_filenames: + if isinstance(self.s3_filenames, str): + self.log.info("Getting files in s3://%s/%s", self.s3_bucket, self.s3_key) + all_keys = s3_hook.list_keys(bucket_name=self.s3_bucket, prefix=self.s3_key) or [] + filenames = [k[len(self.s3_key) :] for k in all_keys] + if self.s3_filenames == "*": + files = filenames + else: + s3_prefix: str = self.s3_filenames + files = [f for f in filenames if s3_prefix in f] + + for file in files: + self.log.info("Moving file %s", file) + if self.sftp_filenames and isinstance(self.sftp_filenames, str): + sftp_filename = file.replace(self.s3_filenames, self.sftp_filenames) + else: + sftp_filename = file + self._download_from_s3( + sftp_client, + s3_client, + self.s3_key + file, + self.sftp_path + sftp_filename, + ) + else: + if self.sftp_filenames: + for s3_file, sftp_file in zip(self.s3_filenames, self.sftp_filenames): + self._download_from_s3( + sftp_client, + s3_client, + self.s3_key + s3_file, + self.sftp_path + sftp_file, + ) + else: + for s3_file in self.s3_filenames: + self._download_from_s3( + sftp_client, + s3_client, + self.s3_key + s3_file, + self.sftp_path + s3_file, + ) + else: + self._download_from_s3(sftp_client, s3_client, self.s3_key, self.sftp_path) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py index 8efe9e71dd8c7..d351270b79dc0 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py @@ -28,6 +28,8 @@ from airflow.providers.ssh.hooks.ssh import SSHHook if TYPE_CHECKING: + import paramiko + from airflow.sdk import Context @@ -43,8 +45,12 @@ class SFTPToS3Operator(BaseOperator): establishing a connection to the SFTP server. :param sftp_remote_host: The remote host of the SFTP server. Overrides host in Connection. - :param sftp_path: The sftp remote path. This is the specified file path - for downloading the file from the SFTP server. + :param sftp_path: The sftp remote path. For a single file it must include the + file path. For multiple files it is the directory path where the files are + located. + :param sftp_filenames: Only used if you want to move multiple files. You can pass + a list with exact filenames present in the sftp path, or a prefix that all + files must match. Use ``"*"`` to move all files within the sftp path. :param aws_conn_id: The Airflow connection used for AWS credentials. If this is None or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or @@ -52,8 +58,11 @@ class SFTPToS3Operator(BaseOperator): maintained on each worker node). :param s3_bucket: The targeted s3 bucket. This is the S3 bucket to where the file is uploaded. - :param s3_key: The targeted s3 key. This is the specified path for - uploading the file to S3. + :param s3_key: The targeted s3 key. For a single file it must include the file + path. For multiple files it must end with ``"/"``. + :param s3_filenames: Only used if you want to move multiple files and name them + differently from the originals on the SFTP server. It can be a list of + filenames or a string prefix that replaces the sftp prefix. :param use_temp_file: If True, copies file first to local, if False streams file from SFTP to S3. :param fail_on_file_not_exist: If True, operator fails when file does not exist, @@ -64,7 +73,7 @@ class SFTPToS3Operator(BaseOperator): :param acl_policy: Canned ACL policy for the file being uploaded to S3. """ - template_fields: Sequence[str] = ("s3_key", "sftp_path", "s3_bucket") + template_fields: Sequence[str] = ("s3_key", "sftp_path", "s3_bucket", "sftp_filenames", "s3_filenames") def __init__( self, @@ -76,6 +85,8 @@ def __init__( sftp_remote_host: str = "", aws_conn_id: str = "aws_default", s3_conn_id: str | None = None, + sftp_filenames: str | list[str] | None = None, + s3_filenames: str | list[str] | None = None, use_temp_file: bool = True, fail_on_file_not_exist: bool = True, replace: bool = False, @@ -98,6 +109,8 @@ def __init__( self.s3_bucket = s3_bucket self.s3_key = s3_key self.aws_conn_id = aws_conn_id + self.sftp_filenames = sftp_filenames + self.s3_filenames = s3_filenames self.use_temp_file = use_temp_file self.fail_on_file_not_exist = fail_on_file_not_exist self.replace = replace @@ -111,29 +124,27 @@ def get_s3_key(s3_key: str) -> str: parsed_s3_key = urlsplit(s3_key) return parsed_s3_key.path.lstrip("/") - def execute(self, context: Context) -> None: - self.s3_key = self.get_s3_key(self.s3_key) - - # SSHHook will handle a None/"" sftp_remote_host - ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id, remote_host=self.sftp_remote_host) - s3_hook = S3Hook(self.aws_conn_id) - - sftp_client = ssh_hook.get_conn().open_sftp() - + def _upload_to_s3( + self, + sftp_client: paramiko.SFTPClient, + s3_hook: S3Hook, + sftp_path: str, + s3_key: str, + ) -> None: try: - sftp_client.stat(self.sftp_path) + sftp_client.stat(sftp_path) except FileNotFoundError: if self.fail_on_file_not_exist: raise - self.log.info("File %s not found on SFTP server. Skipping transfer.", self.sftp_path) + self.log.info("File %s not found on SFTP server. Skipping transfer.", sftp_path) return if self.use_temp_file: with NamedTemporaryFile("w") as f: - sftp_client.get(self.sftp_path, f.name) + sftp_client.get(sftp_path, f.name) s3_hook.load_file( filename=f.name, - key=self.s3_key, + key=s3_key, bucket_name=self.s3_bucket, replace=self.replace, encrypt=self.encrypt, @@ -146,7 +157,57 @@ def execute(self, context: Context) -> None: extra_args["ServerSideEncryption"] = "AES256" if self.acl_policy: extra_args["ACL"] = self.acl_policy - with sftp_client.file(self.sftp_path, mode="rb") as data: + with sftp_client.file(sftp_path, mode="rb") as data: s3_hook.get_conn().upload_fileobj( - data, self.s3_bucket, self.s3_key, ExtraArgs=extra_args or None, Callback=self.log.info + data, self.s3_bucket, s3_key, ExtraArgs=extra_args or None, Callback=self.log.info ) + + def execute(self, context: Context) -> None: + self.s3_key = self.get_s3_key(self.s3_key) + + # SSHHook will handle a None/"" sftp_remote_host + ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id, remote_host=self.sftp_remote_host) + s3_hook = S3Hook(self.aws_conn_id) + sftp_client = ssh_hook.get_conn().open_sftp() + + if self.sftp_filenames: + if isinstance(self.sftp_filenames, str): + self.log.info("Getting files in %s", self.sftp_path) + list_dir = sftp_client.listdir(self.sftp_path) + if self.sftp_filenames == "*": + files = list_dir + else: + sftp_prefix: str = self.sftp_filenames + files = [f for f in list_dir if sftp_prefix in f] + + for file in files: + self.log.info("Moving file %s", file) + if self.s3_filenames and isinstance(self.s3_filenames, str): + s3_filename = file.replace(self.sftp_filenames, self.s3_filenames) + else: + s3_filename = file + self._upload_to_s3( + sftp_client, + s3_hook, + f"{self.sftp_path}/{file}", + f"{self.s3_key}{s3_filename}", + ) + else: + if self.s3_filenames: + for sftp_file, s3_file in zip(self.sftp_filenames, self.s3_filenames): + self._upload_to_s3( + sftp_client, + s3_hook, + self.sftp_path + sftp_file, + self.s3_key + s3_file, + ) + else: + for sftp_file in self.sftp_filenames: + self._upload_to_s3( + sftp_client, + s3_hook, + self.sftp_path + sftp_file, + self.s3_key + sftp_file, + ) + else: + self._upload_to_s3(sftp_client, s3_hook, self.sftp_path, self.s3_key) diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_ftp.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_ftp.py index 6308d34ac020a..82be1cf0412e9 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_ftp.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_ftp.py @@ -19,6 +19,8 @@ from unittest import mock +import pytest + from airflow.providers.amazon.aws.transfers.s3_to_ftp import S3ToFTPOperator TASK_ID = "test_s3_to_ftp" @@ -42,3 +44,29 @@ def test_execute(self, mock_local_tmp_file, mock_s3_hook_get_key, mock_ftp_hook_ mock_local_tmp_file_value = mock_local_tmp_file.return_value.__enter__.return_value mock_s3_hook_get_key.return_value.download_fileobj.assert_called_once_with(mock_local_tmp_file_value) mock_ftp_hook_store_file.assert_called_once_with(operator.ftp_path, mock_local_tmp_file_value.name) + + +class TestS3ToFTPOperatorInit: + """Unit tests for S3ToFTPOperator.__init__ that do not require an FTP server.""" + + @pytest.mark.parametrize( + ("s3_filenames", "ftp_filenames"), + [ + (None, None), + ("*", None), + ("prefix_", "renamed_"), + (["a.csv", "b.csv"], ["x.csv", "y.csv"]), + ], + ) + def test_multi_file_params(self, s3_filenames, ftp_filenames): + """s3_filenames and ftp_filenames are stored correctly.""" + op = S3ToFTPOperator( + task_id="test_multi", + s3_bucket=BUCKET, + s3_key=S3_KEY, + ftp_path=FTP_PATH, + s3_filenames=s3_filenames, + ftp_filenames=ftp_filenames, + ) + assert op.s3_filenames == s3_filenames + assert op.ftp_filenames == ftp_filenames diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sftp.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sftp.py index 257b898922ccd..de8ae1892b05c 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sftp.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sftp.py @@ -313,3 +313,30 @@ def test_s3_to_sftp_operator_sftp_remote_host(self): def teardown_method(self): self.delete_remote_resource() + + +class TestS3ToSFTPOperatorInit: + """Unit tests for S3ToSFTPOperator.__init__ that do not require an SSH server.""" + + @pytest.mark.parametrize( + ("s3_filenames", "sftp_filenames"), + [ + (None, None), + ("*", None), + ("prefix_", "renamed_"), + (["a.csv", "b.csv"], ["x.csv", "y.csv"]), + ], + ) + def test_multi_file_params(self, s3_filenames, sftp_filenames): + """s3_filenames and sftp_filenames are stored correctly.""" + op = S3ToSFTPOperator( + task_id="test_multi", + s3_bucket=BUCKET, + s3_key=S3_KEY, + sftp_path=SFTP_PATH, + sftp_conn_id=SFTP_CONN_ID, + s3_filenames=s3_filenames, + sftp_filenames=sftp_filenames, + ) + assert op.s3_filenames == s3_filenames + assert op.sftp_filenames == sftp_filenames diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_sftp_to_s3.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_sftp_to_s3.py index 779d43de1bf1c..903175b7ef67e 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_sftp_to_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_sftp_to_s3.py @@ -276,3 +276,26 @@ def test_s3_upload_options(self, kwargs, expected): assert op.encrypt == expected["encrypt"] assert op.gzip == expected["gzip"] assert op.acl_policy == expected["acl_policy"] + + @pytest.mark.parametrize( + ("sftp_filenames", "s3_filenames"), + [ + (None, None), + ("*", None), + ("prefix_", "renamed_"), + (["a.csv", "b.csv"], ["x.csv", "y.csv"]), + ], + ) + def test_multi_file_params(self, sftp_filenames, s3_filenames): + """sftp_filenames and s3_filenames are stored correctly.""" + op = SFTPToS3Operator( + task_id="test_multi", + s3_bucket=BUCKET, + s3_key=S3_KEY, + sftp_path=SFTP_PATH, + sftp_conn_id=SFTP_CONN_ID, + sftp_filenames=sftp_filenames, + s3_filenames=s3_filenames, + ) + assert op.sftp_filenames == sftp_filenames + assert op.s3_filenames == s3_filenames From ab477a8b1b9ffe9d24da396e154518997a0a215b Mon Sep 17 00:00:00 2001 From: Yuseok Jo Date: Sun, 24 May 2026 00:24:28 +0900 Subject: [PATCH 4/4] Add fail_on_file_not_exist to FTPToS3, S3ToSFTP, S3ToFTP operators --- .../amazon/aws/transfers/ftp_to_s3.py | 41 ++++++++++++------- .../amazon/aws/transfers/s3_to_ftp.py | 9 ++++ .../amazon/aws/transfers/s3_to_sftp.py | 23 +++++++---- .../amazon/aws/transfers/test_ftp_to_s3.py | 35 ++++++++++++++++ .../amazon/aws/transfers/test_s3_to_ftp.py | 28 +++++++++++++ .../amazon/aws/transfers/test_s3_to_sftp.py | 35 ++++++++++++++++ 6 files changed, 148 insertions(+), 23 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/ftp_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/ftp_to_s3.py index 251c16a5e26b6..5dd933a197bc7 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/ftp_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/ftp_to_s3.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import ftplib from collections.abc import Sequence from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING @@ -60,6 +61,9 @@ class FTPToS3Operator(BaseOperator): :param gzip: If True, the file will be compressed locally :param acl_policy: String specifying the canned ACL policy for the file being uploaded to the S3 bucket. + :param fail_on_file_not_exist: If True, operator fails when a source file does not + exist on the FTP server. If False, the operator logs a warning and skips the + transfer. Default is True. """ template_fields: Sequence[str] = ("ftp_path", "s3_bucket", "s3_key", "ftp_filenames", "s3_filenames") @@ -78,6 +82,7 @@ def __init__( encrypt: bool = False, gzip: bool = False, acl_policy: str | None = None, + fail_on_file_not_exist: bool = True, **kwargs, ): super().__init__(**kwargs) @@ -92,25 +97,31 @@ def __init__( self.encrypt = encrypt self.gzip = gzip self.acl_policy = acl_policy + self.fail_on_file_not_exist = fail_on_file_not_exist self.s3_hook: S3Hook | None = None self.ftp_hook: FTPHook | None = None def __upload_to_s3_from_ftp(self, remote_filename, s3_file_key): - with NamedTemporaryFile() as local_tmp_file: - self.ftp_hook.retrieve_file( - remote_full_path=remote_filename, local_full_path_or_buffer=local_tmp_file.name - ) - - self.s3_hook.load_file( - filename=local_tmp_file.name, - key=s3_file_key, - bucket_name=self.s3_bucket, - replace=self.replace, - encrypt=self.encrypt, - gzip=self.gzip, - acl_policy=self.acl_policy, - ) - self.log.info("File upload to %s", s3_file_key) + try: + with NamedTemporaryFile() as local_tmp_file: + self.ftp_hook.retrieve_file( + remote_full_path=remote_filename, local_full_path_or_buffer=local_tmp_file.name + ) + self.s3_hook.load_file( + filename=local_tmp_file.name, + key=s3_file_key, + bucket_name=self.s3_bucket, + replace=self.replace, + encrypt=self.encrypt, + gzip=self.gzip, + acl_policy=self.acl_policy, + ) + self.log.info("File upload to %s", s3_file_key) + except ftplib.error_perm as e: + if "550" in str(e) and not self.fail_on_file_not_exist: + self.log.info("File %s not found on FTP server. Skipping transfer.", remote_filename) + return + raise def execute(self, context: Context): self.ftp_hook = FTPHook(ftp_conn_id=self.ftp_conn_id) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_ftp.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_ftp.py index 0232e5fdca665..ad532e9ff1d66 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_ftp.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_ftp.py @@ -55,6 +55,8 @@ class S3ToFTPOperator(BaseOperator): :param aws_conn_id: reference to a specific AWS connection :param ftp_conn_id: The ftp connection id. The name or identifier for establishing a connection to the FTP server. + :param fail_on_file_not_exist: If True, operator fails when a source S3 key does not + exist. If False, the operator logs a warning and skips the transfer. Default is True. """ template_fields: Sequence[str] = ("s3_bucket", "s3_key", "ftp_path", "s3_filenames", "ftp_filenames") @@ -69,6 +71,7 @@ def __init__( ftp_filenames: str | list[str] | None = None, aws_conn_id="aws_default", ftp_conn_id="ftp_default", + fail_on_file_not_exist: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) @@ -79,8 +82,14 @@ def __init__( self.ftp_filenames = ftp_filenames self.aws_conn_id = aws_conn_id self.ftp_conn_id = ftp_conn_id + self.fail_on_file_not_exist = fail_on_file_not_exist def _download_from_s3(self, s3_hook: S3Hook, ftp_hook: FTPHook, s3_key: str, ftp_path: str) -> None: + if not s3_hook.check_for_key(s3_key, self.s3_bucket): + if self.fail_on_file_not_exist: + raise FileNotFoundError(f"Key {s3_key!r} not found in S3 bucket {self.s3_bucket!r}") + self.log.info("Key %s not found in S3. Skipping transfer.", s3_key) + return s3_obj = s3_hook.get_key(s3_key, self.s3_bucket) with NamedTemporaryFile() as local_tmp_file: self.log.info("Downloading file from %s", s3_key) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py index 26c662c91ca13..b44d6fd82b9b1 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py @@ -67,6 +67,8 @@ class S3ToSFTPOperator(BaseOperator): :param confirm: specify if the SFTP operation should be confirmed, defaults to True. When True, a stat will be performed on the remote file after upload to verify the file size matches and confirm successful transfer. + :param fail_on_file_not_exist: If True, operator fails when a source S3 key does not + exist. If False, the operator logs a warning and skips the transfer. Default is True. """ template_fields: Sequence[str] = ("s3_key", "sftp_path", "s3_bucket", "s3_filenames", "sftp_filenames") @@ -83,6 +85,7 @@ def __init__( s3_filenames: str | list[str] | None = None, sftp_filenames: str | list[str] | None = None, confirm: bool = True, + fail_on_file_not_exist: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) @@ -95,6 +98,7 @@ def __init__( self.s3_filenames = s3_filenames self.sftp_filenames = sftp_filenames self.confirm = confirm + self.fail_on_file_not_exist = fail_on_file_not_exist @staticmethod def get_s3_key(s3_key: str) -> str: @@ -105,12 +109,17 @@ def get_s3_key(s3_key: str) -> str: def _download_from_s3( self, sftp_client: paramiko.SFTPClient, - s3_client, + s3_hook: S3Hook, s3_key: str, sftp_path: str, ) -> None: + if not s3_hook.check_for_key(s3_key, self.s3_bucket): + if self.fail_on_file_not_exist: + raise FileNotFoundError(f"Key {s3_key!r} not found in S3 bucket {self.s3_bucket!r}") + self.log.info("Key %s not found in S3. Skipping transfer.", s3_key) + return with NamedTemporaryFile("w") as f: - s3_client.download_file(self.s3_bucket, s3_key, f.name) + s3_hook.get_conn().download_file(self.s3_bucket, s3_key, f.name) sftp_client.put(f.name, sftp_path, confirm=self.confirm) def execute(self, context: Context) -> None: @@ -119,8 +128,6 @@ def execute(self, context: Context) -> None: # SSHHook will handle a None/"" sftp_remote_host ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id, remote_host=self.sftp_remote_host) s3_hook = S3Hook(self.aws_conn_id) - - s3_client = s3_hook.get_conn() sftp_client = ssh_hook.get_conn().open_sftp() if self.s3_filenames: @@ -142,7 +149,7 @@ def execute(self, context: Context) -> None: sftp_filename = file self._download_from_s3( sftp_client, - s3_client, + s3_hook, self.s3_key + file, self.sftp_path + sftp_filename, ) @@ -151,7 +158,7 @@ def execute(self, context: Context) -> None: for s3_file, sftp_file in zip(self.s3_filenames, self.sftp_filenames): self._download_from_s3( sftp_client, - s3_client, + s3_hook, self.s3_key + s3_file, self.sftp_path + sftp_file, ) @@ -159,9 +166,9 @@ def execute(self, context: Context) -> None: for s3_file in self.s3_filenames: self._download_from_s3( sftp_client, - s3_client, + s3_hook, self.s3_key + s3_file, self.sftp_path + s3_file, ) else: - self._download_from_s3(sftp_client, s3_client, self.s3_key, self.sftp_path) + self._download_from_s3(sftp_client, s3_hook, self.s3_key, self.sftp_path) diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_ftp_to_s3.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_ftp_to_s3.py index 757a396464178..102969b33c8f9 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_ftp_to_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_ftp_to_s3.py @@ -17,7 +17,11 @@ # under the License. from __future__ import annotations +import ftplib from unittest import mock +from unittest.mock import MagicMock, patch + +import pytest from airflow.providers.amazon.aws.transfers.ftp_to_s3 import FTPToS3Operator @@ -128,3 +132,34 @@ def test_execute_multiple_files_prefix( operator.execute(None) mock_ftp_hook_list_directory.assert_called_once_with(path=FTP_PATH_MULTIPLE) + + +class TestFTPToS3OperatorInit: + """Unit tests for FTPToS3Operator.__init__ that do not require an FTP server.""" + + def test_fail_on_file_not_exist_default(self): + """fail_on_file_not_exist defaults to True.""" + op = FTPToS3Operator(task_id="test_fail_default", s3_bucket=BUCKET, s3_key=S3_KEY, ftp_path=FTP_PATH) + assert op.fail_on_file_not_exist is True + + @pytest.mark.parametrize("fail_on_file_not_exist", [True, False]) + def test_fail_on_file_not_exist_skip(self, fail_on_file_not_exist): + """When FTP file is missing (error_perm 550): raise if True, skip if False.""" + op = FTPToS3Operator( + task_id="test_skip", + s3_bucket=BUCKET, + s3_key=S3_KEY, + ftp_path=FTP_PATH, + fail_on_file_not_exist=fail_on_file_not_exist, + ) + op.ftp_hook = MagicMock() + op.s3_hook = MagicMock() + op.ftp_hook.retrieve_file.side_effect = ftplib.error_perm("550 No such file or directory") + + if fail_on_file_not_exist: + with pytest.raises(ftplib.error_perm): + op._FTPToS3Operator__upload_to_s3_from_ftp(FTP_PATH, S3_KEY) + else: + with patch.object(op.log, "info") as mock_log: + op._FTPToS3Operator__upload_to_s3_from_ftp(FTP_PATH, S3_KEY) + mock_log.assert_called_once() diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_ftp.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_ftp.py index 82be1cf0412e9..c40ce14c60718 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_ftp.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_ftp.py @@ -70,3 +70,31 @@ def test_multi_file_params(self, s3_filenames, ftp_filenames): ) assert op.s3_filenames == s3_filenames assert op.ftp_filenames == ftp_filenames + + def test_fail_on_file_not_exist_default(self): + """fail_on_file_not_exist defaults to True.""" + op = S3ToFTPOperator(task_id="test_fail_default", s3_bucket=BUCKET, s3_key=S3_KEY, ftp_path=FTP_PATH) + assert op.fail_on_file_not_exist is True + + @pytest.mark.parametrize("fail_on_file_not_exist", [True, False]) + def test_fail_on_file_not_exist_skip(self, fail_on_file_not_exist): + """When key is missing: raise FileNotFoundError if True, skip if False.""" + from unittest.mock import MagicMock, patch + + op = S3ToFTPOperator( + task_id="test_skip", + s3_bucket=BUCKET, + s3_key=S3_KEY, + ftp_path=FTP_PATH, + fail_on_file_not_exist=fail_on_file_not_exist, + ) + mock_s3_hook = MagicMock() + mock_s3_hook.check_for_key.return_value = False + + if fail_on_file_not_exist: + with pytest.raises(FileNotFoundError): + op._download_from_s3(mock_s3_hook, MagicMock(), S3_KEY, FTP_PATH) + else: + with patch.object(op.log, "info") as mock_log: + op._download_from_s3(mock_s3_hook, MagicMock(), S3_KEY, FTP_PATH) + mock_log.assert_called_once() diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sftp.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sftp.py index de8ae1892b05c..867c7336ae920 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sftp.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sftp.py @@ -340,3 +340,38 @@ def test_multi_file_params(self, s3_filenames, sftp_filenames): ) assert op.s3_filenames == s3_filenames assert op.sftp_filenames == sftp_filenames + + def test_fail_on_file_not_exist_default(self): + """fail_on_file_not_exist defaults to True.""" + op = S3ToSFTPOperator( + task_id="test_fail_default", + s3_bucket=BUCKET, + s3_key=S3_KEY, + sftp_path=SFTP_PATH, + sftp_conn_id=SFTP_CONN_ID, + ) + assert op.fail_on_file_not_exist is True + + @pytest.mark.parametrize("fail_on_file_not_exist", [True, False]) + def test_fail_on_file_not_exist_skip(self, fail_on_file_not_exist): + """When key is missing: raise FileNotFoundError if True, skip if False.""" + from unittest.mock import MagicMock, patch + + op = S3ToSFTPOperator( + task_id="test_skip", + s3_bucket=BUCKET, + s3_key=S3_KEY, + sftp_path=SFTP_PATH, + sftp_conn_id=SFTP_CONN_ID, + fail_on_file_not_exist=fail_on_file_not_exist, + ) + mock_s3_hook = MagicMock() + mock_s3_hook.check_for_key.return_value = False + + if fail_on_file_not_exist: + with pytest.raises(FileNotFoundError): + op._download_from_s3(MagicMock(), mock_s3_hook, S3_KEY, SFTP_PATH) + else: + with patch.object(op.log, "info") as mock_log: + op._download_from_s3(MagicMock(), mock_s3_hook, S3_KEY, SFTP_PATH) + mock_log.assert_called_once()