Skip to content

Commit aa4e8ee

Browse files
authored
Merge pull request #836 from superannotateai/fix_methods
fix related pydantic
2 parents 17e0ba5 + ca1e5f6 commit aa4e8ee

File tree

4 files changed

+70
-15
lines changed

4 files changed

+70
-15
lines changed

src/superannotate/lib/core/entities/base.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,21 @@ def _validate_hex_color(v: str) -> str:
3131
HexColor = Annotated[str, AfterValidator(_validate_hex_color)]
3232

3333

34-
def _validate_string_date(v: datetime) -> str:
35-
"""Convert datetime to string format."""
34+
def _validate_string_date(v: Union[datetime, str]) -> str:
35+
"""Convert datetime to string format. For case data output."""
3636
if isinstance(v, str):
37-
return v
38-
return v.isoformat().split("+")[0] + ".000Z"
37+
try:
38+
dt = datetime.fromisoformat(v.replace("Z", "+00:00"))
39+
return dt.strftime("%Y-%m-%dT%H:%M:%S+00:00")
40+
except (ValueError, AttributeError):
41+
return v
42+
elif isinstance(v, datetime):
43+
return v.strftime("%Y-%m-%dT%H:%M:%S+00:00")
44+
return v
3945

4046

41-
def _serialize_string_date(v) -> str:
42-
"""Serialize datetime or string to string format."""
47+
def _serialize_string_date(v: Union[datetime, str]) -> str:
48+
"""Serialize datetime or string to string format. For case data input."""
4349
if isinstance(v, str):
4450
return v
4551
if isinstance(v, datetime):
@@ -127,7 +133,7 @@ def _validate_token(value: str) -> str:
127133
class ConfigEntity(BaseModel):
128134
model_config = ConfigDict(extra="ignore")
129135

130-
API_TOKEN: str = Field(alias="SA_TOKEN")
136+
API_TOKEN: TokenStr = Field(alias="SA_TOKEN")
131137
API_URL: str = Field(alias="SA_URL", default=BACKEND_URL)
132138
LOGGING_LEVEL: Literal[
133139
"NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"

src/superannotate/lib/core/enums.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
from enum import Enum
33
from types import DynamicClassAttribute
44

5+
from pydantic import GetCoreSchemaHandler
6+
from pydantic_core import core_schema
7+
from pydantic_core import PydanticCustomError
8+
59

610
class classproperty: # noqa
711
def __init__(self, getter):
@@ -33,6 +37,47 @@ def value(self) -> int:
3337
def __unicode__(self):
3438
return self.__doc__
3539

40+
@classmethod
41+
def __get_pydantic_core_schema__(cls, source_type, handler: GetCoreSchemaHandler):
42+
"""Customize Pydantic v2 validation to show titles in error messages."""
43+
return core_schema.no_info_after_validator_function(
44+
cls._validate,
45+
core_schema.union_schema(
46+
[
47+
core_schema.is_instance_schema(cls),
48+
core_schema.int_schema(),
49+
core_schema.str_schema(),
50+
]
51+
),
52+
serialization=core_schema.plain_serializer_function_ser_schema(
53+
lambda x: x.value if isinstance(x, cls) else x, when_used="json"
54+
),
55+
)
56+
57+
@classmethod
58+
def _validate(cls, value):
59+
"""Validate and convert value to enum member."""
60+
if isinstance(value, cls):
61+
return value
62+
63+
# Try to find by value
64+
if isinstance(value, int):
65+
for enum in cls:
66+
if enum.value == value:
67+
return enum
68+
69+
# Try to find by title or name
70+
if isinstance(value, str):
71+
for enum in cls:
72+
if enum.__doc__ and enum.__doc__.lower() == value.lower():
73+
return enum
74+
if value in cls.__members__:
75+
return cls.__members__[value]
76+
77+
# Build error message with titles
78+
available = ", ".join(f"{enum.__doc__.lower()}" for enum in cls if enum.__doc__)
79+
raise PydanticCustomError("enum", f"Input should be: {available}")
80+
3681
@classmethod
3782
def choices(cls) -> typing.Tuple[str]:
3883
"""Return all titles as choices."""

tests/integration/classes/test_create_update_annotation_class.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,14 @@ def test_update_annotation_class_without_group_type(self):
658658
)
659659
assert res["attribute_groups"][0]["group_type"] == "radio"
660660

661+
def test_create_with_invalid_type(self):
662+
try:
663+
sa.create_annotation_class(
664+
self.PROJECT_NAME, "tt", "#FFFFFF", class_type="invalid"
665+
)
666+
except AppException as e:
667+
assert "Input should be: object, tag, relationship" in str(e)
668+
661669

662670
class TestVideoCreateAnnotationClasses(BaseTestCase):
663671
PROJECT_NAME = "TestVideoCreateAnnotationClasses"

tests/unit/test_init.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class ClientInitTestCase(TestCase):
1616

1717
def test_init_via_invalid_token(self):
1818
_token = "123"
19-
with self.assertRaisesRegex(AppException, r"(\s+)token(\s+)Invalid token."):
19+
with self.assertRaisesRegex(AppException, r"Invalid token\."):
2020
SAClient(token=_token)
2121

2222
@patch("lib.infrastructure.controller.Controller.get_current_user")
@@ -61,9 +61,7 @@ def test_init_via_config_json_invalid_json(self):
6161
with open(f"{config_dir}/config.json", "w") as config_json:
6262
json.dump({"token": "INVALID_TOKEN"}, config_json)
6363
for kwargs in ({}, {"config_path": f"{config_dir}/config.json"}):
64-
with self.assertRaisesRegex(
65-
AppException, r"(\s+)token(\s+)Invalid token."
66-
):
64+
with self.assertRaisesRegex(AppException, r"Invalid token\."):
6765
SAClient(**kwargs)
6866

6967
@patch("lib.infrastructure.controller.Controller.get_current_user")
@@ -137,7 +135,7 @@ def test_init_env(self, get_team, get_current_user):
137135

138136
@patch.dict(os.environ, {"SA_URL": "SOME_URL", "SA_TOKEN": "SOME_TOKEN"})
139137
def test_init_env_invalid_token(self):
140-
with self.assertRaisesRegex(AppException, r"(\s+)SA_TOKEN(\s+)Invalid token."):
138+
with self.assertRaisesRegex(AppException, r"Invalid token\."):
141139
SAClient()
142140

143141
def test_init_via_config_ini_invalid_token(self):
@@ -157,9 +155,7 @@ def test_init_via_config_ini_invalid_token(self):
157155
config_parser.write(config_ini)
158156

159157
for kwargs in ({}, {"config_path": f"{config_dir}/config.ini"}):
160-
with self.assertRaisesRegex(
161-
AppException, r"(\s+)SA_TOKEN(\s+)Invalid token."
162-
):
158+
with self.assertRaisesRegex(AppException, r"Invalid token\."):
163159
SAClient(**kwargs)
164160

165161
def test_invalid_config_path(self):

0 commit comments

Comments
 (0)