Skip to content

Commit 28c6abb

Browse files
authored
Merge pull request #113 from singnet/SPS-14
Refactor and service managing
2 parents 7fc2a36 + 476873f commit 28c6abb

25 files changed

Lines changed: 876 additions & 334 deletions

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,8 @@ lint:
22
@ruff check . --fix
33
@ruff format .
44
.PHONY: lint
5+
6+
test:
7+
python -m coverage run -m pytest tests/ -v && \
8+
python -m coverage report
9+
.PHONY: test

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@ dependencies = [
3131
"snet-contracts==1.0.1",
3232
"lighthouseweb3~=0.1.4",
3333
"py-multihash~=3.0",
34+
"pydantic~=2.11",
3435
"pydantic-settings~=2.13"
3536
]
3637

3738
[tool.poetry.group.dev.dependencies]
3839
ruff = "^0.11"
3940
pytest = "^8.3"
41+
coverage = "^7.13"
4042

4143
[tool.ruff]
4244
line-length = 100

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ ipfshttpclient==0.4.13.2
99
snet-contracts==1.0.1
1010
lighthouseweb3~=0.1.4
1111
py-multihash~=3.0
12+
pydantic~=2.11
1213
pydantic-settings~=2.13

snet/sdk/__init__.py

Lines changed: 72 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
import sys
44
import warnings
55
from enum import Enum
6+
from pathlib import Path
7+
from typing import Union
68

79
import google.protobuf.internal.api_implementation
810
from google.protobuf import symbol_database as _symbol_database
911

10-
from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata
12+
from snet.sdk.registry.models import StorageType
13+
from snet.sdk.registry.organization_metadata import OrganizationMetadata
14+
from snet.sdk.registry.registry_contract import RegistryContract
15+
from snet.sdk.registry.service_metadata import MPEServiceMetadata, ServiceMetadata
1116

1217
with warnings.catch_warnings():
1318
# Suppress the eth-typing package`s warnings related to some new networks
@@ -18,9 +23,6 @@
1823
UserWarning,
1924
)
2025

21-
import web3
22-
23-
from snet.contracts import get_contract_object
2426
from snet.sdk.account import Account
2527
from snet.sdk.config import config
2628
from snet.sdk.client_lib_generator import ClientLibGenerator
@@ -34,12 +36,11 @@
3436
PaymentStrategy,
3537
)
3638
from snet.sdk.service_client import ServiceClient
37-
from snet.sdk.storage_provider.storage_provider import StorageProvider
38-
from snet.sdk.custom_typing import ModuleName, ServiceStub
39+
from snet.sdk.registry.storage_provider import StorageProvider
40+
from snet.sdk.types import ModuleName, ServiceStub
3941
from snet.sdk.utils.utils import (
40-
bytes32_to_str,
4142
find_file_by_keyword,
42-
type_converter,
43+
get_we3_object,
4344
)
4445

4546
google.protobuf.internal.api_implementation.Type = lambda: "python"
@@ -55,29 +56,13 @@ class PaymentStrategyType(Enum):
5556

5657

5758
class SnetSDK:
58-
"""Base Snet SDK"""
59-
6059
def __init__(self):
61-
self.web3 = web3.Web3(web3.HTTPProvider(config.ETH_RPC_ENDPOINT))
62-
63-
mpe_contract_address = config.MPE_CONTRACT_ADDRESS
64-
if not mpe_contract_address:
65-
self.mpe_contract = MPEContract(self.web3)
66-
else:
67-
self.mpe_contract = MPEContract(self.web3, mpe_contract_address)
68-
69-
registry_contract_address = config.REGISTRY_CONTRACT_ADDRESS
70-
if registry_contract_address is None:
71-
self.registry_contract = get_contract_object(self.web3, "Registry")
72-
else:
73-
self.registry_contract = get_contract_object(
74-
self.web3, "Registry", registry_contract_address
75-
)
76-
77-
self.metadata_provider = StorageProvider(self.registry_contract)
78-
79-
self.account = Account(self.web3, self.mpe_contract)
80-
self.payment_channel_provider = PaymentChannelProvider(self.web3, self.mpe_contract)
60+
self.w3 = get_we3_object()
61+
self.mpe_contract = MPEContract()
62+
self.registry_contract = RegistryContract()
63+
self.storage_provider = StorageProvider()
64+
self.payment_channel_provider = PaymentChannelProvider(self.mpe_contract)
65+
self.account = Account(self.mpe_contract.contract.address)
8166

