Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 115 additions & 28 deletions core/src/main/java/com/google/adk/sessions/State.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,68 +51,135 @@ public State(ConcurrentMap<String, Object> state, ConcurrentMap<String, Object>
@Override
public void clear() {
state.clear();
// Delta should likely be cleared too if we are clearing the state,
// or we might want to mark everything as removed in delta.
// Given the Python implementation doesn't have clear, and this is a local view,
// clearing both seems appropriate to reset the object.
delta.clear();
}

@Override
public boolean containsKey(Object key) {
if (delta.containsKey(key)) {
return delta.get(key) != REMOVED;
}
return state.containsKey(key);
}

@Override
public boolean containsValue(Object value) {
return state.containsValue(value);
// This is expensive but necessary for correctness with the merged view.
return values().contains(value);
}

@Override
public Set<Entry<String, Object>> entrySet() {
return state.entrySet();
// This provides a snapshot, not a live view backed by the map, which differs from standard Map
// contract.
// However, given the complexity of merging two concurrent maps, this is a reasonable compromise
// for this specific implementation.
// TODO: Consider implementing a live view if needed.
Map<String, Object> merged = new ConcurrentHashMap<>(state);
for (Entry<String, Object> entry : delta.entrySet()) {
if (entry.getValue() == REMOVED) {
merged.remove(entry.getKey());
} else {
merged.put(entry.getKey(), entry.getValue());
}
}
return merged.entrySet();
}

@Override
public boolean equals(Object o) {
if (o == this) {
return true;
}
if (!(o instanceof State other)) {
if (!(o instanceof Map)) {
return false;
}
return state.equals(other.state);
Map<?, ?> other = (Map<?, ?>) o;
// We can't easily rely on state.equals() because our "content" is merged.
// Validating equality against another Map requires checking the merged view.
if (size() != other.size()) {
return false;
}
try {
for (Entry<String, Object> e : entrySet()) {
String key = e.getKey();
Object value = e.getValue();
if (value == null) {
if (!(other.get(key) == null && other.containsKey(key))) return false;
} else {
if (!value.equals(other.get(key))) return false;
}
}
} catch (ClassCastException | NullPointerException unused) {
return false;
}
return true;
}

@Override
public Object get(Object key) {
if (delta.containsKey(key)) {
Object value = delta.get(key);
return value == REMOVED ? null : value;
}
return state.get(key);
}

@Override
public int hashCode() {
return state.hashCode();
// Similar to equals, we need to calculate hash code based on the merged entry set.
int h = 0;
for (Entry<String, Object> entry : entrySet()) {
h += entry.hashCode();
}
return h;
}

@Override
public boolean isEmpty() {
return state.isEmpty();
if (delta.isEmpty()) {
return state.isEmpty();
}
// If delta is not empty, we need to check if it effectively removes everything from state
// or adds something.
return size() == 0;
}

@Override
public Set<String> keySet() {
return state.keySet();
// Snapshot view
Map<String, Object> merged = new ConcurrentHashMap<>(state);
for (Entry<String, Object> entry : delta.entrySet()) {
if (entry.getValue() == REMOVED) {
merged.remove(entry.getKey());
} else {
merged.put(entry.getKey(), entry.getValue());
}
}
return merged.keySet();
}

@Override
public Object put(String key, Object value) {
Object oldValue = state.put(key, value);
// Current value logic needs to check delta first to return correct "oldValue"
Object oldValue = get(key);
state.put(key, value);
delta.put(key, value);
return oldValue;
}

@Override
public Object putIfAbsent(String key, Object value) {
Object existingValue = state.putIfAbsent(key, value);
if (existingValue == null) {
delta.put(key, value);
Object currentValue = get(key);
if (currentValue == null) {
put(key, value);
return null;
}
return existingValue;
return currentValue;
}

@Override
Expand All @@ -123,47 +190,67 @@ public void putAll(Map<? extends String, ? extends Object> m) {

@Override
public Object remove(Object key) {
if (state.containsKey(key)) {
Object oldValue = get(key);
// We explicitly check for containment in the merged view to ensure we return the correct old
// value.
if (state.containsKey(key) || (delta.containsKey(key) && delta.get(key) != REMOVED)) {
delta.put((String) key, REMOVED);
}
return state.remove(key);

// We remove from the state map to keep it consistent with the write-through behavior of this
// class.
state.remove(key);
return oldValue;
}

@Override
public boolean remove(Object key, Object value) {
boolean removed = state.remove(key, value);
if (removed) {
delta.put((String) key, REMOVED);
Object currentValue = get(key);
if (Objects.equals(currentValue, value) && (currentValue != null || containsKey(key))) {
remove(key);
return true;
}
return removed;
return false;
}

@Override
public boolean replace(String key, Object oldValue, Object newValue) {
boolean replaced = state.replace(key, oldValue, newValue);
if (replaced) {
delta.put(key, newValue);
Object currentValue = get(key);
if (Objects.equals(currentValue, oldValue) && (currentValue != null || containsKey(key))) {
put(key, newValue);
return true;
}
return replaced;
return false;
}

@Override
public Object replace(String key, Object value) {
Object oldValue = state.replace(key, value);
if (oldValue != null) {
delta.put(key, value);
Object currentValue = get(key);
if (currentValue != null || containsKey(key)) {
put(key, value);
return currentValue;
}
return oldValue;
return null;
}

@Override
public int size() {
return state.size();
// Expensive, but accurate merged size.
return entrySet().size();
}

@Override
public Collection<Object> values() {
return state.values();
// Snapshot view
Map<String, Object> merged = new ConcurrentHashMap<>(state);
for (Entry<String, Object> entry : delta.entrySet()) {
if (entry.getValue() == REMOVED) {
merged.remove(entry.getKey());
} else {
merged.put(entry.getKey(), entry.getValue());
}
}
return merged.values();
}

public boolean hasDelta() {
Expand Down
143 changes: 143 additions & 0 deletions core/src/test/java/com/google/adk/sessions/StateDiffTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.adk.sessions;

import static com.google.common.truth.Truth.assertThat;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public final class StateDiffTest {

@Test
public void get_returnsValueFromDeltaIfPresent() {
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
stateMap.put("key", "initialValue");
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
deltaMap.put("key", "newValue");
State state = new State(stateMap, deltaMap);

assertThat(state.get("key")).isEqualTo("newValue");
}

@Test
public void get_returnsValueFromStateIfNotInDelta() {
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
stateMap.put("key", "initialValue");
State state = new State(stateMap);

assertThat(state.get("key")).isEqualTo("initialValue");
}

@Test
public void get_returnsNullIfKeyInDeltaAsRemoved() {
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
stateMap.put("key", "initialValue");
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
deltaMap.put("key", State.REMOVED);
State state = new State(stateMap, deltaMap);

assertThat(state.get("key")).isNull();
}

@Test
public void containsKey_respectsDelta() {
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
stateMap.put("key1", "value1");
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
deltaMap.put("key1", State.REMOVED);
deltaMap.put("key2", "value2");
State state = new State(stateMap, deltaMap);

assertThat(state.containsKey("key1")).isFalse();
assertThat(state.containsKey("key2")).isTrue();
}

@Test
public void size_respectsDelta() {
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
stateMap.put("key1", "value1");
stateMap.put("key2", "value2");
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
deltaMap.put("key1", State.REMOVED);
deltaMap.put("key3", "value3");
State state = new State(stateMap, deltaMap);

assertThat(state.size()).isEqualTo(2); // key2, key3
}

@Test
public void isEmpty_respectsDelta() {
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
stateMap.put("key1", "value1");
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
deltaMap.put("key1", State.REMOVED);
State state = new State(stateMap, deltaMap);

assertThat(state.isEmpty()).isTrue();
}

@Test
public void entrySet_reflectsMergedState() {
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
stateMap.put("key1", "value1");
stateMap.put("key2", "value2");
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
deltaMap.put("key1", "newValue1");
deltaMap.put("key3", "value3");
State state = new State(stateMap, deltaMap);

Map<String, Object> expected = new HashMap<>();
expected.put("key1", "newValue1");
expected.put("key2", "value2");
expected.put("key3", "value3");

assertThat(state.entrySet()).containsExactlyElementsIn(expected.entrySet());
}

@Test
public void keySet_reflectsMergedState() {
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
stateMap.put("key1", "value1");
stateMap.put("key2", "value2");
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
deltaMap.put("key1", State.REMOVED);
deltaMap.put("key3", "value3");
State state = new State(stateMap, deltaMap);

assertThat(state.keySet()).containsExactly("key2", "key3");
}

@Test
public void values_reflectsMergedState() {
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
stateMap.put("key1", "value1");
stateMap.put("key2", "value2");
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
deltaMap.put("key1", "newValue1");
deltaMap.put("key3", "value3");
State state = new State(stateMap, deltaMap);

assertThat(state.values()).containsExactly("newValue1", "value2", "value3");
}
}