Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.common.cluster;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.text.similarity.CosineSimilarity;

/**
* Greedy single-pass text similarity clustering for grouping similar text values. Events are
* processed in order; each is compared to existing cluster representatives using cosine similarity.
* If the best match meets the threshold, the event joins that cluster; otherwise a new cluster is
* created.
*
* <p>Optimized for incremental processing with vector caching and memory-efficient operations.
*/
public class TextSimilarityClustering {

private static final CosineSimilarity COSINE = new CosineSimilarity();

// Cache vectorized representations to avoid recomputation
private final Map<String, Map<CharSequence, Integer>> vectorCache = new ConcurrentHashMap<>();
private static final int MAX_CACHE_SIZE = 10000;

private final double threshold;
private final String matchMode;
private final String delims;

public TextSimilarityClustering(double threshold, String matchMode, String delims) {
this.threshold = validateThreshold(threshold);
this.matchMode = validateMatchMode(matchMode);
this.delims = delims != null ? delims : " ";
}

private static double validateThreshold(double threshold) {
if (threshold <= 0.0 || threshold >= 1.0) {
throw new IllegalArgumentException(
"The threshold must be > 0.0 and < 1.0, got: " + threshold);
}
return threshold;
}

private static String validateMatchMode(String matchMode) {
if (matchMode == null) {
return "termlist";
}
switch (matchMode.toLowerCase()) {
case "termlist":
case "termset":
case "ngramset":
return matchMode.toLowerCase();
default:
throw new IllegalArgumentException(
"Invalid match mode: " + matchMode + ". Must be one of: termlist, termset, ngramset");
}
}

/**
* Compute similarity between two text values using the configured match mode. Used for
* incremental clustering against cluster representatives.
*/
public double computeSimilarity(String text1, String text2) {
// Normalize nulls to empty strings
String normalizedText1 = (text1 == null) ? "" : text1;
String normalizedText2 = (text2 == null) ? "" : text2;

// Both are empty - perfect match
if (normalizedText1.isEmpty() && normalizedText2.isEmpty()) {
return 1.0;
}

// One is empty, other isn't - no match
if (normalizedText1.isEmpty() || normalizedText2.isEmpty()) {
return 0.0;
}

// Both non-empty - compute cosine similarity
Map<CharSequence, Integer> vector1 = vectorizeWithCache(normalizedText1);
Map<CharSequence, Integer> vector2 = vectorizeWithCache(normalizedText2);

return COSINE.cosineSimilarity(vector1, vector2);
}

private Map<CharSequence, Integer> vectorizeWithCache(String value) {
return vectorCache.computeIfAbsent(value, k -> {
if (vectorCache.size() > MAX_CACHE_SIZE) {
vectorCache.keySet().parallelStream()
.limit(MAX_CACHE_SIZE / 2)
.forEach(vectorCache::remove);
}
return vectorize(k);
});
}

private Map<CharSequence, Integer> vectorize(String value) {
if (value == null || value.isEmpty()) {
return Map.of();
}
return switch (matchMode) {
case "termset" -> vectorizeTermSet(value);
case "ngramset" -> vectorizeNgramSet(value);
default -> vectorizeTermList(value);
};
}

private static final java.util.regex.Pattern NUMERIC_PATTERN =
java.util.regex.Pattern.compile("^\\d+$");

private static String normalizeToken(String token) {
return NUMERIC_PATTERN.matcher(token).matches() ? "*" : token;
}

/** Positional term frequency — token order matters. */
private Map<CharSequence, Integer> vectorizeTermList(String value) {
String[] tokens = tokenize(value);
Map<CharSequence, Integer> vector = new HashMap<>((int) (tokens.length * 1.4));

for (int i = 0; i < tokens.length; i++) {
if (!tokens[i].isEmpty()) {
String key = i + "-" + normalizeToken(tokens[i]);
vector.merge(key, 1, Integer::sum);
}
}
return vector;
}

/** Bag-of-words term frequency — token order ignored. */
private Map<CharSequence, Integer> vectorizeTermSet(String value) {
String[] tokens = tokenize(value);
Map<CharSequence, Integer> vector = new HashMap<>((int) (tokens.length * 1.4));

for (String token : tokens) {
if (!token.isEmpty()) {
vector.merge(normalizeToken(token), 1, Integer::sum);
}
}
return vector;
}

/** Character trigram frequency. */
private Map<CharSequence, Integer> vectorizeNgramSet(String value) {
if (value.length() < 3) {
// For very short strings, fall back to character frequency
Map<CharSequence, Integer> vector = new HashMap<>();
for (char c : value.toCharArray()) {
vector.merge(String.valueOf(c), 1, Integer::sum);
}
return vector;
}

Map<CharSequence, Integer> vector = new HashMap<>((int) ((value.length() - 2) * 1.4));
for (int i = 0; i <= value.length() - 3; i++) {
String ngram = value.substring(i, i + 3);
vector.merge(ngram, 1, Integer::sum);
}
return vector;
}

private String[] tokenize(String value) {
if ("non-alphanumeric".equals(delims)) {
return value.split("[^a-zA-Z0-9_]+");
}
String pattern = "[" + java.util.regex.Pattern.quote(delims) + "]+";
return value.split(pattern);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.opensearch.sql.ast.tree.Bin;
import org.opensearch.sql.ast.tree.Chart;
import org.opensearch.sql.ast.tree.CloseCursor;
import org.opensearch.sql.ast.tree.Cluster;
import org.opensearch.sql.ast.tree.Convert;
import org.opensearch.sql.ast.tree.Dedupe;
import org.opensearch.sql.ast.tree.Eval;
Expand Down Expand Up @@ -432,6 +433,10 @@ public T visitPatterns(Patterns patterns, C context) {
return visitChildren(patterns, context);
}

public T visitCluster(Cluster node, C context) {
return visitChildren(node, context);
}

public T visitWindow(Window window, C context) {
return visitChildren(window, context);
}
Expand Down
53 changes: 53 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/Cluster.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

/** AST node for the PPL cluster command. */
@Getter
@Setter
@ToString
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
@AllArgsConstructor
public class Cluster extends UnresolvedPlan {

private final UnresolvedExpression sourceField;
private final double threshold;
private final String matchMode;
private final String labelField;
private final String countField;
private final boolean labelOnly;
private final boolean showCount;
private final String delims;
private UnresolvedPlan child;

@Override
public Cluster attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public List<UnresolvedPlan> getChild() {
return this.child == null ? ImmutableList.of() : ImmutableList.of(this.child);
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitCluster(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2478,6 +2478,93 @@ public RelNode visitKmeans(Kmeans node, CalcitePlanContext context) {
throw new CalciteUnsupportedException("Kmeans command is unsupported in Calcite");
}

@Override
public RelNode visitCluster(
org.opensearch.sql.ast.tree.Cluster node, CalcitePlanContext context) {
visitChildren(node, context);

// Resolve clustering as a window function over all rows (unbounded frame).
// The window function buffers all rows, runs the greedy clustering algorithm,
// and returns an array of cluster labels (one per input row, in order).
List<UnresolvedExpression> funcParams = new ArrayList<>();
funcParams.add(node.getSourceField());
funcParams.add(AstDSL.doubleLiteral(node.getThreshold()));
funcParams.add(AstDSL.stringLiteral(node.getMatchMode()));
funcParams.add(AstDSL.stringLiteral(node.getDelims()));

RexNode clusterWindow =
rexVisitor.analyze(
new WindowFunction(
new Function(
BuiltinFunctionName.INTERNAL_CLUSTER_LABEL.getName().getFunctionName(),
funcParams),
List.of(),
List.of()),
context);
String arrayAlias = "_cluster_labels_array";
context.relBuilder.projectPlus(context.relBuilder.alias(clusterWindow, arrayAlias));

// Add ROW_NUMBER to index into the array (1-based).
String rowNumAlias = "_cluster_row_idx";
RexNode rowNum =
context
.relBuilder
.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
.over()
.rowsBetween(RexWindowBounds.UNBOUNDED_PRECEDING, RexWindowBounds.CURRENT_ROW)
.as(rowNumAlias);
context.relBuilder.projectPlus(rowNum);

// Extract the label for this row: array[row_number] (ITEM access is 1-based).
RexNode rowIdxAsInt =
context.rexBuilder.makeCast(
context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.INTEGER),
context.relBuilder.field(rowNumAlias));
RexNode labelExpr =
context.rexBuilder.makeCall(
SqlStdOperatorTable.ITEM, context.relBuilder.field(arrayAlias), rowIdxAsInt);
context.relBuilder.projectPlus(context.relBuilder.alias(labelExpr, node.getLabelField()));

// Remove the temporary array and row index columns.
context.relBuilder.projectExcept(
context.relBuilder.field(arrayAlias), context.relBuilder.field(rowNumAlias));

if (node.isShowCount()) {
// cluster_count = COUNT(*) OVER (PARTITION BY cluster_label)
RexNode countWindow =
context
.relBuilder
.aggregateCall(SqlStdOperatorTable.COUNT)
.over()
.partitionBy(context.relBuilder.field(node.getLabelField()))
.rowsBetween(RexWindowBounds.UNBOUNDED_PRECEDING, RexWindowBounds.UNBOUNDED_FOLLOWING)
.as(node.getCountField());
context.relBuilder.projectPlus(countWindow);
}

if (!node.isLabelOnly()) {
// Filter to representative rows only: keep the first event per cluster.
String convergenceRowNum = "_cluster_convergence_row_num";
RexNode convergenceRn =
context
.relBuilder
.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
.over()
.partitionBy(context.relBuilder.field(node.getLabelField()))
.rowsTo(RexWindowBounds.CURRENT_ROW)
.as(convergenceRowNum);
context.relBuilder.projectPlus(convergenceRn);
context.relBuilder.filter(
context.rexBuilder.makeCall(
SqlStdOperatorTable.EQUALS,
context.relBuilder.field(convergenceRowNum),
context.rexBuilder.makeExactLiteral(java.math.BigDecimal.ONE)));
context.relBuilder.projectExcept(context.relBuilder.field(convergenceRowNum));
}

return context.relBuilder.peek();
}

@Override
public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) {
visitChildren(node, context);
Expand Down
Loading
Loading