Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 158 additions & 36 deletions paimon-python/pypaimon/table/row/generic_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# limitations under the License.
################################################################################

import calendar
import decimal
import struct
from datetime import date, datetime, time, timedelta
from decimal import Decimal
Expand All @@ -29,6 +31,63 @@
from pypaimon.table.row.blob import BlobData


_DECIMAL_CTX = decimal.Context(prec=100, rounding=decimal.ROUND_HALF_UP)


def _decimal_to_unscaled_with_check(d: Decimal, precision: int, scale: int):
"""Round decimal with HALF_UP, check precision overflow, and return unscaled value.
Returns (unscaled_int, True) on overflow, (unscaled_int, False) on success."""
rounded = d.quantize(Decimal(10) ** -scale, context=_DECIMAL_CTX)
sign, digits, exponent = rounded.as_tuple()
# Precision overflow check
if rounded != 0 and len(digits) > precision:
return 0, True
int_digits = int(''.join(str(x) for x in digits)) if digits != (0,) else 0
shift = exponent + scale
if shift >= 0:
unscaled = int_digits * (10 ** shift)
else:
unscaled = int_digits // (10 ** (-shift))
return (-unscaled if sign else unscaled), False


def _parse_type_precision_scale(data_type):
"""Parse precision and scale from type string like DECIMAL(38, 10)."""
type_str = str(data_type)
if '(' in type_str and ')' in type_str:
try:
params_str = type_str.split('(')[1].split(')')[0]
parts = [p.strip() for p in params_str.split(',')]
precision = int(parts[0])
scale = int(parts[1]) if len(parts) > 1 else 0
return precision, scale
except (ValueError, IndexError):
return 0, 0
return 0, 0


_EPOCH = datetime(1970, 1, 1)


def _datetime_to_millis_and_nanos(value: datetime):
"""Convert datetime to (epoch_millis, nano_of_millisecond) without float arithmetic."""
epoch_seconds = calendar.timegm(value.timetuple())
millis = epoch_seconds * 1000 + value.microsecond // 1000
nano_of_millisecond = (value.microsecond % 1000) * 1000
return millis, nano_of_millisecond


def _millis_nanos_to_datetime(millis: int, nano_of_millisecond: int = 0) -> datetime:
"""Convert (epoch_millis, nano_of_millisecond) to datetime. Nanos truncated to micros."""
total_micros = millis * 1000 + nano_of_millisecond // 1000
seconds = total_micros // 1_000_000
micros = total_micros % 1_000_000
if micros < 0:
seconds -= 1
micros += 1_000_000
return _EPOCH + timedelta(seconds=seconds, microseconds=micros)


@dataclass
class GenericRow(InternalRow):

Expand Down Expand Up @@ -233,26 +292,49 @@ def _parse_blob(cls, bytes_data: bytes, base_offset: int, field_offset: int) ->
return BlobData.from_bytes(binary_data)

@classmethod
def _parse_decimal(cls, bytes_data: bytes, base_offset: int, field_offset: int, data_type: DataType) -> Decimal:
unscaled_long = struct.unpack('<q', bytes_data[field_offset:field_offset + 8])[0]
type_str = str(data_type)
if '(' in type_str and ')' in type_str:
try:
precision_scale = type_str.split('(')[1].split(')')[0]
if ',' in precision_scale:
scale = int(precision_scale.split(',')[1])
else:
scale = 0
except:
scale = 0
def _unscaled_to_decimal(cls, unscaled_value: int, scale: int) -> Decimal:
sign = 0 if unscaled_value >= 0 else 1
digits = tuple(int(d) for d in str(abs(unscaled_value))) if unscaled_value != 0 else (0,)
return Decimal((sign, digits, -scale))

