Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,8 @@ lint:
@ruff check . --fix
@ruff format .
.PHONY: lint

test:
python -m coverage run -m pytest tests/ -v && \
python -m coverage report
.PHONY: test
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ dependencies = [
"snet-contracts==1.0.1",
"lighthouseweb3~=0.1.4",
"py-multihash~=3.0",
"pydantic~=2.11",
"pydantic-settings~=2.13"
]

[tool.poetry.group.dev.dependencies]
ruff = "^0.11"
pytest = "^8.3"
coverage = "^7.13"

[tool.ruff]
line-length = 100
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ ipfshttpclient==0.4.13.2
snet-contracts==1.0.1
lighthouseweb3~=0.1.4
py-multihash~=3.0
pydantic~=2.11
pydantic-settings~=2.13
123 changes: 72 additions & 51 deletions snet/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
import sys
import warnings
from enum import Enum
from pathlib import Path
from typing import Union

import google.protobuf.internal.api_implementation
from google.protobuf import symbol_database as _symbol_database

from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata
from snet.sdk.registry.models import StorageType
from snet.sdk.registry.organization_metadata import OrganizationMetadata
from snet.sdk.registry.registry_contract import RegistryContract
from snet.sdk.registry.service_metadata import MPEServiceMetadata, ServiceMetadata

with warnings.catch_warnings():
# Suppress the eth-typing package`s warnings related to some new networks
Expand All @@ -18,9 +23,6 @@
UserWarning,
)

import web3

from snet.contracts import get_contract_object
from snet.sdk.account import Account
from snet.sdk.config import config
from snet.sdk.client_lib_generator import ClientLibGenerator
Expand All @@ -34,12 +36,11 @@
PaymentStrategy,
)
from snet.sdk.service_client import ServiceClient
from snet.sdk.storage_provider.storage_provider import StorageProvider
from snet.sdk.custom_typing import ModuleName, ServiceStub
from snet.sdk.registry.storage_provider import StorageProvider
from snet.sdk.types import ModuleName, ServiceStub
from snet.sdk.utils.utils import (
bytes32_to_str,
find_file_by_keyword,
type_converter,
get_we3_object,
)

google.protobuf.internal.api_implementation.Type = lambda: "python"
Expand All @@ -55,29 +56,13 @@


class SnetSDK:
"""Base Snet SDK"""

def __init__(self):
self.web3 = web3.Web3(web3.HTTPProvider(config.ETH_RPC_ENDPOINT))

mpe_contract_address = config.MPE_CONTRACT_ADDRESS
if not mpe_contract_address:
self.mpe_contract = MPEContract(self.web3)
else:
self.mpe_contract = MPEContract(self.web3, mpe_contract_address)

registry_contract_address = config.REGISTRY_CONTRACT_ADDRESS
if registry_contract_address is None:
self.registry_contract = get_contract_object(self.web3, "Registry")
else:
self.registry_contract = get_contract_object(
self.web3, "Registry", registry_contract_address
)

self.metadata_provider = StorageProvider(self.registry_contract)

self.account = Account(self.web3, self.mpe_contract)
self.payment_channel_provider = PaymentChannelProvider(self.web3, self.mpe_contract)
self.w3 = get_we3_object()
self.mpe_contract = MPEContract()
self.registry_contract = RegistryContract()
self.storage_provider = StorageProvider()
self.payment_channel_provider = PaymentChannelProvider(self.mpe_contract)
self.account = Account(self.mpe_contract.contract.address)

def create_service_client(
self,
Expand All @@ -92,14 +77,14 @@

# Create and instance of the Config object,
# so we can create an instance of ClientLibGenerator
lib_generator = ClientLibGenerator(self.metadata_provider, org_id, service_id)
lib_generator = ClientLibGenerator(self.storage_provider, org_id, service_id)

# Download the proto file and generate stubs if needed
force_update = config.FORCE_UPDATE
if force_update:
lib_generator.generate_client_library()
else:
path_to_pb_files = lib_generator.protodir
path_to_pb_files = lib_generator.proto_dir
pb_2_file_name = find_file_by_keyword(
path_to_pb_files, keyword="pb2.py", exclude=["training"]
)
Expand All @@ -118,7 +103,7 @@
if payment_strategy is None:
payment_strategy = payment_strategy_type.value()

service_metadata = self.metadata_provider.enhance_service_metadata(org_id, service_id)
service_metadata = self._enhance_service_metadata(org_id, service_id)
group = self._get_service_group_details(service_metadata, group_name)

service_stubs = self.get_service_stub(lib_generator)
Expand All @@ -134,16 +119,29 @@
options,
self.mpe_contract,
self.account,
self.web3,
self.w3,
pb2_module,
self.payment_channel_provider,
lib_generator.protodir,
lib_generator.proto_dir,
lib_generator.training_added(),
)
return _service_client

