Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions crates/core/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ use arrow::array::{Array, ArrayRef};
use arrow::datatypes::{Field, FieldRef};
use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
use arrow::pyarrow::ToPyArrow;
use datafusion_python_util::validate_pycapsule;
use pyo3::ffi::c_str;
use pyo3::prelude::{PyAnyMethods, PyCapsuleMethods};
use pyo3::types::PyCapsule;
use pyo3::{Bound, PyAny, PyResult, Python, pyclass, pymethods};
Expand Down Expand Up @@ -53,10 +51,8 @@ impl PyArrowArrayExportable {
requested_schema: Option<Bound<'py, PyCapsule>>,
) -> PyDataFusionResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> {
let field = if let Some(schema_capsule) = requested_schema {
validate_pycapsule(&schema_capsule, "arrow_schema")?;

let data: NonNull<FFI_ArrowSchema> = schema_capsule
.pointer_checked(Some(c_str!("arrow_schema")))?
.pointer_checked(Some(c"arrow_schema"))?
.cast();
let schema_ptr = unsafe { data.as_ref() };
let desired_field = Field::try_from(schema_ptr)?;
Expand Down
11 changes: 3 additions & 8 deletions crates/core/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,10 @@ use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use datafusion_ffi::schema_provider::FFI_SchemaProvider;
use datafusion_python_util::{
create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, validate_pycapsule,
wait_for_future,
create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, wait_for_future,
};
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::PyKeyError;
use pyo3::ffi::c_str;
use pyo3::prelude::*;
use pyo3::types::PyCapsule;

Expand Down Expand Up @@ -659,9 +657,8 @@ fn extract_catalog_provider_from_pyobj(
}

let provider = if let Ok(capsule) = catalog_provider.cast::<PyCapsule>() {
validate_pycapsule(capsule, "datafusion_catalog_provider")?;
let data: NonNull<FFI_CatalogProvider> = capsule
.pointer_checked(Some(c_str!("datafusion_catalog_provider")))?
.pointer_checked(Some(c"datafusion_catalog_provider"))?
.cast();
let provider = unsafe { data.as_ref() };
let provider: Arc<dyn CatalogProvider + Send> = provider.into();
Expand Down Expand Up @@ -692,10 +689,8 @@ fn extract_schema_provider_from_pyobj(
}

let provider = if let Ok(capsule) = schema_provider.cast::<PyCapsule>() {
validate_pycapsule(capsule, "datafusion_schema_provider")?;

let data: NonNull<FFI_SchemaProvider> = capsule
.pointer_checked(Some(c_str!("datafusion_schema_provider")))?
.pointer_checked(Some(c"datafusion_schema_provider"))?
.cast();
let provider = unsafe { data.as_ref() };
let provider: Arc<dyn SchemaProvider + Send> = provider.into();
Expand Down
21 changes: 6 additions & 15 deletions crates/core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,11 @@ use datafusion_ffi::table_provider_factory::FFI_TableProviderFactory;
use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
use datafusion_python_util::{
create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, get_global_ctx,
get_tokio_runtime, spawn_future, validate_pycapsule, wait_for_future,
get_tokio_runtime, spawn_future, wait_for_future,
};
use object_store::ObjectStore;
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::{PyKeyError, PyValueError};
use pyo3::ffi::c_str;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
use url::Url;
Expand Down Expand Up @@ -675,10 +674,8 @@ impl PySessionContext {

let factory: Arc<dyn TableProviderFactory> =
if let Ok(capsule) = factory.cast::<PyCapsule>().map_err(py_datafusion_err) {
validate_pycapsule(capsule, "datafusion_table_provider_factory")?;

let data: NonNull<FFI_TableProviderFactory> = capsule
.pointer_checked(Some(c_str!("datafusion_table_provider_factory")))?
.pointer_checked(Some(c"datafusion_table_provider_factory"))?
.cast();
let factory = unsafe { data.as_ref() };
factory.into()
Expand Down Expand Up @@ -709,12 +706,9 @@ impl PySessionContext {
.call1((codec_capsule,))?;
}

let provider = if let Ok(capsule) = provider.cast::<PyCapsule>().map_err(py_datafusion_err)
{
validate_pycapsule(capsule, "datafusion_catalog_provider_list")?;

let provider = if let Ok(capsule) = provider.cast::<PyCapsule>() {
let data: NonNull<FFI_CatalogProviderList> = capsule
.pointer_checked(Some(c_str!("datafusion_catalog_provider_list")))?
.pointer_checked(Some(c"datafusion_catalog_provider_list"))?
.cast();
let provider = unsafe { data.as_ref() };
let provider: Arc<dyn CatalogProviderList + Send> = provider.into();
Expand Down Expand Up @@ -747,12 +741,9 @@ impl PySessionContext {
.call1((codec_capsule,))?;
}

let provider = if let Ok(capsule) = provider.cast::<PyCapsule>().map_err(py_datafusion_err)
{
validate_pycapsule(capsule, "datafusion_catalog_provider")?;

let provider = if let Ok(capsule) = provider.cast::<PyCapsule>() {
let data: NonNull<FFI_CatalogProvider> = capsule
.pointer_checked(Some(c_str!("datafusion_catalog_provider")))?
.pointer_checked(Some(c"datafusion_catalog_provider"))?
.cast();
let provider = unsafe { data.as_ref() };
let provider: Arc<dyn CatalogProvider + Send> = provider.into();
Expand Down
7 changes: 2 additions & 5 deletions crates/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@ use datafusion::logical_expr::SortExpr;
use datafusion::logical_expr::dml::InsertOp;
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
use datafusion::prelude::*;
use datafusion_python_util::{is_ipython_env, spawn_future, validate_pycapsule, wait_for_future};
use datafusion_python_util::{is_ipython_env, spawn_future, wait_for_future};
use futures::{StreamExt, TryStreamExt};
use parking_lot::Mutex;
use pyo3::PyErr;
use pyo3::exceptions::PyValueError;
use pyo3::ffi::c_str;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
Expand Down Expand Up @@ -1117,10 +1116,8 @@ impl PyDataFrame {
let mut projection: Option<SchemaRef> = None;

if let Some(schema_capsule) = requested_schema {
validate_pycapsule(&schema_capsule, "arrow_schema")?;

let data: NonNull<FFI_ArrowSchema> = schema_capsule
.pointer_checked(Some(c_str!("arrow_schema")))?
.pointer_checked(Some(c"arrow_schema"))?
.cast();
let schema_ptr = unsafe { data.as_ref() };
let desired_schema = Schema::try_from(schema_ptr)?;
Expand Down
7 changes: 2 additions & 5 deletions crates/core/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ use datafusion::logical_expr::{
Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, create_udaf,
};
use datafusion_ffi::udaf::FFI_AggregateUDF;
use datafusion_python_util::{parse_volatility, validate_pycapsule};
use pyo3::ffi::c_str;
use datafusion_python_util::parse_volatility;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyTuple};

Expand Down Expand Up @@ -157,10 +156,8 @@ pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
}

fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult<AggregateUDF> {
validate_pycapsule(capsule, "datafusion_aggregate_udf")?;

let data: NonNull<FFI_AggregateUDF> = capsule
.pointer_checked(Some(c_str!("datafusion_aggregate_udf")))?
.pointer_checked(Some(c"datafusion_aggregate_udf"))?
.cast();
let udaf = unsafe { data.as_ref() };
let udaf: Arc<dyn AggregateUDFImpl> = udaf.into();
Expand Down
11 changes: 4 additions & 7 deletions crates/core/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,12 @@ use datafusion::logical_expr::{
Volatility,
};
use datafusion_ffi::udf::FFI_ScalarUDF;
use datafusion_python_util::{parse_volatility, validate_pycapsule};
use pyo3::ffi::c_str;
use datafusion_python_util::parse_volatility;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyTuple};

use crate::array::PyArrowArrayExportable;
use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err};
use crate::errors::{PyDataFusionResult, to_datafusion_err};
use crate::expr::PyExpr;

/// This struct holds the Python written function that is a
Expand Down Expand Up @@ -194,11 +193,9 @@ impl PyScalarUDF {
pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
if func.hasattr("__datafusion_scalar_udf__")? {
let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?;
let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_scalar_udf")?;

let capsule = capsule.cast::<PyCapsule>().map_err(to_datafusion_err)?;
let data: NonNull<FFI_ScalarUDF> = capsule
.pointer_checked(Some(c_str!("datafusion_scalar_udf")))?
.pointer_checked(Some(c"datafusion_scalar_udf"))?
.cast();
let udf = unsafe { data.as_ref() };
let udf: Arc<dyn ScalarUDFImpl> = udf.into();
Expand Down
8 changes: 2 additions & 6 deletions crates/core/src/udtf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ use datafusion::catalog::{TableFunctionImpl, TableProvider};
use datafusion::error::Result as DataFusionResult;
use datafusion::logical_expr::Expr;
use datafusion_ffi::udtf::FFI_TableFunction;
use datafusion_python_util::validate_pycapsule;
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::{PyImportError, PyTypeError};
use pyo3::ffi::c_str;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyTuple, PyType};

Expand Down Expand Up @@ -73,11 +71,9 @@ impl PyTableFunction {
err
}
})?;
let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_table_function")?;

let capsule = capsule.cast::<PyCapsule>()?;
let data: NonNull<FFI_TableFunction> = capsule
.pointer_checked(Some(c_str!("datafusion_table_function")))?
.pointer_checked(Some(c"datafusion_table_function"))?
.cast();
let ffi_func = unsafe { data.as_ref() };
let foreign_func: Arc<dyn TableFunctionImpl> = ffi_func.to_owned().into();
Expand Down
11 changes: 4 additions & 7 deletions crates/core/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@ use datafusion::logical_expr::{
};
use datafusion::scalar::ScalarValue;
use datafusion_ffi::udwf::FFI_WindowUDF;
use datafusion_python_util::{parse_volatility, validate_pycapsule};
use datafusion_python_util::parse_volatility;
use pyo3::exceptions::PyValueError;
use pyo3::ffi::c_str;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyList, PyTuple};

use crate::common::data_type::PyScalarValue;
use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err};
use crate::errors::{PyDataFusionResult, to_datafusion_err};
use crate::expr::PyExpr;

#[derive(Debug)]
Expand Down Expand Up @@ -262,11 +261,9 @@ impl PyWindowUDF {
func
};

let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_window_udf")?;

let capsule = capsule.cast::<PyCapsule>().map_err(to_datafusion_err)?;
let data: NonNull<FFI_WindowUDF> = capsule
.pointer_checked(Some(c_str!("datafusion_window_udf")))?
.pointer_checked(Some(c"datafusion_window_udf"))?
.cast();
let udwf = unsafe { data.as_ref() };
let udwf: Arc<dyn WindowUDFImpl> = udwf.into();
Expand Down
13 changes: 4 additions & 9 deletions crates/util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@ use datafusion::logical_expr::Volatility;
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use datafusion_ffi::table_provider::FFI_TableProvider;
use pyo3::exceptions::{PyImportError, PyTypeError, PyValueError};
use pyo3::ffi::c_str;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyType};
use tokio::runtime::Runtime;
use tokio::task::JoinHandle;
use tokio::time::sleep;

use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err};
use crate::errors::{PyDataFusionError, PyDataFusionResult, to_datafusion_err};

pub mod errors;

Expand Down Expand Up @@ -186,11 +185,9 @@ pub fn table_provider_from_pycapsule<'py>(
})?;
}

if let Ok(capsule) = obj.cast::<PyCapsule>().map_err(py_datafusion_err) {
validate_pycapsule(capsule, "datafusion_table_provider")?;

if let Ok(capsule) = obj.cast::<PyCapsule>() {
let data: NonNull<FFI_TableProvider> = capsule
.pointer_checked(Some(c_str!("datafusion_table_provider")))?
.pointer_checked(Some(c"datafusion_table_provider"))?
.cast();
let provider = unsafe { data.as_ref() };
let provider: Arc<dyn TableProvider> = provider.into();
Expand Down Expand Up @@ -220,10 +217,8 @@ pub fn ffi_logical_codec_from_pycapsule(obj: Bound<PyAny>) -> PyResult<FFI_Logic
};

let capsule = capsule.cast::<PyCapsule>()?;
validate_pycapsule(capsule, "datafusion_logical_extension_codec")?;

let data: NonNull<FFI_LogicalExtensionCodec> = capsule
.pointer_checked(Some(c_str!("datafusion_logical_extension_codec")))?
.pointer_checked(Some(c"datafusion_logical_extension_codec"))?
.cast();
let codec = unsafe { data.as_ref() };

Expand Down