Skip to content
Merged
208 changes: 197 additions & 11 deletions durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from dataclasses import dataclass
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Optional, Sequence, TypeVar, Union
from typing import Any, List, Optional, Sequence, TypeVar, Union

import grpc
from google.protobuf import wrappers_pb2
from google.protobuf import wrappers_pb2 as pb2

from durabletask.entities import EntityInstanceId
from durabletask.entities.entity_metadata import EntityMetadata
Expand Down Expand Up @@ -57,6 +57,18 @@ def raise_if_failed(self):
self.failure_details)


@dataclass
class PurgeInstancesResult:
deleted_instance_count: int
is_complete: bool


@dataclass
class CleanEntityStorageResult:
empty_entities_removed: int
orphaned_locks_released: int


class OrchestrationFailedError(Exception):
def __init__(self, message: str, failure_details: task.FailureDetails):
super().__init__(message)
Expand All @@ -73,6 +85,12 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op

state = res.orchestrationState

new_state = parse_orchestration_state(state)
new_state.instance_id = instance_id # Override instance_id with the one from the request, to match old behavior
return new_state


def parse_orchestration_state(state: pb.OrchestrationState) -> OrchestrationState:
failure_details = None
if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '':
failure_details = task.FailureDetails(
Expand All @@ -81,7 +99,7 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
state.failureDetails.stackTrace.value if not helpers.is_empty(state.failureDetails.stackTrace) else None)

return OrchestrationState(
instance_id,
state.instanceId,
state.name,
OrchestrationStatus(state.orchestrationStatus),
state.createdTimestamp.ToDatetime(),
Expand All @@ -93,7 +111,6 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op


class TaskHubGrpcClient:

def __init__(self, *,
host_address: Optional[str] = None,
metadata: Optional[list[tuple[str, str]]] = None,
Expand Down Expand Up @@ -136,7 +153,7 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
req = pb.CreateInstanceRequest(
name=name,
instanceId=instance_id if instance_id else uuid.uuid4().hex,
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
version=helpers.get_string_value(version if version else self.default_version),
orchestrationIdReusePolicy=reuse_id_policy,
Expand All @@ -152,6 +169,65 @@ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = Tr
res: pb.GetInstanceResponse = self._stub.GetInstance(req)
return new_orchestration_state(req.instanceId, res)

def get_all_orchestration_states(self,
max_instance_count: Optional[int] = None,
fetch_inputs_and_outputs: bool = False) -> List[OrchestrationState]:
return self.get_orchestration_state_by(
created_time_from=None,
created_time_to=None,
runtime_status=None,
max_instance_count=max_instance_count,
fetch_inputs_and_outputs=fetch_inputs_and_outputs
)

def get_orchestration_state_by(self,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a strange method name - is this what we use in other SDKs?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least in .NET, we just have 1 method which takes an optional filter.

https://github.com/microsoft/durabletask-dotnet/blob/main/src/Client/Grpc/GrpcDurableTaskClient.cs#L280

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This matches the Funtions Python SDK, but probably better to match the other portable SDKs and have the translation layer for Functions + durabletask python. Will update this in future commit

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to match .NET - takes a Filter instead of a list of params, and there is one method for each. Also removed _continuation_token from the public API

created_time_from: Optional[datetime] = None,
created_time_to: Optional[datetime] = None,
runtime_status: Optional[List[OrchestrationStatus]] = None,
max_instance_count: Optional[int] = None,
fetch_inputs_and_outputs: bool = False,
_continuation_token: Optional[pb2.StringValue] = None
) -> List[OrchestrationState]:
if max_instance_count is None:
# Some backends do not behave well with max_instance_count = None, so we set to max 32-bit signed value
max_instance_count = (1 << 31) - 1
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Is there no int.max in python?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, int values are unbounded in python


Comment thread
andystaples marked this conversation as resolved.
Outdated
self._logger.info(f"Querying orchestration instances with filters - "
f"created_time_from={created_time_from}, "
f"created_time_to={created_time_to}, "
f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
f"max_instance_count={max_instance_count}, "
f"fetch_inputs_and_outputs={fetch_inputs_and_outputs}, "
f"continuation_token={_continuation_token.value if _continuation_token else None}")

states = []

while True:
req = pb.QueryInstancesRequest(
query=pb.InstanceQuery(
runtimeStatus=[status.value for status in runtime_status] if runtime_status else None,
createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
maxInstanceCount=max_instance_count,
fetchInputsAndOutputs=fetch_inputs_and_outputs,
continuationToken=_continuation_token
)
)
resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req)
states += [parse_orchestration_state(res) for res in resp.orchestrationState]
# Check the value for continuationToken - none or "0" indicates that there are no more results.
if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next list of instances...")
if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
break
_continuation_token = resp.continuationToken
else:
break

states = [state for state in states if state is not None] # Filter out any None values
Comment thread
andystaples marked this conversation as resolved.
Outdated
return states

def wait_for_orchestration_start(self, instance_id: str, *,
fetch_payloads: bool = False,
timeout: int = 60) -> Optional[OrchestrationState]:
Expand Down Expand Up @@ -199,7 +275,8 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *,
req = pb.RaiseEventRequest(
instanceId=instance_id,
name=event_name,
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None)
input=helpers.get_string_value(shared.to_json(data) if data is not None else None)
)

self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
self._stub.RaiseEvent(req)
Expand All @@ -209,7 +286,7 @@ def terminate_orchestration(self, instance_id: str, *,
recursive: bool = True):
req = pb.TerminateRequest(
instanceId=instance_id,
output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None,
output=helpers.get_string_value(shared.to_json(output) if output is not None else None),
recursive=recursive)

self._logger.info(f"Terminating instance '{instance_id}'.")
Expand All @@ -225,10 +302,32 @@ def resume_orchestration(self, instance_id: str):
self._logger.info(f"Resuming instance '{instance_id}'.")
self._stub.ResumeInstance(req)

def purge_orchestration(self, instance_id: str, recursive: bool = True):
def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
self._logger.info(f"Purging instance '{instance_id}'.")
self._stub.PurgeInstances(req)
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req)
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)