def _enhance_service_metadata(self, org_id, service_id):
service_metadata = self.get_service_metadata(org_id, service_id)
org_metadata = self.get_organization_metadata(org_id)

org_group_map = {}
for group in org_metadata.groups:
org_group_map[group.group_name] = group

for group in service_metadata.groups:
group.payment = org_group_map[group.group_name].payment

return service_metadata

def get_service_stub(self, lib_generator: ClientLibGenerator) -> list[ServiceStub]:
path_to_pb_files = str(lib_generator.protodir)
path_to_pb_files = str(lib_generator.proto_dir)
module_name = self.get_module_by_keyword("pb2_grpc.py", lib_generator)
sys.path.append(path_to_pb_files)
try:
Expand All @@ -159,13 +157,18 @@
raise Exception(f"Error importing module: {e}")

def get_module_by_keyword(self, keyword: str, lib_generator: ClientLibGenerator) -> ModuleName:
path_to_pb_files = lib_generator.protodir
path_to_pb_files = lib_generator.proto_dir
file_name = find_file_by_keyword(path_to_pb_files, keyword, exclude=["training"])
module_name = os.path.splitext(file_name)[0]
return ModuleName(module_name)

def get_service_metadata(self, org_id, service_id):
return self.metadata_provider.fetch_service_metadata(org_id, service_id)
service = self.registry_contract.get_service(org_id, service_id)
return self.storage_provider.fetch_service_metadata(service.metadata_uri)

def get_organization_metadata(self, org_id: str) -> OrganizationMetadata:
org = self.registry_contract.get_org(org_id)
return self.storage_provider.fetch_org_metadata(org.metadata_uri)

def _get_first_group(self, service_metadata: MPEServiceMetadata) -> dict:
return service_metadata["groups"][0]
Expand All @@ -176,7 +179,8 @@
for group in service_metadata["groups"]:
if group["group_name"] == group_name:
return group
return {}
# TODO: configure exceptions

Check notice on line 182 in snet/sdk/__init__.py

View check run for this annotation

snet-sonarqube-app / SonarQube Code Analysis

snet/sdk/__init__.py#L182

Complete the task associated to this "TODO" comment.
raise Exception()

Check warning on line 183 in snet/sdk/__init__.py

View check run for this annotation

snet-sonarqube-app / SonarQube Code Analysis

snet/sdk/__init__.py#L183

Replace this generic exception class with a more specific one.

def _get_service_group_details(
self, service_metadata: MPEServiceMetadata, group_name: str
Expand All @@ -190,17 +194,34 @@
return self._get_group_by_group_name(service_metadata, group_name)

def get_organization_list(self) -> list:
org_list = self.registry_contract.functions.listOrganizations().call()
organization_list = []
for idx, org_id in enumerate(org_list):
organization_list.append(bytes32_to_str(org_id))
return organization_list
return self.registry_contract.list_orgs()

def get_services_list(self, org_id: str) -> list:
found, org_service_list = self.registry_contract.functions.listServicesForOrganization(
type_converter("bytes32")(org_id)
).call()
if not found:
raise Exception(f"Organization with id={org_id} doesn't exist!")
org_service_list = list(map(bytes32_to_str, org_service_list))
return org_service_list
return self.registry_contract.list_service_for_org(org_id)

def publish_service_comprehensively(
self,
org_id: str,
service_id: str,
metadata: ServiceMetadata,
proto_dir: Union[str, Path],
storage_type: StorageType = StorageType.IPFS,
) -> bool:
"""
1. publish .proto files as .tar.gz archive into the storage
2. add other fields to the service metadata
3. validate service metadata
4. publish service metadata into the storage
5. publish service into Registry contract
"""
proto_uri = self.storage_provider.publish_proto(proto_dir, storage_type)

metadata.service_api_source = str(proto_uri)
metadata.mpe_address = self.mpe_contract.contract.address

metadata_uri = self.storage_provider.publish_service_metadata(metadata, storage_type)
receipt = self.registry_contract.create_service(
self.account, org_id, service_id, metadata_uri
)

return receipt["status"] != 0
Loading
Loading