diff --git a/inbox/events/abstract.py b/inbox/events/abstract.py index bda76ed4b..241676bb7 100644 --- a/inbox/events/abstract.py +++ b/inbox/events/abstract.py @@ -8,7 +8,6 @@ from inbox.models.backends.oauth import token_manager from inbox.models.calendar import Calendar from inbox.models.event import Event -from inbox.models.session import session_scope log = get_logger() @@ -19,10 +18,10 @@ class AbstractEventsProvider(abc.ABC): specified account. """ - def __init__(self, account_id: int, namespace_id: int): - self.account_id = account_id - self.namespace_id = namespace_id - self.log = log.new(account_id=account_id, component="calendar sync") + def __init__(self, account: Account): + self.account = account + self.namespace_id = account.namespace.id + self.log = log.new(account_id=account.id, component="calendar sync") # A hash to store whether a calendar is read-only or not. # This is a bit of a hack because this isn't exposed at the event level @@ -54,34 +53,28 @@ def sync_events( raise NotImplementedError() @abc.abstractmethod - def webhook_notifications_enabled(self, account: Account) -> bool: + def webhook_notifications_enabled(self) -> bool: """ Return True if webhook notifications are enabled for a given account. """ raise NotImplementedError() @abc.abstractmethod - def watch_calendar_list(self, account: Account) -> Optional[datetime.datetime]: + def watch_calendar_list(self) -> Optional[datetime.datetime]: """ Subscribe to webhook notifications for changes to calendar list. - Arguments: - account: The account - Returns: The expiration of the notification channel """ raise NotImplementedError() @abc.abstractmethod - def watch_calendar( - self, account: Account, calendar: Calendar - ) -> Optional[datetime.datetime]: + def watch_calendar(self, calendar: Calendar) -> Optional[datetime.datetime]: """ Subscribe to webhook notifications for changes to events in a calendar. Arguments: - account: The account calendar: The calendar Returns: @@ -102,13 +95,9 @@ def _get_access_token( Returns: The token """ - with session_scope(self.namespace_id) as db_session: - acc = db_session.query(Account).get(self.account_id) - # This will raise OAuthError if OAuth access was revoked. The - # BaseSyncMonitor loop will catch this, clean up, and exit. - return token_manager.get_token( - acc, force_refresh=force_refresh, scopes=scopes - ) + return token_manager.get_token( + self.account, force_refresh=force_refresh, scopes=scopes + ) class CalendarGoneException(Exception): diff --git a/inbox/events/actions/backends/gmail.py b/inbox/events/actions/backends/gmail.py index 1e758e78d..e439c211b 100644 --- a/inbox/events/actions/backends/gmail.py +++ b/inbox/events/actions/backends/gmail.py @@ -8,7 +8,7 @@ def remote_create_event(account, event, db_session, extra_args): - provider = GoogleEventsProvider(account.id, account.namespace.id) + provider = GoogleEventsProvider(account) result = provider.create_remote_event(event, **extra_args) # The events crud API assigns a random uid to an event when creating it. # We need to update it to the value returned by the Google calendar API. @@ -17,12 +17,12 @@ def remote_create_event(account, event, db_session, extra_args): def remote_update_event(account, event, db_session, extra_args): - provider = GoogleEventsProvider(account.id, account.namespace.id) + provider = GoogleEventsProvider(account) provider.update_remote_event(event, **extra_args) def remote_delete_event( account, event_uid, calendar_name, calendar_uid, db_session, extra_args ): - provider = GoogleEventsProvider(account.id, account.namespace.id) + provider = GoogleEventsProvider(account) provider.delete_remote_event(calendar_uid, event_uid, **extra_args) diff --git a/inbox/events/google.py b/inbox/events/google.py index 764c2feff..21abdd2db 100644 --- a/inbox/events/google.py +++ b/inbox/events/google.py @@ -21,7 +21,7 @@ parse_datetime, parse_google_time, ) -from inbox.models import Account, Calendar +from inbox.models import Calendar from inbox.models.backends.oauth import token_manager from inbox.models.event import EVENT_STATUSES, Event @@ -279,15 +279,15 @@ def delete_remote_event(self, calendar_uid, event_uid, **kwargs): # -------- logic for push notification subscriptions -------- # - def _get_access_token_for_push_notifications(self, account, force_refresh=False): - if not self.webhook_notifications_enabled(account): + def _get_access_token_for_push_notifications(self, force_refresh=False): + if not self.webhook_notifications_enabled(): raise OAuthError("Account not enabled for push notifications.") - return token_manager.get_token(account, force_refresh) + return token_manager.get_token(self.account, force_refresh) - def webhook_notifications_enabled(self, account: Account) -> bool: - return account.get_client_info()[0] in WEBHOOK_ENABLED_CLIENT_IDS + def webhook_notifications_enabled(self) -> bool: + return self.account.get_client_info()[0] in WEBHOOK_ENABLED_CLIENT_IDS - def watch_calendar_list(self, account: Account) -> Optional[datetime.datetime]: + def watch_calendar_list(self) -> Optional[datetime.datetime]: """ Subscribe to google push notifications for the calendar list. @@ -300,9 +300,9 @@ def watch_calendar_list(self, account: Account) -> Optional[datetime.datetime]: Returns: The expiration of the notification channel """ - token = self._get_access_token_for_push_notifications(account) + token = self._get_access_token_for_push_notifications() receiving_url = CALENDAR_LIST_WEBHOOK_URL.format( - urllib.parse.quote(account.public_id) + urllib.parse.quote(self.account.public_id) ) one_week = datetime.timedelta(weeks=1) @@ -338,9 +338,7 @@ def watch_calendar_list(self, account: Account) -> Optional[datetime.datetime]: self._handle_watch_errors(r) return None - def watch_calendar( - self, account: Account, calendar: Calendar - ) -> Optional[datetime.datetime]: + def watch_calendar(self, calendar: Calendar) -> Optional[datetime.datetime]: """ Subscribe to google push notifications for a calendar. @@ -355,7 +353,7 @@ def watch_calendar( Returns: The expiration of the notification channel """ - token = self._get_access_token_for_push_notifications(account) + token = self._get_access_token_for_push_notifications() watch_url = WATCH_EVENTS_URL.format(urllib.parse.quote(calendar.uid)) receiving_url = EVENTS_LIST_WEBHOOK_URL.format( urllib.parse.quote(calendar.public_id) diff --git a/inbox/events/microsoft/events_provider.py b/inbox/events/microsoft/events_provider.py index 15357b87f..3406c23e2 100644 --- a/inbox/events/microsoft/events_provider.py +++ b/inbox/events/microsoft/events_provider.py @@ -67,8 +67,8 @@ class MicrosoftEventsProvider(AbstractEventsProvider): - def __init__(self, account_id: int, namespace_id: int): - super().__init__(account_id, namespace_id) + def __init__(self, account: Account): + super().__init__(account) self.client = MicrosoftGraphClient( lambda: self._get_access_token(scopes=MICROSOFT_CALENDAR_SCOPES) @@ -201,7 +201,7 @@ def _get_event_overrides( return exceptions, cancellations - def webhook_notifications_enabled(self, account: Account) -> bool: + def webhook_notifications_enabled(self) -> bool: """ Return True if webhook notifications are enabled for a given account. @@ -225,7 +225,7 @@ def webhook_notifications_enabled(self, account: Account) -> bool: try: dummy_subscription = self.client.subscribe_to_calendar_changes( - webhook_url=CALENDAR_LIST_WEBHOOK_URL.format(account.public_id), + webhook_url=CALENDAR_LIST_WEBHOOK_URL.format(self.account.public_id), secret=config["MICROSOFT_SUBSCRIPTION_SECRET"], ) except MicrosoftGraphClientException as e: @@ -245,18 +245,15 @@ def webhook_notifications_enabled(self, account: Account) -> bool: self._webhook_notifications_enabled = True return True - def watch_calendar_list(self, account: Account) -> Optional[datetime.datetime]: + def watch_calendar_list(self) -> Optional[datetime.datetime]: """ Subscribe to webhook notifications for changes to calendar list. - Arguments: - account: The account - Returns: The expiration of the notification channel """ response = self.client.subscribe_to_calendar_changes( - webhook_url=CALENDAR_LIST_WEBHOOK_URL.format(account.public_id), + webhook_url=CALENDAR_LIST_WEBHOOK_URL.format(self.account.public_id), secret=config["MICROSOFT_SUBSCRIPTION_SECRET"], ) @@ -264,14 +261,11 @@ def watch_calendar_list(self, account: Account) -> Optional[datetime.datetime]: return ciso8601.parse_datetime(expiration).replace(microsecond=0) - def watch_calendar( - self, account: Account, calendar: Calendar - ) -> Optional[datetime.datetime]: + def watch_calendar(self, calendar: Calendar) -> Optional[datetime.datetime]: """ Subscribe to webhook notifications for changes to events in a calendar. Arguments: - account: The account calendar: The calendar Returns: diff --git a/inbox/events/remote_sync.py b/inbox/events/remote_sync.py index a6e42b7d7..19b4b4e44 100644 --- a/inbox/events/remote_sync.py +++ b/inbox/events/remote_sync.py @@ -40,27 +40,26 @@ class EventSync(BaseSyncMonitor): def __init__( self, - email_address: str, - provider_name: str, - account_id: int, - namespace_id: int, + account: Account, provider_class: Type[AbstractEventsProvider], poll_frequency: int = POLL_FREQUENCY, ): - bind_context(self, "eventsync", account_id) - self.provider = provider_class(account_id, namespace_id) + bind_context(self, "eventsync", account.id) + self.provider = provider_class(account) self.log = logger.new( - account_id=account_id, component="calendar sync", provider=provider_name + account_id=account.id, + component="calendar sync", + provider=account.verbose_provider, ) BaseSyncMonitor.__init__( self, - account_id, - namespace_id, - email_address, + account.id, + account.namespace.id, + account.email_address, EVENT_SYNC_FOLDER_ID, EVENT_SYNC_FOLDER_NAME, - provider_name, + account.verbose_provider, poll_frequency=poll_frequency, scope="calendar", ) @@ -239,19 +238,16 @@ def handle_event_updates( class WebhookEventSync(EventSync): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - with session_scope(self.namespace_id) as db_session: - account = db_session.query(Account).get(self.account_id) - if ( - self.provider.webhook_notifications_enabled(account) - and kwargs.get("poll_frequency") is None - ): - # Run the sync loop more frequently if push notifications are - # enabled. Note that we'll only update the calendar if a - # Webhook was receicved recently, or if we haven't synced for - # too long. - self.poll_frequency = PUSH_NOTIFICATION_POLL_FREQUENCY + def __init__( + self, account: Account, provider_class: Type[AbstractEventsProvider], + ): + super().__init__(account, provider_class) + if self.provider.webhook_notifications_enabled(): + # Run the sync loop more frequently if push notifications are + # enabled. Note that we'll only update the calendar if a + # Webhook was receicved recently, or if we haven't synced for + # too long. + self.poll_frequency = PUSH_NOTIFICATION_POLL_FREQUENCY def sync(self) -> None: """Query a remote provider for updates and persist them to the @@ -295,12 +291,12 @@ def _refresh_webhook_subscriptions(self) -> None: with session_scope(self.namespace_id) as db_session: account = db_session.query(Account).get(self.account_id) - if not self.provider.webhook_notifications_enabled(account): + if not self.provider.webhook_notifications_enabled(): self.log.warning("Webhook notifications disabled") return if account.needs_new_calendar_list_watch(): - calendar_list_expiration = self.provider.watch_calendar_list(account) + calendar_list_expiration = self.provider.watch_calendar_list() if calendar_list_expiration is not None: account.new_calendar_list_watch(calendar_list_expiration) @@ -311,9 +307,7 @@ def _refresh_webhook_subscriptions(self) -> None: ) for calendar in calendars_to_watch: try: - event_list_expiration = self.provider.watch_calendar( - account, calendar - ) + event_list_expiration = self.provider.watch_calendar(calendar) if event_list_expiration is not None: calendar.new_event_watch(event_list_expiration) except CalendarGoneException: diff --git a/inbox/mailsync/service.py b/inbox/mailsync/service.py index 0fa4dd418..5435ac014 100644 --- a/inbox/mailsync/service.py +++ b/inbox/mailsync/service.py @@ -310,27 +310,15 @@ def start_sync(self, account_id): if info.get("events", None) and acc.sync_events: if USE_GOOGLE_PUSH_NOTIFICATIONS and acc.provider == "gmail": event_sync = WebhookEventSync( - acc.email_address, - acc.verbose_provider, - acc.id, - acc.namespace.id, - provider_class=GoogleEventsProvider, + acc, provider_class=GoogleEventsProvider, ) elif acc.provider == "gmail": event_sync = EventSync( - acc.email_address, - acc.verbose_provider, - acc.id, - acc.namespace.id, - provider_class=GoogleEventsProvider, + acc, provider_class=GoogleEventsProvider, ) elif acc.provider == "microsoft": event_sync = WebhookEventSync( - acc.email_address, - acc.verbose_provider, - acc.id, - acc.namespace.id, - provider_class=MicrosoftEventsProvider, + acc, provider_class=MicrosoftEventsProvider, ) self.event_sync_monitors[acc.id] = event_sync event_sync.start() diff --git a/tests/events/microsoft/test_events_provider.py b/tests/events/microsoft/test_events_provider.py index f0ea99372..9f1a8e4e6 100644 --- a/tests/events/microsoft/test_events_provider.py +++ b/tests/events/microsoft/test_events_provider.py @@ -406,8 +406,8 @@ def exception_override_response(): @pytest.fixture -def provider(client): - provider = MicrosoftEventsProvider("fake_account_id", "fake_namespace_id") +def provider(client, outlook_account): + provider = MicrosoftEventsProvider(outlook_account) provider.client = client return provider @@ -436,7 +436,7 @@ def test_sync_calendars_deletion(db, client, outlook_account): db.session.add(deleted_calendar) db.session.commit() - provider = MicrosoftEventsProvider(outlook_account.id, outlook_account.namespace.id) + provider = MicrosoftEventsProvider(outlook_account) provider.client = client deleted_uids, _ = provider.sync_calendars() @@ -503,39 +503,39 @@ def test_sync_events_exception(provider): @responses.activate @pytest.mark.usefixtures("subscribe_responses") -def test_watch_calendar_list(provider, outlook_account): - expiration = provider.watch_calendar_list(outlook_account) +def test_watch_calendar_list(provider): + expiration = provider.watch_calendar_list() assert expiration == datetime.datetime(2022, 11, 24, 18, 31, 12, tzinfo=pytz.UTC) @responses.activate @pytest.mark.usefixtures("subscribe_responses") -def test_watch_calendar(provider, outlook_account): +def test_watch_calendar(provider): calendar = Calendar(uid="fake_calendar_id", public_id="fake_public_id") - expiration = provider.watch_calendar(outlook_account, calendar) + expiration = provider.watch_calendar(calendar) assert expiration == datetime.datetime(2022, 10, 25, 4, 22, 34, tzinfo=pytz.UTC) @responses.activate @pytest.mark.usefixtures("subscribe_response_gone") -def test_watch_calendar_gone(provider, outlook_account): +def test_watch_calendar_gone(provider): calendar = Calendar(uid="fake_calendar_id", public_id="fake_public_id") with pytest.raises(CalendarGoneException): - provider.watch_calendar(outlook_account, calendar) + provider.watch_calendar(calendar) @responses.activate @pytest.mark.usefixtures("subscribe_responses") -def test_webhook_notifications_enabled_avaialble(provider, outlook_account): - assert provider.webhook_notifications_enabled(outlook_account) +def test_webhook_notifications_enabled_avaialble(provider): + assert provider.webhook_notifications_enabled() @responses.activate @pytest.mark.usefixtures("subscribe_response_unavailable") -def test_webhook_notifications_enabled_unavailable(provider, outlook_account): - assert not provider.webhook_notifications_enabled(outlook_account) +def test_webhook_notifications_enabled_unavailable(provider): + assert not provider.webhook_notifications_enabled() @responses.activate @@ -547,11 +547,7 @@ def test_webhook_notifications_enabled_unavailable(provider, outlook_account): ) def test_sync(db, provider, outlook_account): event_sync = WebhookEventSync( - outlook_account.email_address, - outlook_account.verbose_provider, - outlook_account.id, - outlook_account.namespace.id, - provider_class=lambda *args, **kwargs: provider, + outlook_account, provider_class=lambda *args, **kwargs: provider, ) # First sync, initially we just read without subscriptions diff --git a/tests/events/test_google_events.py b/tests/events/test_google_events.py index 1ce200c0e..4d53af8f6 100644 --- a/tests/events/test_google_events.py +++ b/tests/events/test_google_events.py @@ -114,7 +114,7 @@ def test_calendar_parsing(): ), ] - provider = GoogleEventsProvider(1, 1) + provider = GoogleEventsProvider(mock.Mock()) provider._get_raw_calendars = mock.MagicMock(return_value=raw_response) deletes, updates = provider.sync_calendars() assert deletes == expected_deletes @@ -257,7 +257,7 @@ def test_event_parsing(): ), ] - provider = GoogleEventsProvider(1, 1) + provider = GoogleEventsProvider(mock.Mock()) provider.calendars_table = {"uid": False} provider._get_raw_events = mock.MagicMock(return_value=raw_response) updates = provider.sync_events("uid", 1) @@ -307,7 +307,7 @@ def test_event_parsing(): } ] - provider = GoogleEventsProvider(1, 1) + provider = GoogleEventsProvider(mock.Mock()) # This is a read-only calendar provider.calendars_table = {"uid": True} @@ -368,7 +368,7 @@ def test_handle_unparseable_dates(): "summary": "test", } ] - provider = GoogleEventsProvider(1, 1) + provider = GoogleEventsProvider(mock.Mock()) provider._get_raw_events = mock.MagicMock(return_value=raw_response) updates = provider.sync_events("uid", 1) assert len(updates) == 0 @@ -385,7 +385,7 @@ def test_pagination(): second_response._content = json.dumps({"items": ["D", "E"]}).encode() requests.get = mock.Mock(side_effect=[first_response, second_response]) - provider = GoogleEventsProvider(1, 1) + provider = GoogleEventsProvider(mock.Mock()) provider._get_access_token = mock.Mock(return_value="token") items = provider._get_resource_list("https://googleapis.com/testurl") assert items == ["A", "B", "C", "D", "E"] @@ -400,7 +400,7 @@ def test_handle_http_401(): second_response._content = json.dumps({"items": ["A", "B", "C"]}).encode() requests.get = mock.Mock(side_effect=[first_response, second_response]) - provider = GoogleEventsProvider(1, 1) + provider = GoogleEventsProvider(mock.Mock()) provider._get_access_token = mock.Mock(return_value="token") items = provider._get_resource_list("https://googleapis.com/testurl") assert items == ["A", "B", "C"] @@ -433,7 +433,7 @@ def test_handle_quota_exceeded(): second_response._content = json.dumps({"items": ["A", "B", "C"]}).encode() requests.get = mock.Mock(side_effect=[first_response, second_response]) - provider = GoogleEventsProvider(1, 1) + provider = GoogleEventsProvider(mock.Mock()) provider._get_access_token = mock.Mock(return_value="token") items = provider._get_resource_list("https://googleapis.com/testurl") # Check that we slept, then retried. @@ -451,7 +451,7 @@ def test_handle_internal_server_error(): second_response._content = json.dumps({"items": ["A", "B", "C"]}).encode() requests.get = mock.Mock(side_effect=[first_response, second_response]) - provider = GoogleEventsProvider(1, 1) + provider = GoogleEventsProvider(mock.Mock()) provider._get_access_token = mock.Mock(return_value="token") items = provider._get_resource_list("https://googleapis.com/testurl") # Check that we slept, then retried. @@ -480,7 +480,7 @@ def test_handle_api_not_enabled(): ).encode() requests.get = mock.Mock(return_value=response) - provider = GoogleEventsProvider(1, 1) + provider = GoogleEventsProvider(mock.Mock()) provider._get_access_token = mock.Mock(return_value="token") with pytest.raises(AccessNotEnabledError): provider._get_resource_list("https://googleapis.com/testurl") @@ -491,7 +491,7 @@ def test_handle_other_errors(): response.status_code = 403 response._content = b"This is not the JSON you're looking for" requests.get = mock.Mock(return_value=response) - provider = GoogleEventsProvider(1, 1) + provider = GoogleEventsProvider(mock.Mock()) provider._get_access_token = mock.Mock(return_value="token") with pytest.raises(requests.exceptions.HTTPError): provider._get_resource_list("https://googleapis.com/testurl") @@ -499,7 +499,7 @@ def test_handle_other_errors(): response = requests.Response() response.status_code = 404 requests.get = mock.Mock(return_value=response) - provider = GoogleEventsProvider(1, 1) + provider = GoogleEventsProvider(mock.Mock()) provider._get_access_token = mock.Mock(return_value="token") with pytest.raises(requests.exceptions.HTTPError): provider._get_resource_list("https://googleapis.com/testurl") @@ -714,7 +714,7 @@ def test_cancelled_override_creation(): } ] - provider = GoogleEventsProvider(1, 1) + provider = GoogleEventsProvider(mock.Mock()) provider._get_raw_events = mock.MagicMock(return_value=raw_response) updates = provider.sync_events("uid", 1) assert updates[0].cancelled is True diff --git a/tests/events/test_sync.py b/tests/events/test_sync.py index fca47a0e9..38b031e97 100644 --- a/tests/events/test_sync.py +++ b/tests/events/test_sync.py @@ -122,13 +122,7 @@ def event_response_with_delete(calendar_uid, sync_from_time): def test_handle_changes(db, generic_account): namespace_id = generic_account.namespace.id - event_sync = EventSync( - generic_account.email_address, - "google", - generic_account.id, - namespace_id, - provider_class=GoogleEventsProvider, - ) + event_sync = EventSync(generic_account, provider_class=GoogleEventsProvider,) # Sync calendars/events event_sync.provider.sync_calendars = calendar_response