Skip to content

Commit 75d8e4c

Browse files
committed
Consolidate special case regexp_match logic
1 parent 8cf70ec commit 75d8e4c

File tree

1 file changed

+100
-82
lines changed

1 file changed

+100
-82
lines changed

datafusion/functions/src/regex/regexpreplace.rs

Lines changed: 100 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ use datafusion_expr::{
4242
Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
4343
};
4444
use datafusion_macros::user_doc;
45-
use regex::Regex;
45+
use regex::{CaptureLocations, Regex};
46+
use std::borrow::Cow;
4647
use std::collections::HashMap;
4748
use std::sync::{Arc, LazyLock};
4849

@@ -201,6 +202,80 @@ fn regex_replace_posix_groups(replacement: &str) -> String {
201202
.into_owned()
202203
}
203204

205+
struct ShortRegex {
206+
/// Shortened anchored regex used to extract capture group 1 directly.
207+
/// See [`try_build_short_extract_regex`] for details.
208+
short_re: Regex,
209+
/// Reusable capture locations for `short_re` to avoid per-row allocation.
210+
locs: CaptureLocations,
211+
}
212+
213+
/// Holds the normal compiled regex together with the optional fast path used
214+
/// for `regexp_replace(str, '^...(capture)...*$', '\1')`.
215+
struct OptimizedRegex {
216+
/// Full regex used for the normal replacement path and as a correctness fallback.
217+
re: Regex,
218+
/// Precomputed state for the direct-extraction fast path, when applicable.
219+
short_re: Option<ShortRegex>,
220+
}
221+
222+
impl OptimizedRegex {
223+
/// Builds any reusable state needed by the extraction fast path.
224+
///
225+
/// The fast path is only enabled for single replacements where the pattern
226+
/// and replacement satisfy [`try_build_short_extract_regex`].
227+
fn new(re: Regex, limit: usize, pattern: &str, replacement: &str) -> Self {
228+
let short_re = if limit == 1 {
229+
try_build_short_extract_regex(pattern, replacement)
230+
} else {
231+
None
232+
};
233+
234+
let short_re = short_re.map(|short_re| {
235+
let locs = short_re.capture_locations();
236+
ShortRegex { short_re, locs }
237+
});
238+
239+
Self { re, short_re }
240+
}
241+
242+
/// Applies the direct-extraction fast path when it preserves the result of
243+
/// `Regex::replacen`; otherwise falls back to the full regex replacement.
244+
fn replacen<'a>(
245+
&mut self,
246+
val: &'a str,
247+
limit: usize,
248+
replacement: &str,
249+
) -> Cow<'a, str> {
250+
// If this pattern is not eligible for direct extraction, use the full regex.
251+
let Some(ShortRegex { short_re, locs }) = self.short_re.as_mut() else {
252+
return self.re.replacen(val, limit, replacement);
253+
};
254+
255+
// If the shortened regex does not match, the original anchored regex would
256+
// also leave the input unchanged.
257+
if short_re.captures_read(locs, val).is_none() {
258+
return Cow::Borrowed(val);
259+
};
260+
261+
// `captures_read` succeeded, so the overall shortened match is present.
262+
let match_end = locs.get(0).unwrap().1;
263+
if memchr(b'\n', &val.as_bytes()[match_end..]).is_some() {
264+
// If there is a newline after the match, we can't use the short
265+
// regex since it won't match across lines. Fall back to the full
266+
// regex replacement.
267+
return self.re.replacen(val, limit, replacement);
268+
};
269+
// The fast path only applies to `${1}` replacements, so the result is
270+
// either capture group 1 or the empty string if that group did not match.
271+
if let Some((start, end)) = locs.get(1) {
272+
Cow::Borrowed(&val[start..end])
273+
} else {
274+
Cow::Borrowed("")
275+
}
276+
}
277+
}
278+
204279
/// For anchored patterns like `^...(capture)....*$` where the replacement
205280
/// is `\1`, build a shorter regex (stripping trailing `.*$`) and use
206281
/// `captures_read` with `CaptureLocations` for direct extraction — no
@@ -442,7 +517,7 @@ macro_rules! fetch_string_arg {
442517
/// hold a single Regex object for the replace operation. This also speeds
443518
/// up the pre-processing time of the replacement string, since it only
444519
/// needs to processed once.
445-
fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
520+
fn regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
446521
args: &[ArrayRef],
447522
) -> Result<ArrayRef> {
448523
let array_size = args[0].len();
@@ -477,13 +552,7 @@ fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
477552
// with rust ones.
478553
let replacement = regex_replace_posix_groups(replacement);
479554

480-
// For anchored patterns like ^...(capture)....*$, build a shorter
481-
// regex and use captures_read for direct extraction.
482-
let short_re = if limit == 1 {
483-
try_build_short_extract_regex(&pattern, &replacement)
484-
} else {
485-
None
486-
};
555+
let mut opt_re = OptimizedRegex::new(re, limit, &pattern, &replacement);
487556

488557
let string_array_type = args[0].data_type();
489558
match string_array_type {
@@ -501,37 +570,13 @@ fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
501570
let mut new_offsets = BufferBuilder::<T>::new(string_array.len() + 1);
502571
new_offsets.append(T::zero());
503572

504-
if let Some(ref short_re) = short_re {
505-
let mut locs = short_re.capture_locations();
506-
string_array.iter().for_each(|val| {
507-
if let Some(val) = val {
508-
if short_re.captures_read(&mut locs, val).is_some() {
509-
let match_end = locs.get(0).unwrap().1;
510-
if memchr(b'\n', &val.as_bytes()[match_end..]).is_none() {
511-
if let Some((start, end)) = locs.get(1) {
512-
vals.append_slice(&val.as_bytes()[start..end]);
513-
}
514-
} else {
515-
// Newline in remainder: .*$ wouldn't match without 's' flag
516-
let result =
517-
re.replacen(val, limit, replacement.as_str());
518-
vals.append_slice(result.as_bytes());
519-
}
520-
} else {
521-
vals.append_slice(val.as_bytes());
522-
}
523-
}
524-
new_offsets.append(T::from_usize(vals.len()).unwrap());
525-
});
526-
} else {
527-
string_array.iter().for_each(|val| {
528-
if let Some(val) = val {
529-
let result = re.replacen(val, limit, replacement.as_str());
530-
vals.append_slice(result.as_bytes());
531-
}
532-
new_offsets.append(T::from_usize(vals.len()).unwrap());
533-
});
534-
}
573+
string_array.iter().for_each(|val| {
574+
if let Some(val) = val {
575+
let result = opt_re.replacen(val, limit, replacement.as_str());
576+
vals.append_slice(result.as_bytes());
577+
}
578+
new_offsets.append(T::from_usize(vals.len()).unwrap());
579+
});
535580

536581
let data = ArrayDataBuilder::new(GenericStringArray::<T>::DATA_TYPE)
537582
.len(string_array.len())
@@ -546,39 +591,12 @@ fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
546591

547592
let mut builder = StringViewBuilder::with_capacity(string_view_array.len());
548593

549-
if let Some(ref short_re) = short_re {
550-
let mut locs = short_re.capture_locations();
551-
for val in string_view_array.iter() {
552-
if let Some(val) = val {
553-
if short_re.captures_read(&mut locs, val).is_some() {
554-
let match_end = locs.get(0).unwrap().1;
555-
if memchr(b'\n', &val.as_bytes()[match_end..]).is_none() {
556-
if let Some((start, end)) = locs.get(1) {
557-
builder.append_value(&val[start..end]);
558-
} else {
559-
builder.append_value("");
560-
}
561-
} else {
562-
// Newline in remainder: .*$ wouldn't match without 's' flag
563-
let result =
564-
re.replacen(val, limit, replacement.as_str());
565-
builder.append_value(result);
566-
}
567-
} else {
568-
builder.append_value(val);
569-
}
570-
} else {
571-
builder.append_null();
572-
}
573-
}
574-
} else {
575-
for val in string_view_array.iter() {
576-
if let Some(val) = val {
577-
let result = re.replacen(val, limit, replacement.as_str());
578-
builder.append_value(result);
579-
} else {
580-
builder.append_null();
581-
}
594+
for val in string_view_array.iter() {
595+
if let Some(val) = val {
596+
let result = opt_re.replacen(val, limit, replacement.as_str());
597+
builder.append_value(result.as_ref());
598+
} else {
599+
builder.append_null();
582600
}
583601
}
584602

@@ -655,7 +673,7 @@ fn specialize_regexp_replace<T: OffsetSizeTrait>(
655673
arg.to_array(expansion_len)
656674
})
657675
.collect::<Result<Vec<_>>>()?;
658-
_regexp_replace_static_pattern_replace::<T>(&args)
676+
regexp_replace_static_pattern_replace::<T>(&args)
659677
}
660678

661679
// If there are no specialized implementations, we'll fall back to the
@@ -789,7 +807,7 @@ mod tests {
789807
let replacements = <$T>::from(replacement);
790808
let expected = <$T>::from(expected);
791809

792-
let re = _regexp_replace_static_pattern_replace::<$O>(&[
810+
let re = regexp_replace_static_pattern_replace::<$O>(&[
793811
Arc::new(values),
794812
Arc::new(patterns),
795813
Arc::new(replacements),
@@ -834,7 +852,7 @@ mod tests {
834852
let flags = StringArray::from(vec!["i"; 5]);
835853
let expected = <$T>::from(expected);
836854

837-
let re = _regexp_replace_static_pattern_replace::<$O>(&[
855+
let re = regexp_replace_static_pattern_replace::<$O>(&[
838856
Arc::new(values),
839857
Arc::new(patterns),
840858
Arc::new(replacements),
@@ -866,7 +884,7 @@ mod tests {
866884
let replacements = StringArray::from(vec!["foo"; 5]);
867885
let expected = StringArray::from(vec![None::<&str>; 5]);
868886

869-
let re = _regexp_replace_static_pattern_replace::<i32>(&[
887+
let re = regexp_replace_static_pattern_replace::<i32>(&[
870888
Arc::new(values),
871889
Arc::new(patterns),
872890
Arc::new(replacements),
@@ -883,7 +901,7 @@ mod tests {
883901
let replacements = StringArray::from(Vec::<Option<&str>>::new());
884902
let expected = StringArray::from(Vec::<Option<&str>>::new());
885903

886-
let re = _regexp_replace_static_pattern_replace::<i32>(&[
904+
let re = regexp_replace_static_pattern_replace::<i32>(&[
887905
Arc::new(values),
888906
Arc::new(patterns),
889907
Arc::new(replacements),
@@ -901,7 +919,7 @@ mod tests {
901919
let flags = StringArray::from(vec![None::<&str>; 5]);
902920
let expected = StringArray::from(vec![None::<&str>; 5]);
903921

904-
let re = _regexp_replace_static_pattern_replace::<i32>(&[
922+
let re = regexp_replace_static_pattern_replace::<i32>(&[
905923
Arc::new(values),
906924
Arc::new(patterns),
907925
Arc::new(replacements),
@@ -920,7 +938,7 @@ mod tests {
920938
let patterns = StringArray::from(vec!["["; 5]);
921939
let replacements = StringArray::from(vec!["foo"; 5]);
922940

923-
let re = _regexp_replace_static_pattern_replace::<i32>(&[
941+
let re = regexp_replace_static_pattern_replace::<i32>(&[
924942
Arc::new(values),
925943
Arc::new(patterns),
926944
Arc::new(replacements),
@@ -957,7 +975,7 @@ mod tests {
957975
Some("c"),
958976
]);
959977

960-
let re = _regexp_replace_static_pattern_replace::<i32>(&[
978+
let re = regexp_replace_static_pattern_replace::<i32>(&[
961979
Arc::new(values),
962980
Arc::new(patterns),
963981
Arc::new(replacements),
@@ -985,7 +1003,7 @@ mod tests {
9851003
let replacements = StringArray::from(vec!["foo"; 1]);
9861004
let expected = StringArray::from(vec![Some("b"), None, Some("foo"), None, None]);
9871005

988-
let re = _regexp_replace_static_pattern_replace::<i32>(&[
1006+
let re = regexp_replace_static_pattern_replace::<i32>(&[
9891007
Arc::new(values),
9901008
Arc::new(patterns),
9911009
Arc::new(replacements),

0 commit comments

Comments
 (0)