diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/MagnitudeVector.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/MagnitudeVector.java index dfe6696dd..58d7fb685 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/MagnitudeVector.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/MagnitudeVector.java @@ -19,11 +19,50 @@ package org.apache.geaflow.ai.index.vector; +import java.util.Objects; + public class MagnitudeVector implements IVector { + private final double magnitude; + + public MagnitudeVector() { + this.magnitude = 0.0; + } + + public MagnitudeVector(double magnitude) { + this.magnitude = magnitude; + } + + public double getMagnitude() { + return magnitude; + } + @Override public double match(IVector other) { - return 0; + if (!(other instanceof MagnitudeVector)) { + throw new IllegalArgumentException("Other vector must be a MagnitudeVector"); + } + + MagnitudeVector otherVec = (MagnitudeVector) other; + double otherMagnitude = otherVec.magnitude; + + return computeSimilarity(otherMagnitude); + + } + + private double computeSimilarity(double otherMagnitude) { + if (this.magnitude == 0.0 && otherMagnitude == 0.0) { + return 1.0; + } + + double diff = Math.abs(this.magnitude - otherMagnitude); + double max = Math.max(Math.abs(this.magnitude), Math.abs(otherMagnitude)); + + if (max == 0.0) { + return 1.0; + } + + return 1.0 - (diff / max); } @Override @@ -31,8 +70,25 @@ public VectorType getType() { return VectorType.MagnitudeVector; } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + MagnitudeVector that = (MagnitudeVector) o; + return Double.compare(that.magnitude, magnitude) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(magnitude); + } + @Override public String toString() { - return "MagnitudeVector{}"; + return "MagnitudeVector{magnitude=" + magnitude + '}'; } } diff --git a/geaflow-ai/src/test/java/org/apache/geaflow/ai/GraphMemoryTest.java b/geaflow-ai/src/test/java/org/apache/geaflow/ai/GraphMemoryTest.java index 6216cb344..d0af4f435 100644 --- a/geaflow-ai/src/test/java/org/apache/geaflow/ai/GraphMemoryTest.java +++ b/geaflow-ai/src/test/java/org/apache/geaflow/ai/GraphMemoryTest.java @@ -27,9 +27,11 @@ import org.apache.geaflow.ai.index.EntityAttributeIndexStore; import org.apache.geaflow.ai.index.IndexStore; import org.apache.geaflow.ai.index.vector.EmbeddingVector; +import org.apache.geaflow.ai.index.vector.IVector; import org.apache.geaflow.ai.index.vector.KeywordVector; import org.apache.geaflow.ai.index.vector.MagnitudeVector; import org.apache.geaflow.ai.index.vector.TraversalVector; +import org.apache.geaflow.ai.index.vector.VectorType; import org.apache.geaflow.ai.search.VectorSearch; import org.apache.geaflow.ai.verbalization.Context; import org.apache.geaflow.ai.verbalization.SubgraphSemanticPromptFunction; @@ -38,6 +40,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + public class GraphMemoryTest { private static final Logger LOGGER = LoggerFactory.getLogger(GraphMemoryTest.class); @@ -52,6 +57,55 @@ public void testVectorSearch() { LOGGER.info(String.valueOf(search)); } + // ========== MagnitudeVector Tests ========== + + @Test + public void testMagnitudeVectorConstructorAndGetter() { + MagnitudeVector vector = new MagnitudeVector(0.85); + assertEquals(vector.getMagnitude(), 0.85, 0.0001); + } + + @Test + public void testMagnitudeVectorMatchExactSameValue() { + MagnitudeVector v1 = new MagnitudeVector(5.0); + MagnitudeVector v2 = new MagnitudeVector(5.0); + + assertEquals(v1.match(v2), 1.0, 0.0001); + } + + @Test + public void testMagnitudeVectorMatchDifferentValues() { + MagnitudeVector v1 = new MagnitudeVector(10.0); + MagnitudeVector v2 = new MagnitudeVector(5.0); + + // Expected: 1 - |10-5|/max(10,5) = 1 - 5/10 = 0.5 + assertEquals(v1.match(v2), 0.5, 0.0001); + } + @Test + public void testMagnitudeVectorEqualsAndHashCode() { + MagnitudeVector v1 = new MagnitudeVector(5.0); + MagnitudeVector v2 = new MagnitudeVector(5.0); + MagnitudeVector v3 = new MagnitudeVector(10.0); + + assertEquals(v1, v2); + assertEquals(v1.hashCode(), v2.hashCode()); + assertNotEquals(v1, v3); + } + + @Test + public void testMagnitudeVectorToString() { + MagnitudeVector vector = new MagnitudeVector(0.75); + String str = vector.toString(); + + assertEquals(str, "MagnitudeVector{magnitude=0.75}"); + } + + @Test + public void testMagnitudeVectorGetType() { + MagnitudeVector vector = new MagnitudeVector(1.0); + assertEquals(vector.getType(), VectorType.MagnitudeVector); + } + @Test public void testEmptyMainPipeline() { GraphMemoryServer server = new GraphMemoryServer();