def purge_orchestrations_by(self,
created_time_from: Optional[datetime] = None,
created_time_to: Optional[datetime] = None,
runtime_status: Optional[List[OrchestrationStatus]] = None,
recursive: bool = False) -> PurgeInstancesResult:
self._logger.info("Purging orchestrations by filter: "
f"created_time_from={created_time_from}, "
f"created_time_to={created_time_to}, "
f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
f"recursive={recursive}")
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(pb.PurgeInstancesRequest(
instanceId=None,
Comment thread
andystaples marked this conversation as resolved.
Outdated
purgeInstanceFilter=pb.PurgeInstanceFilter(
createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
runtimeStatus=[status.value for status in runtime_status] if runtime_status else None
),
Comment thread
andystaples marked this conversation as resolved.
recursive=recursive
))
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)

def signal_entity(self,
entity_instance_id: EntityInstanceId,
Expand All @@ -237,7 +336,7 @@ def signal_entity(self,
req = pb.SignalEntityRequest(
instanceId=str(entity_instance_id),
name=operation_name,
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None,
input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
requestId=str(uuid.uuid4()),
scheduledTime=None,
parentTraceContext=None,
Expand All @@ -256,4 +355,91 @@ def get_entity(self,
if not res.exists:
return None

return EntityMetadata.from_entity_response(res, include_state)
return EntityMetadata.from_entity_metadata(res.entity, include_state)

def get_all_entities(self,
include_state: bool = True,
include_transient: bool = False,
page_size: Optional[int] = None) -> List[EntityMetadata]:
return self.get_entities_by(
instance_id_starts_with=None,
last_modified_from=None,
last_modified_to=None,
include_state=include_state,
include_transient=include_transient,
page_size=page_size
)

def get_entities_by(self,
instance_id_starts_with: Optional[str] = None,
last_modified_from: Optional[datetime] = None,
last_modified_to: Optional[datetime] = None,
include_state: bool = True,
include_transient: bool = False,
page_size: Optional[int] = None,
_continuation_token: Optional[pb2.StringValue] = None
) -> List[EntityMetadata]:
Comment thread
andystaples marked this conversation as resolved.
Outdated
self._logger.info(f"Retrieving entities by filter: "
f"instance_id_starts_with={instance_id_starts_with}, "
f"last_modified_from={last_modified_from}, "
f"last_modified_to={last_modified_to}, "
f"include_state={include_state}, "
f"include_transient={include_transient}, "
f"page_size={page_size}")

entities = []

while True:
query_request = pb.QueryEntitiesRequest(
query=pb.EntityQuery(
instanceIdStartsWith=helpers.get_string_value(instance_id_starts_with),
lastModifiedFrom=helpers.new_timestamp(last_modified_from) if last_modified_from else None,
lastModifiedTo=helpers.new_timestamp(last_modified_to) if last_modified_to else None,
includeState=include_state,
includeTransient=include_transient,
pageSize=helpers.get_int_value(page_size),
continuationToken=_continuation_token
)
)
resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request)
entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next page of entities...")
if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
break
_continuation_token = resp.continuationToken
else:
break
return entities

def clean_entity_storage(self,
remove_empty_entities: bool = True,
release_orphaned_locks: bool = True,
_continuation_token: Optional[pb2.StringValue] = None
) -> CleanEntityStorageResult:
Comment thread
andystaples marked this conversation as resolved.
self._logger.info("Cleaning entity storage")

empty_entities_removed = 0
orphaned_locks_released = 0

while True:
req = pb.CleanEntityStorageRequest(
removeEmptyEntities=remove_empty_entities,
releaseOrphanedLocks=release_orphaned_locks,
continuationToken=_continuation_token
)
resp: pb.CleanEntityStorageResponse = self._stub.CleanEntityStorage(req)
empty_entities_removed += resp.emptyEntitiesRemoved
orphaned_locks_released += resp.orphanedLocksReleased

if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, cleaning next page...")
if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
break
_continuation_token = resp.continuationToken
else:
break

return CleanEntityStorageResult(empty_entities_removed, orphaned_locks_released)
14 changes: 9 additions & 5 deletions durabletask/entities/entity_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,22 @@ def __init__(self,

@staticmethod
def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool):
return EntityMetadata.from_entity_metadata(entity_response.entity, includes_state)

@staticmethod
def from_entity_metadata(entity: pb.EntityMetadata, includes_state: bool):
try:
entity_id = EntityInstanceId.parse(entity_response.entity.instanceId)
entity_id = EntityInstanceId.parse(entity.instanceId)
except ValueError:
raise ValueError("Invalid entity instance ID in entity response.")
entity_state = None
if includes_state:
entity_state = entity_response.entity.serializedState.value
entity_state = entity.serializedState.value
return EntityMetadata(
id=entity_id,
last_modified=entity_response.entity.lastModifiedTime.ToDatetime(timezone.utc),
backlog_queue_size=entity_response.entity.backlogQueueSize,
locked_by=entity_response.entity.lockedBy.value,
last_modified=entity.lastModifiedTime.ToDatetime(timezone.utc),
backlog_queue_size=entity.backlogQueueSize,
locked_by=entity.lockedBy.value,
includes_state=includes_state,
state=entity_state
)
Expand Down
7 changes: 7 additions & 0 deletions durabletask/internal/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ def get_string_value(val: Optional[str]) -> Optional[wrappers_pb2.StringValue]:
return wrappers_pb2.StringValue(value=val)


def get_int_value(val: Optional[int]) -> Optional[wrappers_pb2.Int32Value]:
if val is None:
return None
else:
return wrappers_pb2.Int32Value(value=val)


def get_string_value_or_empty(val: Optional[str]) -> wrappers_pb2.StringValue:
if val is None:
return wrappers_pb2.StringValue(value="")
Expand Down
Loading
Loading