Skip to content

Commit 5dcfc68

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Simplifying State interfaces
This is really two changes: 1. Replace the interface of state from ConcurrentMap to Map 2. Under the covers use extenral synchornization (Collections.synchronize, etc) along with HashMap wich allows nulls to represent "remove this variable from the session" This devX improvement comes with a subtle assumption that State will be passed in as a HashMap. This change may cause subtle breaking changes. PiperOrigin-RevId: 872418434
1 parent 5262d4a commit 5dcfc68

File tree

16 files changed

+274
-324
lines changed

16 files changed

+274
-324
lines changed

contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
import java.util.Optional;
4747
import java.util.Set;
4848
import java.util.UUID;
49-
import java.util.concurrent.ConcurrentHashMap;
5049
import java.util.concurrent.ConcurrentMap;
5150
import java.util.concurrent.atomic.AtomicBoolean;
5251
import java.util.regex.Matcher;
@@ -85,10 +84,19 @@ private CollectionReference getSessionsCollection(String userId) {
8584
.collection(SESSION_COLLECTION_NAME);
8685
}
8786

87+
@Override
88+
public Single<Session> createSession(
89+
String appName,
90+
String userId,
91+
@Nullable ConcurrentMap<String, Object> state,
92+
@Nullable String sessionId) {
93+
return createSession(appName, userId, (Map<String, Object>) state, sessionId);
94+
}
95+
8896
/** Creates a new session in Firestore. */
8997
@Override
9098
public Single<Session> createSession(
91-
String appName, String userId, ConcurrentMap<String, Object> state, String sessionId) {
99+
String appName, String userId, Map<String, Object> state, String sessionId) {
92100
return Single.fromCallable(
93101
() -> {
94102
Objects.requireNonNull(appName, "appName cannot be null");
@@ -100,21 +108,17 @@ public Single<Session> createSession(
100108
.filter(s -> !s.isEmpty())
101109
.orElseGet(() -> UUID.randomUUID().toString());
102110

103-
ConcurrentMap<String, Object> initialState =
104-
(state == null) ? new ConcurrentHashMap<>() : new ConcurrentHashMap<>(state);
105111
logger.info(
106112
"Creating session for userId: {} with sessionId: {} and initial state: {}",
107113
userId,
108114
resolvedSessionId,
109-
initialState);
110-
List<Event> initialEvents = new ArrayList<>();
115+
state);
111116
Instant now = Instant.now();
112117
Session newSession =
113118
Session.builder(resolvedSessionId)
114119
.appName(appName)
115120
.userId(userId)
116-
.state(initialState)
117-
.events(initialEvents)
121+
.state(state)
118122
.lastUpdateTime(now)
119123
.build();
120124

@@ -200,8 +204,7 @@ public Maybe<Session> getSession(
200204
})
201205
.map(
202206
events -> {
203-
ConcurrentMap<String, Object> state =
204-
new ConcurrentHashMap<>((Map<String, Object>) data.get(STATE_KEY));
207+
Map<String, Object> state = (Map<String, Object>) data.get(STATE_KEY);
205208
return Session.builder((String) data.get(ID_KEY))
206209
.appName((String) data.get(APP_NAME_KEY))
207210
.userId((String) data.get(USER_ID_KEY))
@@ -451,8 +454,6 @@ public Single<ListSessionsResponse> listSessions(String appName, String userId)
451454
.appName((String) data.get(APP_NAME_KEY))
452455
.userId((String) data.get(USER_ID_KEY))
453456
.lastUpdateTime(Instant.parse((String) data.get(UPDATE_TIME_KEY)))
454-
.state(new ConcurrentHashMap<>()) // Empty state
455-
.events(new ArrayList<>()) // Empty events
456457
.build();
457458
sessions.add(session);
458459
}

core/src/main/java/com/google/adk/events/EventActions.java

Lines changed: 49 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919
import com.fasterxml.jackson.annotation.JsonProperty;
2020
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
2121
import com.google.adk.JsonBaseModel;
22-
import com.google.adk.sessions.State;
2322
import com.google.errorprone.annotations.CanIgnoreReturnValue;
24-
import java.util.HashSet;
23+
import java.util.Collections;
24+
import java.util.HashMap;
25+
import java.util.Map;
2526
import java.util.Objects;
2627
import java.util.Optional;
2728
import java.util.Set;
28-
import java.util.concurrent.ConcurrentHashMap;
29-
import java.util.concurrent.ConcurrentMap;
3029
import javax.annotation.Nullable;
3130

3231
/** Represents the actions attached to an event. */
@@ -35,39 +34,37 @@
3534
public class EventActions extends JsonBaseModel {
3635

3736
private Optional<Boolean> skipSummarization;
38-
private ConcurrentMap<String, Object> stateDelta;
39-
private ConcurrentMap<String, Integer> artifactDelta;
40-
private Set<String> deletedArtifactIds;
37+
private Map<String, Object> stateDelta;
38+
private Map<String, Integer> artifactDelta;
4139
private Optional<String> transferToAgent;
4240
private Optional<Boolean> escalate;
43-
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs;
44-
private ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations;
41+
private Map<String, Map<String, Object>> requestedAuthConfigs;
42+
private Map<String, ToolConfirmation> requestedToolConfirmations;
4543
private boolean endOfAgent;
4644
private Optional<EventCompaction> compaction;
4745

4846
/** Default constructor for Jackson. */
4947
public EventActions() {
5048
this.skipSummarization = Optional.empty();
51-
this.stateDelta = new ConcurrentHashMap<>();
52-
this.artifactDelta = new ConcurrentHashMap<>();
53-
this.deletedArtifactIds = new HashSet<>();
49+
this.stateDelta = Collections.synchronizedMap(new HashMap<>());
50+
this.artifactDelta = Collections.synchronizedMap(new HashMap<>());
5451
this.transferToAgent = Optional.empty();
5552
this.escalate = Optional.empty();
56-
this.requestedAuthConfigs = new ConcurrentHashMap<>();
57-
this.requestedToolConfirmations = new ConcurrentHashMap<>();
53+
this.requestedAuthConfigs = Collections.synchronizedMap(new HashMap<>());
54+
this.requestedToolConfirmations = Collections.synchronizedMap(new HashMap<>());
5855
this.endOfAgent = false;
5956
this.compaction = Optional.empty();
6057
}
6158

6259
private EventActions(Builder builder) {
6360
this.skipSummarization = builder.skipSummarization;
64-
this.stateDelta = builder.stateDelta;
65-
this.artifactDelta = builder.artifactDelta;
66-
this.deletedArtifactIds = builder.deletedArtifactIds;
61+
this.stateDelta = Collections.synchronizedMap(builder.stateDelta);
62+
this.artifactDelta = Collections.synchronizedMap(builder.artifactDelta);
6763
this.transferToAgent = builder.transferToAgent;
6864
this.escalate = builder.escalate;
69-
this.requestedAuthConfigs = builder.requestedAuthConfigs;
70-
this.requestedToolConfirmations = builder.requestedToolConfirmations;
65+
this.requestedAuthConfigs = Collections.synchronizedMap(builder.requestedAuthConfigs);
66+
this.requestedToolConfirmations =
67+
Collections.synchronizedMap(builder.requestedToolConfirmations);
7168
this.endOfAgent = builder.endOfAgent;
7269
this.compaction = builder.compaction;
7370
}
@@ -90,41 +87,32 @@ public void setSkipSummarization(boolean skipSummarization) {
9087
}
9188

9289
@JsonProperty("stateDelta")
93-
public ConcurrentMap<String, Object> stateDelta() {
90+
public Map<String, Object> stateDelta() {
9491
return stateDelta;
9592
}
9693

97-
@Deprecated // Use stateDelta(), addState() and removeStateByKey() instead.
98-
public void setStateDelta(ConcurrentMap<String, Object> stateDelta) {
99-
this.stateDelta = stateDelta;
94+
public void setStateDelta(Map<String, Object> stateDelta) {
95+
this.stateDelta = Collections.synchronizedMap(new HashMap<>(stateDelta));
10096
}
10197

10298
/**
10399
* Removes a key from the state delta.
104100
*
105101
* @param key The key to remove.
102+
* @deprecated Use {@link #stateDelta()}.put(key, null) instead.
106103
*/
104+
@Deprecated
107105
public void removeStateByKey(String key) {
108-
stateDelta.put(key, State.REMOVED);
106+
stateDelta().put(key, null);
109107
}
110108

111109
@JsonProperty("artifactDelta")
112-
public ConcurrentMap<String, Integer> artifactDelta() {
110+
public Map<String, Integer> artifactDelta() {
113111
return artifactDelta;
114112
}
115113

116-
public void setArtifactDelta(ConcurrentMap<String, Integer> artifactDelta) {
117-
this.artifactDelta = artifactDelta;
118-
}
119-
120-
@JsonProperty("deletedArtifactIds")
121-
@JsonInclude(JsonInclude.Include.NON_EMPTY)
122-
public Set<String> deletedArtifactIds() {
123-
return deletedArtifactIds;
124-
}
125-
126-
public void setDeletedArtifactIds(Set<String> deletedArtifactIds) {
127-
this.deletedArtifactIds = deletedArtifactIds;
114+
public void setArtifactDelta(Map<String, Integer> artifactDelta) {
115+
this.artifactDelta = Collections.synchronizedMap(new HashMap<>(artifactDelta));
128116
}
129117

130118
@JsonProperty("transferToAgent")
@@ -154,23 +142,23 @@ public void setEscalate(boolean escalate) {
154142
}
155143

156144
@JsonProperty("requestedAuthConfigs")
157-
public ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs() {
145+
public Map<String, Map<String, Object>> requestedAuthConfigs() {
158146
return requestedAuthConfigs;
159147
}
160148

161-
public void setRequestedAuthConfigs(
162-
ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs) {
149+
public void setRequestedAuthConfigs(Map<String, Map<String, Object>> requestedAuthConfigs) {
163150
this.requestedAuthConfigs = requestedAuthConfigs;
164151
}
165152

166153
@JsonProperty("requestedToolConfirmations")
167-
public ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations() {
154+
public Map<String, ToolConfirmation> requestedToolConfirmations() {
168155
return requestedToolConfirmations;
169156
}
170157

171158
public void setRequestedToolConfirmations(
172-
ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations) {
173-
this.requestedToolConfirmations = requestedToolConfirmations;
159+
Map<String, ToolConfirmation> requestedToolConfirmations) {
160+
this.requestedToolConfirmations =
161+
Collections.synchronizedMap(new HashMap<>(requestedToolConfirmations));
174162
}
175163

176164
@JsonProperty("endOfAgent")
@@ -235,7 +223,6 @@ public boolean equals(Object o) {
235223
return Objects.equals(skipSummarization, that.skipSummarization)
236224
&& Objects.equals(stateDelta, that.stateDelta)
237225
&& Objects.equals(artifactDelta, that.artifactDelta)
238-
&& Objects.equals(deletedArtifactIds, that.deletedArtifactIds)
239226
&& Objects.equals(transferToAgent, that.transferToAgent)
240227
&& Objects.equals(escalate, that.escalate)
241228
&& Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs)
@@ -250,7 +237,6 @@ public int hashCode() {
250237
skipSummarization,
251238
stateDelta,
252239
artifactDelta,
253-
deletedArtifactIds,
254240
transferToAgent,
255241
escalate,
256242
requestedAuthConfigs,
@@ -262,38 +248,34 @@ public int hashCode() {
262248
/** Builder for {@link EventActions}. */
263249
public static class Builder {
264250
private Optional<Boolean> skipSummarization;
265-
private ConcurrentMap<String, Object> stateDelta;
266-
private ConcurrentMap<String, Integer> artifactDelta;
267-
private Set<String> deletedArtifactIds;
251+
private Map<String, Object> stateDelta;
252+
private Map<String, Integer> artifactDelta;
268253
private Optional<String> transferToAgent;
269254
private Optional<Boolean> escalate;
270-
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs;
271-
private ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations;
255+
private Map<String, Map<String, Object>> requestedAuthConfigs;
256+
private Map<String, ToolConfirmation> requestedToolConfirmations;
272257
private boolean endOfAgent = false;
273258
private Optional<EventCompaction> compaction;
274259

275260
public Builder() {
276261
this.skipSummarization = Optional.empty();
277-
this.stateDelta = new ConcurrentHashMap<>();
278-
this.artifactDelta = new ConcurrentHashMap<>();
279-
this.deletedArtifactIds = new HashSet<>();
262+
this.stateDelta = new HashMap<>();
263+
this.artifactDelta = new HashMap<>();
280264
this.transferToAgent = Optional.empty();
281265
this.escalate = Optional.empty();
282-
this.requestedAuthConfigs = new ConcurrentHashMap<>();
283-
this.requestedToolConfirmations = new ConcurrentHashMap<>();
266+
this.requestedAuthConfigs = new HashMap<>();
267+
this.requestedToolConfirmations = new HashMap<>();
284268
this.compaction = Optional.empty();
285269
}
286270

287271
private Builder(EventActions eventActions) {
288272
this.skipSummarization = eventActions.skipSummarization();
289-
this.stateDelta = new ConcurrentHashMap<>(eventActions.stateDelta());
290-
this.artifactDelta = new ConcurrentHashMap<>(eventActions.artifactDelta());
291-
this.deletedArtifactIds = new HashSet<>(eventActions.deletedArtifactIds());
273+
this.stateDelta = new HashMap<>(eventActions.stateDelta());
274+
this.artifactDelta = new HashMap<>(eventActions.artifactDelta());
292275
this.transferToAgent = eventActions.transferToAgent();
293276
this.escalate = eventActions.escalate();
294-
this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs());
295-
this.requestedToolConfirmations =
296-
new ConcurrentHashMap<>(eventActions.requestedToolConfirmations());
277+
this.requestedAuthConfigs = new HashMap<>(eventActions.requestedAuthConfigs());
278+
this.requestedToolConfirmations = new HashMap<>(eventActions.requestedToolConfirmations());
297279
this.endOfAgent = eventActions.endOfAgent();
298280
this.compaction = eventActions.compaction();
299281
}
@@ -307,22 +289,22 @@ public Builder skipSummarization(boolean skipSummarization) {
307289

308290
@CanIgnoreReturnValue
309291
@JsonProperty("stateDelta")
310-
public Builder stateDelta(ConcurrentMap<String, Object> value) {
292+
public Builder stateDelta(Map<String, Object> value) {
311293
this.stateDelta = value;
312294
return this;
313295
}
314296

315297
@CanIgnoreReturnValue
316298
@JsonProperty("artifactDelta")
317-
public Builder artifactDelta(ConcurrentMap<String, Integer> value) {
299+
public Builder artifactDelta(Map<String, Integer> value) {
318300
this.artifactDelta = value;
319301
return this;
320302
}
321303

322304
@CanIgnoreReturnValue
323305
@JsonProperty("deletedArtifactIds")
324306
public Builder deletedArtifactIds(Set<String> value) {
325-
this.deletedArtifactIds = value;
307+
value.forEach(v -> artifactDelta.put(v, null));
326308
return this;
327309
}
328310

@@ -342,16 +324,15 @@ public Builder escalate(boolean escalate) {
342324

343325
@CanIgnoreReturnValue
344326
@JsonProperty("requestedAuthConfigs")
345-
public Builder requestedAuthConfigs(
346-
ConcurrentMap<String, ConcurrentMap<String, Object>> value) {
327+
public Builder requestedAuthConfigs(Map<String, Map<String, Object>> value) {
347328
this.requestedAuthConfigs = value;
348329
return this;
349330
}
350331

351332
@CanIgnoreReturnValue
352333
@JsonProperty("requestedToolConfirmations")
353-
public Builder requestedToolConfirmations(ConcurrentMap<String, ToolConfirmation> value) {
354-
this.requestedToolConfirmations = value;
334+
public Builder requestedToolConfirmations(Map<String, ToolConfirmation> value) {
335+
this.requestedToolConfirmations = Collections.synchronizedMap(new HashMap<>(value));
355336
return this;
356337
}
357338

@@ -385,7 +366,6 @@ public Builder merge(EventActions other) {
385366
other.skipSummarization().ifPresent(this::skipSummarization);
386367
this.stateDelta.putAll(other.stateDelta());
387368
this.artifactDelta.putAll(other.artifactDelta());
388-
this.deletedArtifactIds.addAll(other.deletedArtifactIds());
389369
other.transferToAgent().ifPresent(this::transferToAgent);
390370
other.escalate().ifPresent(this::escalate);
391371
this.requestedAuthConfigs.putAll(other.requestedAuthConfigs());

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import java.util.ArrayList;
5959
import java.util.Arrays;
6060
import java.util.Collections;
61+
import java.util.HashMap;
6162
import java.util.List;
6263
import java.util.Map;
6364
import java.util.Optional;
@@ -337,7 +338,9 @@ private Single<Event> appendNewMessageToSession(
337338
// Add state delta if provided
338339
if (stateDelta != null && !stateDelta.isEmpty()) {
339340
eventBuilder.actions(
340-
EventActions.builder().stateDelta(new ConcurrentHashMap<>(stateDelta)).build());
341+
EventActions.builder()
342+
.stateDelta(stateDelta == null ? new HashMap<>() : new HashMap<>(stateDelta))
343+
.build());
341344
}
342345

343346
return this.sessionService.appendEvent(session, eventBuilder.build());

0 commit comments

Comments
 (0)