8267
def create_service_client(
8368
self,
@@ -92,14 +77,14 @@ def create_service_client(
9277

9378
# Create and instance of the Config object,
9479
# so we can create an instance of ClientLibGenerator
95-
lib_generator = ClientLibGenerator(self.metadata_provider, org_id, service_id)
80+
lib_generator = ClientLibGenerator(self.storage_provider, org_id, service_id)
9681

9782
# Download the proto file and generate stubs if needed
9883
force_update = config.FORCE_UPDATE
9984
if force_update:
10085
lib_generator.generate_client_library()
10186
else:
102-
path_to_pb_files = lib_generator.protodir
87+
path_to_pb_files = lib_generator.proto_dir
10388
pb_2_file_name = find_file_by_keyword(
10489
path_to_pb_files, keyword="pb2.py", exclude=["training"]
10590
)
@@ -118,7 +103,7 @@ def create_service_client(
118103
if payment_strategy is None:
119104
payment_strategy = payment_strategy_type.value()
120105

121-
service_metadata = self.metadata_provider.enhance_service_metadata(org_id, service_id)
106+
service_metadata = self._enhance_service_metadata(org_id, service_id)
122107
group = self._get_service_group_details(service_metadata, group_name)
123108

124109
service_stubs = self.get_service_stub(lib_generator)
@@ -134,16 +119,29 @@ def create_service_client(
134119
options,
135120
self.mpe_contract,
136121
self.account,
137-
self.web3,
122+
self.w3,
138123
pb2_module,
139124
self.payment_channel_provider,
140-
lib_generator.protodir,
125+
lib_generator.proto_dir,
141126
lib_generator.training_added(),
142127
)
143128
return _service_client
144129

130+
def _enhance_service_metadata(self, org_id, service_id):
131+
service_metadata = self.get_service_metadata(org_id, service_id)
132+
org_metadata = self.get_organization_metadata(org_id)
133+
134+
org_group_map = {}
135+
for group in org_metadata.groups:
136+
org_group_map[group.group_name] = group
137+
138+
for group in service_metadata.groups:
139+
group.payment = org_group_map[group.group_name].payment
140+
141+
return service_metadata
142+
145143
def get_service_stub(self, lib_generator: ClientLibGenerator) -> list[ServiceStub]:
146-
path_to_pb_files = str(lib_generator.protodir)
144+
path_to_pb_files = str(lib_generator.proto_dir)
147145
module_name = self.get_module_by_keyword("pb2_grpc.py", lib_generator)
148146
sys.path.append(path_to_pb_files)
149147
try:
@@ -159,13 +157,18 @@ def get_service_stub(self, lib_generator: ClientLibGenerator) -> list[ServiceStu
159157
raise Exception(f"Error importing module: {e}")
160158

161159
def get_module_by_keyword(self, keyword: str, lib_generator: ClientLibGenerator) -> ModuleName:
162-
path_to_pb_files = lib_generator.protodir
160+
path_to_pb_files = lib_generator.proto_dir
163161
file_name = find_file_by_keyword(path_to_pb_files, keyword, exclude=["training"])
164162
module_name = os.path.splitext(file_name)[0]
165163
return ModuleName(module_name)
166164

167165
def get_service_metadata(self, org_id, service_id):
168-
return self.metadata_provider.fetch_service_metadata(org_id, service_id)
166+
service = self.registry_contract.get_service(org_id, service_id)
167+
return self.storage_provider.fetch_service_metadata(service.metadata_uri)
168+
169+
def get_organization_metadata(self, org_id: str) -> OrganizationMetadata:
170+
org = self.registry_contract.get_org(org_id)
171+
return self.storage_provider.fetch_org_metadata(org.metadata_uri)
169172

170173
def _get_first_group(self, service_metadata: MPEServiceMetadata) -> dict:
171174
return service_metadata["groups"][0]
@@ -176,7 +179,8 @@ def _get_group_by_group_name(
176179
for group in service_metadata["groups"]:
177180
if group["group_name"] == group_name:
178181
return group
179-
return {}
182+
# TODO: configure exceptions
183+
raise Exception()
180184

181185
def _get_service_group_details(
182186
self, service_metadata: MPEServiceMetadata, group_name: str
@@ -190,17 +194,34 @@ def _get_service_group_details(
190194
return self._get_group_by_group_name(service_metadata, group_name)
191195

192196
def get_organization_list(self) -> list:
193-
org_list = self.registry_contract.functions.listOrganizations().call()
194-
organization_list = []
195-
for idx, org_id in enumerate(org_list):
196-
organization_list.append(bytes32_to_str(org_id))
197-
return organization_list
197+
return self.registry_contract.list_orgs()
198198

199199
def get_services_list(self, org_id: str) -> list:
200-
found, org_service_list = self.registry_contract.functions.listServicesForOrganization(
201-
type_converter("bytes32")(org_id)
202-
).call()
203-
if not found:
204-
raise Exception(f"Organization with id={org_id} doesn't exist!")
205-
org_service_list = list(map(bytes32_to_str, org_service_list))
206-
return org_service_list
200+
return self.registry_contract.list_service_for_org(org_id)
201+
202+
def publish_service_comprehensively(
203+
self,
204+
org_id: str,
205+
service_id: str,
206+
metadata: ServiceMetadata,
207+
proto_dir: Union[str, Path],
208+
storage_type: StorageType = StorageType.IPFS,
209+
) -> bool:
210+
"""
211+
1. publish .proto files as .tar.gz archive into the storage
212+
2. add other fields to the service metadata
213+
3. validate service metadata
214+
4. publish service metadata into the storage
215+
5. publish service into Registry contract
216+
"""
217+
proto_uri = self.storage_provider.publish_proto(proto_dir, storage_type)
218+
219+
metadata.service_api_source = str(proto_uri)
220+
metadata.mpe_address = self.mpe_contract.contract.address
221+
222+
metadata_uri = self.storage_provider.publish_service_metadata(metadata, storage_type)
223+
receipt = self.registry_contract.create_service(
224+
self.account, org_id, service_id, metadata_uri
225+
)
226+
227+
return receipt["status"] != 0

0 commit comments

Comments
 (0)