@classmethod
def _parse_decimal(cls, bytes_data: bytes, base_offset: int, field_offset: int, data_type: DataType):
precision, scale = _parse_type_precision_scale(data_type)
if precision <= 0:
raise ValueError(f"Decimal requires precision > 0, got {precision}")
if precision <= 18:
# Compact: unscaled long in fixed part
unscaled_long = struct.unpack('<q', bytes_data[field_offset:field_offset + 8])[0]
return cls._unscaled_to_decimal(unscaled_long, scale)
else:
scale = 0
return Decimal(unscaled_long) / (10 ** scale)
# Non-compact: (cursor << 32 | byte_length) in fixed part, bytes in var area
offset_and_len = struct.unpack('<q', bytes_data[field_offset:field_offset + 8])[0]
cursor = (offset_and_len >> 32) & 0xFFFFFFFF
byte_length = offset_and_len & 0xFFFFFFFF
var_offset = base_offset + cursor
unscaled_bytes = bytes_data[var_offset:var_offset + byte_length]
unscaled_value = int.from_bytes(unscaled_bytes, byteorder='big', signed=True)
# Precision overflow returns null
result = cls._unscaled_to_decimal(unscaled_value, scale)
_, digits, _ = result.as_tuple()
if result != 0 and len(digits) > precision:
return None
return result

@classmethod
def _parse_timestamp(cls, bytes_data: bytes, base_offset: int, field_offset: int, data_type: DataType) -> datetime:
millis = struct.unpack('<q', bytes_data[field_offset:field_offset + 8])[0]
return datetime.fromtimestamp(millis / 1000.0, tz=None)
precision, _ = _parse_type_precision_scale(data_type)
if precision <= 3:
# Compact: epoch millis in fixed part
millis = struct.unpack('<q', bytes_data[field_offset:field_offset + 8])[0]
return _millis_nanos_to_datetime(millis)
else:
# Non-compact: (cursor << 32 | nanoOfMillisecond) in fixed part, millis in var area
offset_and_nanos = struct.unpack('<q', bytes_data[field_offset:field_offset + 8])[0]
nano_of_millisecond = offset_and_nanos & 0xFFFFFFFF
sub_offset = (offset_and_nanos >> 32) & 0xFFFFFFFF
millis = struct.unpack('<q', bytes_data[base_offset + sub_offset:base_offset + sub_offset + 8])[0]
return _millis_nanos_to_datetime(millis, nano_of_millisecond)

@classmethod
def _parse_date(cls, bytes_data: bytes, field_offset: int) -> date:
Expand Down Expand Up @@ -301,24 +383,64 @@ def to_bytes(cls, row: Union[GenericRow, BinaryRow]) -> bytes:
raise ValueError(f"BinaryRow only support AtomicType yet, meet {field.type.__class__}")

type_name = field.type.type.upper()
if any(type_name.startswith(p) for p in ['CHAR', 'VARCHAR', 'STRING',
'BINARY', 'VARBINARY', 'BYTES', 'BLOB']):
if any(type_name.startswith(p) for p in ['CHAR', 'VARCHAR', 'STRING']):
is_var_len_type = any(type_name.startswith(p) for p in [
'CHAR', 'VARCHAR', 'STRING', 'BINARY', 'VARBINARY', 'BYTES', 'BLOB'])
is_decimal_type = type_name.startswith('DECIMAL') or type_name.startswith('NUMERIC')
is_timestamp_type = type_name.startswith('TIMESTAMP')
decimal_precision, decimal_scale = _parse_type_precision_scale(field.type) if is_decimal_type else (0, 0)
is_high_precision_decimal = is_decimal_type and decimal_precision > 18
timestamp_precision = _parse_type_precision_scale(field.type)[0] if is_timestamp_type else 0
is_non_compact_timestamp = is_timestamp_type and timestamp_precision > 3

# Precision overflow -> null
if is_decimal_type and value is not None:
d = value if isinstance(value, Decimal) else Decimal(str(value))
unscaled_value, overflow = _decimal_to_unscaled_with_check(d, decimal_precision, decimal_scale)
if overflow:
cls._set_null_bit(fixed_part, 0, i)
struct.pack_into('<q', fixed_part, field_fixed_offset, 0)
continue

