diff --git a/core/src/main/java/com/google/adk/sessions/State.java b/core/src/main/java/com/google/adk/sessions/State.java index ec23857d9..f828c3c1b 100644 --- a/core/src/main/java/com/google/adk/sessions/State.java +++ b/core/src/main/java/com/google/adk/sessions/State.java @@ -51,21 +51,43 @@ public State(ConcurrentMap state, ConcurrentMap @Override public void clear() { state.clear(); + // Delta should likely be cleared too if we are clearing the state, + // or we might want to mark everything as removed in delta. + // Given the Python implementation doesn't have clear, and this is a local view, + // clearing both seems appropriate to reset the object. + delta.clear(); } @Override public boolean containsKey(Object key) { + if (delta.containsKey(key)) { + return delta.get(key) != REMOVED; + } return state.containsKey(key); } @Override public boolean containsValue(Object value) { - return state.containsValue(value); + // This is expensive but necessary for correctness with the merged view. + return values().contains(value); } @Override public Set> entrySet() { - return state.entrySet(); + // This provides a snapshot, not a live view backed by the map, which differs from standard Map + // contract. + // However, given the complexity of merging two concurrent maps, this is a reasonable compromise + // for this specific implementation. + // TODO: Consider implementing a live view if needed. + Map merged = new ConcurrentHashMap<>(state); + for (Entry entry : delta.entrySet()) { + if (entry.getValue() == REMOVED) { + merged.remove(entry.getKey()); + } else { + merged.put(entry.getKey(), entry.getValue()); + } + } + return merged.entrySet(); } @Override @@ -73,46 +95,91 @@ public boolean equals(Object o) { if (o == this) { return true; } - if (!(o instanceof State other)) { + if (!(o instanceof Map)) { return false; } - return state.equals(other.state); + Map other = (Map) o; + // We can't easily rely on state.equals() because our "content" is merged. + // Validating equality against another Map requires checking the merged view. + if (size() != other.size()) { + return false; + } + try { + for (Entry e : entrySet()) { + String key = e.getKey(); + Object value = e.getValue(); + if (value == null) { + if (!(other.get(key) == null && other.containsKey(key))) return false; + } else { + if (!value.equals(other.get(key))) return false; + } + } + } catch (ClassCastException | NullPointerException unused) { + return false; + } + return true; } @Override public Object get(Object key) { + if (delta.containsKey(key)) { + Object value = delta.get(key); + return value == REMOVED ? null : value; + } return state.get(key); } @Override public int hashCode() { - return state.hashCode(); + // Similar to equals, we need to calculate hash code based on the merged entry set. + int h = 0; + for (Entry entry : entrySet()) { + h += entry.hashCode(); + } + return h; } @Override public boolean isEmpty() { - return state.isEmpty(); + if (delta.isEmpty()) { + return state.isEmpty(); + } + // If delta is not empty, we need to check if it effectively removes everything from state + // or adds something. + return size() == 0; } @Override public Set keySet() { - return state.keySet(); + // Snapshot view + Map merged = new ConcurrentHashMap<>(state); + for (Entry entry : delta.entrySet()) { + if (entry.getValue() == REMOVED) { + merged.remove(entry.getKey()); + } else { + merged.put(entry.getKey(), entry.getValue()); + } + } + return merged.keySet(); } @Override public Object put(String key, Object value) { - Object oldValue = state.put(key, value); + // Current value logic needs to check delta first to return correct "oldValue" + Object oldValue = get(key); + state.put(key, value); delta.put(key, value); return oldValue; } @Override public Object putIfAbsent(String key, Object value) { - Object existingValue = state.putIfAbsent(key, value); - if (existingValue == null) { - delta.put(key, value); + Object currentValue = get(key); + if (currentValue == null) { + put(key, value); + return null; } - return existingValue; + return currentValue; } @Override @@ -123,47 +190,67 @@ public void putAll(Map m) { @Override public Object remove(Object key) { - if (state.containsKey(key)) { + Object oldValue = get(key); + // We explicitly check for containment in the merged view to ensure we return the correct old + // value. + if (state.containsKey(key) || (delta.containsKey(key) && delta.get(key) != REMOVED)) { delta.put((String) key, REMOVED); } - return state.remove(key); + + // We remove from the state map to keep it consistent with the write-through behavior of this + // class. + state.remove(key); + return oldValue; } @Override public boolean remove(Object key, Object value) { - boolean removed = state.remove(key, value); - if (removed) { - delta.put((String) key, REMOVED); + Object currentValue = get(key); + if (Objects.equals(currentValue, value) && (currentValue != null || containsKey(key))) { + remove(key); + return true; } - return removed; + return false; } @Override public boolean replace(String key, Object oldValue, Object newValue) { - boolean replaced = state.replace(key, oldValue, newValue); - if (replaced) { - delta.put(key, newValue); + Object currentValue = get(key); + if (Objects.equals(currentValue, oldValue) && (currentValue != null || containsKey(key))) { + put(key, newValue); + return true; } - return replaced; + return false; } @Override public Object replace(String key, Object value) { - Object oldValue = state.replace(key, value); - if (oldValue != null) { - delta.put(key, value); + Object currentValue = get(key); + if (currentValue != null || containsKey(key)) { + put(key, value); + return currentValue; } - return oldValue; + return null; } @Override public int size() { - return state.size(); + // Expensive, but accurate merged size. + return entrySet().size(); } @Override public Collection values() { - return state.values(); + // Snapshot view + Map merged = new ConcurrentHashMap<>(state); + for (Entry entry : delta.entrySet()) { + if (entry.getValue() == REMOVED) { + merged.remove(entry.getKey()); + } else { + merged.put(entry.getKey(), entry.getValue()); + } + } + return merged.values(); } public boolean hasDelta() { diff --git a/core/src/test/java/com/google/adk/sessions/StateDiffTest.java b/core/src/test/java/com/google/adk/sessions/StateDiffTest.java new file mode 100644 index 000000000..55d659e89 --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/StateDiffTest.java @@ -0,0 +1,143 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.sessions; + +import static com.google.common.truth.Truth.assertThat; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class StateDiffTest { + + @Test + public void get_returnsValueFromDeltaIfPresent() { + ConcurrentMap stateMap = new ConcurrentHashMap<>(); + stateMap.put("key", "initialValue"); + ConcurrentMap deltaMap = new ConcurrentHashMap<>(); + deltaMap.put("key", "newValue"); + State state = new State(stateMap, deltaMap); + + assertThat(state.get("key")).isEqualTo("newValue"); + } + + @Test + public void get_returnsValueFromStateIfNotInDelta() { + ConcurrentMap stateMap = new ConcurrentHashMap<>(); + stateMap.put("key", "initialValue"); + State state = new State(stateMap); + + assertThat(state.get("key")).isEqualTo("initialValue"); + } + + @Test + public void get_returnsNullIfKeyInDeltaAsRemoved() { + ConcurrentMap stateMap = new ConcurrentHashMap<>(); + stateMap.put("key", "initialValue"); + ConcurrentMap deltaMap = new ConcurrentHashMap<>(); + deltaMap.put("key", State.REMOVED); + State state = new State(stateMap, deltaMap); + + assertThat(state.get("key")).isNull(); + } + + @Test + public void containsKey_respectsDelta() { + ConcurrentMap stateMap = new ConcurrentHashMap<>(); + stateMap.put("key1", "value1"); + ConcurrentMap deltaMap = new ConcurrentHashMap<>(); + deltaMap.put("key1", State.REMOVED); + deltaMap.put("key2", "value2"); + State state = new State(stateMap, deltaMap); + + assertThat(state.containsKey("key1")).isFalse(); + assertThat(state.containsKey("key2")).isTrue(); + } + + @Test + public void size_respectsDelta() { + ConcurrentMap stateMap = new ConcurrentHashMap<>(); + stateMap.put("key1", "value1"); + stateMap.put("key2", "value2"); + ConcurrentMap deltaMap = new ConcurrentHashMap<>(); + deltaMap.put("key1", State.REMOVED); + deltaMap.put("key3", "value3"); + State state = new State(stateMap, deltaMap); + + assertThat(state.size()).isEqualTo(2); // key2, key3 + } + + @Test + public void isEmpty_respectsDelta() { + ConcurrentMap stateMap = new ConcurrentHashMap<>(); + stateMap.put("key1", "value1"); + ConcurrentMap deltaMap = new ConcurrentHashMap<>(); + deltaMap.put("key1", State.REMOVED); + State state = new State(stateMap, deltaMap); + + assertThat(state.isEmpty()).isTrue(); + } + + @Test + public void entrySet_reflectsMergedState() { + ConcurrentMap stateMap = new ConcurrentHashMap<>(); + stateMap.put("key1", "value1"); + stateMap.put("key2", "value2"); + ConcurrentMap deltaMap = new ConcurrentHashMap<>(); + deltaMap.put("key1", "newValue1"); + deltaMap.put("key3", "value3"); + State state = new State(stateMap, deltaMap); + + Map expected = new HashMap<>(); + expected.put("key1", "newValue1"); + expected.put("key2", "value2"); + expected.put("key3", "value3"); + + assertThat(state.entrySet()).containsExactlyElementsIn(expected.entrySet()); + } + + @Test + public void keySet_reflectsMergedState() { + ConcurrentMap stateMap = new ConcurrentHashMap<>(); + stateMap.put("key1", "value1"); + stateMap.put("key2", "value2"); + ConcurrentMap deltaMap = new ConcurrentHashMap<>(); + deltaMap.put("key1", State.REMOVED); + deltaMap.put("key3", "value3"); + State state = new State(stateMap, deltaMap); + + assertThat(state.keySet()).containsExactly("key2", "key3"); + } + + @Test + public void values_reflectsMergedState() { + ConcurrentMap stateMap = new ConcurrentHashMap<>(); + stateMap.put("key1", "value1"); + stateMap.put("key2", "value2"); + ConcurrentMap deltaMap = new ConcurrentHashMap<>(); + deltaMap.put("key1", "newValue1"); + deltaMap.put("key3", "value3"); + State state = new State(stateMap, deltaMap); + + assertThat(state.values()).containsExactly("newValue1", "value2", "value3"); + } +}