Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,76 @@

package org.apache.geaflow.ai.index.vector;

import java.util.Objects;

public class MagnitudeVector implements IVector {

private final double magnitude;

public MagnitudeVector() {
this.magnitude = 0.0;
}

public MagnitudeVector(double magnitude) {
this.magnitude = magnitude;
}

public double getMagnitude() {
return magnitude;
}

@Override
public double match(IVector other) {
return 0;
if (!(other instanceof MagnitudeVector)) {
throw new IllegalArgumentException("Other vector must be a MagnitudeVector");
}

MagnitudeVector otherVec = (MagnitudeVector) other;
double otherMagnitude = otherVec.magnitude;

return computeSimilarity(otherMagnitude);

}

private double computeSimilarity(double otherMagnitude) {
if (this.magnitude == 0.0 && otherMagnitude == 0.0) {
return 1.0;
}

double diff = Math.abs(this.magnitude - otherMagnitude);
double max = Math.max(Math.abs(this.magnitude), Math.abs(otherMagnitude));

if (max == 0.0) {
return 1.0;
}

return 1.0 - (diff / max);
}

@Override
public VectorType getType() {
return VectorType.MagnitudeVector;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
MagnitudeVector that = (MagnitudeVector) o;
return Double.compare(that.magnitude, magnitude) == 0;
}

@Override
public int hashCode() {
return Objects.hash(magnitude);
}

@Override
public String toString() {
return "MagnitudeVector{}";
return "MagnitudeVector{magnitude=" + magnitude + '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
import org.apache.geaflow.ai.index.EntityAttributeIndexStore;
import org.apache.geaflow.ai.index.IndexStore;
import org.apache.geaflow.ai.index.vector.EmbeddingVector;
import org.apache.geaflow.ai.index.vector.IVector;
import org.apache.geaflow.ai.index.vector.KeywordVector;
import org.apache.geaflow.ai.index.vector.MagnitudeVector;
import org.apache.geaflow.ai.index.vector.TraversalVector;
import org.apache.geaflow.ai.index.vector.VectorType;
import org.apache.geaflow.ai.search.VectorSearch;
import org.apache.geaflow.ai.verbalization.Context;
import org.apache.geaflow.ai.verbalization.SubgraphSemanticPromptFunction;
Expand All @@ -38,6 +40,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;

public class GraphMemoryTest {

private static final Logger LOGGER = LoggerFactory.getLogger(GraphMemoryTest.class);
Expand All @@ -52,6 +57,55 @@ public void testVectorSearch() {
LOGGER.info(String.valueOf(search));
}

// ========== MagnitudeVector Tests ==========

@Test
public void testMagnitudeVectorConstructorAndGetter() {
MagnitudeVector vector = new MagnitudeVector(0.85);
assertEquals(vector.getMagnitude(), 0.85, 0.0001);
}

@Test
public void testMagnitudeVectorMatchExactSameValue() {
MagnitudeVector v1 = new MagnitudeVector(5.0);
MagnitudeVector v2 = new MagnitudeVector(5.0);

assertEquals(v1.match(v2), 1.0, 0.0001);
}

@Test
public void testMagnitudeVectorMatchDifferentValues() {
MagnitudeVector v1 = new MagnitudeVector(10.0);
MagnitudeVector v2 = new MagnitudeVector(5.0);

// Expected: 1 - |10-5|/max(10,5) = 1 - 5/10 = 0.5
assertEquals(v1.match(v2), 0.5, 0.0001);
}
@Test
public void testMagnitudeVectorEqualsAndHashCode() {
MagnitudeVector v1 = new MagnitudeVector(5.0);
MagnitudeVector v2 = new MagnitudeVector(5.0);
MagnitudeVector v3 = new MagnitudeVector(10.0);

assertEquals(v1, v2);
assertEquals(v1.hashCode(), v2.hashCode());
assertNotEquals(v1, v3);
}

@Test
public void testMagnitudeVectorToString() {
MagnitudeVector vector = new MagnitudeVector(0.75);
String str = vector.toString();

assertEquals(str, "MagnitudeVector{magnitude=0.75}");
}

@Test
public void testMagnitudeVectorGetType() {
MagnitudeVector vector = new MagnitudeVector(1.0);
assertEquals(vector.getType(), VectorType.MagnitudeVector);
}

@Test
public void testEmptyMainPipeline() {
GraphMemoryServer server = new GraphMemoryServer();
Expand Down
Loading