if is_non_compact_timestamp:
# Non-compact: millis in var area, (offset << 32 | nanoOfMilli) in fixed part
if value.tzinfo is not None:
raise RuntimeError("datetime tzinfo not supported yet")
ts_millis, nano_of_millisecond = _datetime_to_millis_and_nanos(value)
var_value_bytes = struct.pack('<q', ts_millis)
offset_in_variable_part = current_variable_offset
variable_part_data.append(var_value_bytes)
current_variable_offset += 8
absolute_offset = fixed_part_size + offset_in_variable_part
offset_and_nano = (absolute_offset << 32) | nano_of_millisecond
struct.pack_into('<q', fixed_part, field_fixed_offset, offset_and_nano)
elif is_var_len_type or is_high_precision_decimal:
if is_high_precision_decimal:
# Big-endian signed bytes
if unscaled_value == 0:
value_bytes = b'\x00'
else:
byte_length = (unscaled_value.bit_length() + 8) // 8 # +8 for sign bit
value_bytes = unscaled_value.to_bytes(byte_length, byteorder='big', signed=True)
elif any(type_name.startswith(p) for p in ['CHAR', 'VARCHAR', 'STRING']):
value_bytes = str(value).encode('utf-8')
elif type_name == 'BLOB':
value_bytes = value.to_data()
else:
value_bytes = bytes(value)

length = len(value_bytes)
if length <= cls.MAX_FIX_PART_DATA_SIZE:
if length <= cls.MAX_FIX_PART_DATA_SIZE and not is_high_precision_decimal:
fixed_part[field_fixed_offset: field_fixed_offset + length] = value_bytes
for j in range(length, 7):
fixed_part[field_fixed_offset + j] = 0
header_byte = 0x80 | length
fixed_part[field_fixed_offset + 7] = header_byte
else:
var_length = cls._round_number_of_bytes_to_nearest_word(len(value_bytes))
# Non-compact decimal: fixed 16 bytes; others: 8-byte aligned
if is_high_precision_decimal:
var_length = 16
else:
var_length = cls._round_number_of_bytes_to_nearest_word(len(value_bytes))
var_value_bytes = value_bytes + b'\x00' * (var_length - length)
offset_in_variable_part = current_variable_offset
variable_part_data.append(var_value_bytes)
Expand Down Expand Up @@ -365,8 +487,18 @@ def _serialize_field_value(cls, value: Any, data_type: AtomicType) -> bytes:
elif type_name in ['DOUBLE']:
return cls._serialize_double(value)
elif type_name.startswith('DECIMAL') or type_name.startswith('NUMERIC'):
precision, _ = _parse_type_precision_scale(data_type)
if precision > 18:
raise ValueError(
f"Non-compact decimal (precision={precision}) must be serialized "
f"via the variable-length path in to_bytes(), not _serialize_field_value()")
return cls._serialize_decimal(value, data_type)
elif type_name.startswith('TIMESTAMP'):
precision = _parse_type_precision_scale(data_type)[0]
if precision > 3:
raise ValueError(
f"Non-compact timestamp (precision={precision}) must be serialized "
f"via the variable-length path in to_bytes(), not _serialize_field_value()")
return cls._serialize_timestamp(value)
elif type_name in ['DATE']:
return cls._serialize_date(value) + b'\x00' * 4
Expand Down Expand Up @@ -405,27 +537,17 @@ def _serialize_double(cls, value: float) -> bytes:

@classmethod
def _serialize_decimal(cls, value: Decimal, data_type: DataType) -> bytes:
type_str = str(data_type)
if '(' in type_str and ')' in type_str:
try:
precision_scale = type_str.split('(')[1].split(')')[0]
if ',' in precision_scale:
scale = int(precision_scale.split(',')[1])
else:
scale = 0
except:
scale = 0
else:
scale = 0

unscaled_value = int(value * (10 ** scale))
"""Compact decimal: unscaled long in fixed part."""
precision, scale = _parse_type_precision_scale(data_type)
d = value if isinstance(value, Decimal) else Decimal(str(value))
unscaled_value, _ = _decimal_to_unscaled_with_check(d, precision, scale)
return struct.pack('<q', unscaled_value)

@classmethod
def _serialize_timestamp(cls, value: datetime) -> bytes:
if value.tzinfo is not None:
raise RuntimeError("datetime tzinfo not supported yet")
millis = int(value.timestamp() * 1000)
millis, _ = _datetime_to_millis_and_nanos(value)
return struct.pack('<q', millis)

@classmethod
Expand Down
Loading
Loading