55from dataclasses import dataclass
66from typing import Any
77
8- from arrayfire import backend, safe_call # TODO refactoring
9- from arrayfire.array import _in_display_dims_limit # TODO refactoring
8+ from arrayfire import backend, safe_call # TODO refactor
9+ from arrayfire.algorithm import count # TODO refactor
10+ from arrayfire.array import _get_indices, _in_display_dims_limit # TODO refactor
1011
1112from ._dtypes import CShape, Dtype
1213from ._dtypes import bool as af_bool
@@ -37,15 +38,15 @@ class Array:
3738 # arrayfire's __radd__() instead of numpy's __add__()
3839 __array_priority__ = 30
3940
40- # Initialisation
41- arr = ctypes.c_void_p(0)
42-
4341 def __init__(
4442 self, x: None | Array | py_array.array | int | ctypes.c_void_p | list = None, dtype: None | Dtype = None,
4543 pointer_source: PointerSource = PointerSource.host, shape: None | ShapeType = None,
4644 offset: None | ctypes._SimpleCData[int] = None, strides: None | ShapeType = None) -> None:
4745 _no_initial_dtype = False # HACK, FIXME
4846
47+ # Initialise array object
48+ self.arr = ctypes.c_void_p(0)
49+
4950 if isinstance(dtype, str):
5051 dtype = _str_to_dtype(dtype)
5152
@@ -127,7 +128,7 @@ def __str__(self) -> str: # FIXME
127128 if not _in_display_dims_limit(self.shape):
128129 return _metadata_string(self.dtype, self.shape)
129130
130- return _metadata_string(self.dtype) + self._as_str( )
131+ return _metadata_string(self.dtype) + _array_as_str(self )
131132
132133 def __repr__(self) -> str: # FIXME
133134 return _metadata_string(self.dtype, self.shape)
@@ -173,6 +174,7 @@ def __truediv__(self, other: int | float | bool | complex | Array, /) -> Array:
173174 return _process_c_function(self, other, backend.get().af_div)
174175
175176 def __floordiv__(self, other: int | float | bool | complex | Array, /) -> Array:
177+ # TODO
176178 return NotImplemented
177179
178180 def __mod__(self, other: int | float | bool | complex | Array, /) -> Array:
@@ -187,6 +189,25 @@ def __pow__(self, other: int | float | bool | complex | Array, /) -> Array:
187189 """
188190 return _process_c_function(self, other, backend.get().af_pow)
189191
192+ def __matmul__(self, other: Array, /) -> Array:
193+ # TODO
194+ return NotImplemented
195+
196+ def __getitem__(self, key: int | slice | tuple[int | slice] | Array, /) -> Array:
197+ # TODO: API Specification - key: int | slice | ellipsis | tuple[int | slice] | Array
198+ # TODO: refactor
199+ out = Array()
200+ ndims = self.ndim
201+
202+ if isinstance(key, Array) and key == af_bool.c_api_value:
203+ ndims = 1
204+ if count(key) == 0:
205+ return out
206+
207+ safe_call(backend.get().af_index_gen(
208+ ctypes.pointer(out.arr), self.arr, c_dim_t(ndims), _get_indices(key).pointer))
209+ return out
210+
190211 @property
191212 def dtype(self) -> Dtype:
192213 out = ctypes.c_int()
@@ -234,13 +255,23 @@ def shape(self) -> ShapeType:
234255 ctypes.pointer(d0), ctypes.pointer(d1), ctypes.pointer(d2), ctypes.pointer(d3), self.arr))
235256 return (d0.value, d1.value, d2.value, d3.value)[:self.ndim] # Skip passing None values
236257
237- def _as_str(self) -> str:
238- arr_str = ctypes.c_char_p(0)
239- # FIXME add description to passed arguments
240- safe_call(backend.get().af_array_to_string(ctypes.pointer(arr_str), "", self.arr, 4, True))
241- py_str = to_str(arr_str)
242- safe_call(backend.get().af_free_host(arr_str))
243- return py_str
258+ def scalar(self) -> int | float | bool | complex:
259+ """
260+ Return the first element of the array
261+ """
262+ # BUG seg fault on empty array
263+ out = self.dtype.c_type()
264+ safe_call(backend.get().af_get_scalar(ctypes.pointer(out), self.arr))
265+ return out.value # type: ignore[no-any-return] # FIXME
266+
267+
268+ def _array_as_str(array: Array) -> str:
269+ arr_str = ctypes.c_char_p(0)
270+ # FIXME add description to passed arguments
271+ safe_call(backend.get().af_array_to_string(ctypes.pointer(arr_str), "", array.arr, 4, True))
272+ py_str = to_str(arr_str)
273+ safe_call(backend.get().af_free_host(arr_str))
274+ return py_str
244275
245276
246277def _metadata_string(dtype: Dtype, dims: None | ShapeType = None) -> str:
@@ -283,9 +314,8 @@ def _process_c_function(
283314 if isinstance(other, Array):
284315 safe_call(c_function(ctypes.pointer(out.arr), target.arr, other.arr, _bcast_var))
285316 elif is_number(other):
286- target_c_shape = CShape(*target.shape)
287317 other_dtype = _implicit_dtype(other, target.dtype)
288- other_array = _constant_array(other, target_c_shape , other_dtype)
318+ other_array = _constant_array(other, CShape(*target.shape) , other_dtype)
289319 safe_call(c_function(ctypes.pointer(out.arr), target.arr, other_array.arr, _bcast_var))
290320 else:
291321 raise TypeError(f"{type(other)} is not supported and can not be passed to C binary function.")
@@ -326,7 +356,7 @@ def _constant_array(value: int | float | bool | complex, shape: CShape, dtype: D
326356
327357 safe_call(backend.get().af_constant_complex(
328358 ctypes.pointer(out.arr), ctypes.c_double(value.real), ctypes.c_double(value.imag), 4,
329- ctypes.pointer(shape.c_array), dtype))
359+ ctypes.pointer(shape.c_array), dtype.c_api_value ))
330360 elif dtype == af_int64:
331361 safe_call(backend.get().af_constant_long(
332362 ctypes.pointer(out.arr), ctypes.c_longlong(value.real), 4, ctypes.pointer(shape.c_array)))
@@ -335,6 +365,6 @@ def _constant_array(value: int | float | bool | complex, shape: CShape, dtype: D
335365 ctypes.pointer(out.arr), ctypes.c_ulonglong(value.real), 4, ctypes.pointer(shape.c_array)))
336366 else:
337367 safe_call(backend.get().af_constant(
338- ctypes.pointer(out.arr), ctypes.c_double(value), 4, ctypes.pointer(shape.c_array), dtype))
368+ ctypes.pointer(out.arr), ctypes.c_double(value), 4, ctypes.pointer(shape.c_array), dtype.c_api_value ))
339369
340370 return out
0 commit comments