Skip to content

Commit 475c3d3

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Simplifying State interfaces
This is really two changes: 1. Replace the interface of state from ConcurrentMap to Map 2. Under the covers use extenral synchornization (Collections.synchronize, etc) along with HashMap wich allows nulls to represent "remove this variable from the session" This devX improvement comes with a subtle assumption that State will be passed in as a HashMap. This change may cause subtle breaking changes. PiperOrigin-RevId: 872418434
1 parent 5262d4a commit 475c3d3

File tree

13 files changed

+280
-330
lines changed

13 files changed

+280
-330
lines changed

contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
import java.util.Optional;
4747
import java.util.Set;
4848
import java.util.UUID;
49-
import java.util.concurrent.ConcurrentHashMap;
5049
import java.util.concurrent.ConcurrentMap;
5150
import java.util.concurrent.atomic.AtomicBoolean;
5251
import java.util.regex.Matcher;
@@ -85,10 +84,19 @@ private CollectionReference getSessionsCollection(String userId) {
8584
.collection(SESSION_COLLECTION_NAME);
8685
}
8786

87+
@Override
88+
public Single<Session> createSession(
89+
String appName,
90+
String userId,
91+
@Nullable ConcurrentMap<String, Object> state,
92+
@Nullable String sessionId) {
93+
return createSession(appName, userId, (Map<String, Object>) state, sessionId);
94+
}
95+
8896
/** Creates a new session in Firestore. */
8997
@Override
9098
public Single<Session> createSession(
91-
String appName, String userId, ConcurrentMap<String, Object> state, String sessionId) {
99+
String appName, String userId, Map<String, Object> state, String sessionId) {
92100
return Single.fromCallable(
93101
() -> {
94102
Objects.requireNonNull(appName, "appName cannot be null");
@@ -100,21 +108,17 @@ public Single<Session> createSession(
100108
.filter(s -> !s.isEmpty())
101109
.orElseGet(() -> UUID.randomUUID().toString());
102110

103-
ConcurrentMap<String, Object> initialState =
104-
(state == null) ? new ConcurrentHashMap<>() : new ConcurrentHashMap<>(state);
105111
logger.info(
106112
"Creating session for userId: {} with sessionId: {} and initial state: {}",
107113
userId,
108114
resolvedSessionId,
109-
initialState);
110-
List<Event> initialEvents = new ArrayList<>();
115+
state);
111116
Instant now = Instant.now();
112117
Session newSession =
113118
Session.builder(resolvedSessionId)
114119
.appName(appName)
115120
.userId(userId)
116-
.state(initialState)
117-
.events(initialEvents)
121+
.state(state)
118122
.lastUpdateTime(now)
119123
.build();
120124

@@ -200,8 +204,7 @@ public Maybe<Session> getSession(
200204
})
201205
.map(
202206
events -> {
203-
ConcurrentMap<String, Object> state =
204-
new ConcurrentHashMap<>((Map<String, Object>) data.get(STATE_KEY));
207+
Map<String, Object> state = (Map<String, Object>) data.get(STATE_KEY);
205208
return Session.builder((String) data.get(ID_KEY))
206209
.appName((String) data.get(APP_NAME_KEY))
207210
.userId((String) data.get(USER_ID_KEY))
@@ -451,8 +454,6 @@ public Single<ListSessionsResponse> listSessions(String appName, String userId)
451454
.appName((String) data.get(APP_NAME_KEY))
452455
.userId((String) data.get(USER_ID_KEY))
453456
.lastUpdateTime(Instant.parse((String) data.get(UPDATE_TIME_KEY)))
454-
.state(new ConcurrentHashMap<>()) // Empty state
455-
.events(new ArrayList<>()) // Empty events
456457
.build();
457458
sessions.add(session);
458459
}

core/src/main/java/com/google/adk/events/EventActions.java

Lines changed: 50 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919
import com.fasterxml.jackson.annotation.JsonProperty;
2020
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
2121
import com.google.adk.JsonBaseModel;
22-
import com.google.adk.sessions.State;
2322
import com.google.errorprone.annotations.CanIgnoreReturnValue;
24-
import java.util.HashSet;
23+
import java.util.Collections;
24+
import java.util.HashMap;
25+
import java.util.Map;
2526
import java.util.Objects;
2627
import java.util.Optional;
2728
import java.util.Set;
28-
import java.util.concurrent.ConcurrentHashMap;
29-
import java.util.concurrent.ConcurrentMap;
3029
import javax.annotation.Nullable;
3130

3231
/** Represents the actions attached to an event. */
@@ -35,35 +34,32 @@
3534
public class EventActions extends JsonBaseModel {
3635

3736
private Optional<Boolean> skipSummarization;
38-
private ConcurrentMap<String, Object> stateDelta;
39-
private ConcurrentMap<String, Integer> artifactDelta;
40-
private Set<String> deletedArtifactIds;
37+
private Map<String, Object> stateDelta;
38+
private Map<String, Integer> artifactDelta;
4139
private Optional<String> transferToAgent;
4240
private Optional<Boolean> escalate;
43-
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs;
44-
private ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations;
41+
private Map<String, Map<String, Object>> requestedAuthConfigs;
42+
private Map<String, ToolConfirmation> requestedToolConfirmations;
4543
private boolean endOfAgent;
4644
private Optional<EventCompaction> compaction;
4745

4846
/** Default constructor for Jackson. */
4947
public EventActions() {
5048
this.skipSummarization = Optional.empty();
51-
this.stateDelta = new ConcurrentHashMap<>();
52-
this.artifactDelta = new ConcurrentHashMap<>();
53-
this.deletedArtifactIds = new HashSet<>();
49+
this.stateDelta = Collections.synchronizedMap(new HashMap<>());
50+
this.artifactDelta = Collections.synchronizedMap(new HashMap<>());
5451
this.transferToAgent = Optional.empty();
5552
this.escalate = Optional.empty();
56-
this.requestedAuthConfigs = new ConcurrentHashMap<>();
57-
this.requestedToolConfirmations = new ConcurrentHashMap<>();
53+
this.requestedAuthConfigs = Collections.synchronizedMap(new HashMap<>());
54+
this.requestedToolConfirmations = Collections.synchronizedMap(new HashMap<>());
5855
this.endOfAgent = false;
5956
this.compaction = Optional.empty();
6057
}
6158

6259
private EventActions(Builder builder) {
6360
this.skipSummarization = builder.skipSummarization;
64-
this.stateDelta = builder.stateDelta;
65-
this.artifactDelta = builder.artifactDelta;
66-
this.deletedArtifactIds = builder.deletedArtifactIds;
61+
this.stateDelta = Collections.synchronizedMap(builder.stateDelta);
62+
this.artifactDelta = Collections.synchronizedMap(builder.artifactDelta);
6763
this.transferToAgent = builder.transferToAgent;
6864
this.escalate = builder.escalate;
6965
this.requestedAuthConfigs = builder.requestedAuthConfigs;
@@ -90,41 +86,37 @@ public void setSkipSummarization(boolean skipSummarization) {
9086
}
9187

9288
@JsonProperty("stateDelta")
93-
public ConcurrentMap<String, Object> stateDelta() {
89+
public Map<String, Object> stateDelta() {
9490
return stateDelta;
9591
}
9692

93+
/**
94+
* @deprecated Use {@link #stateDelta()} or {@link Builder#stateDelta(Map<String, Object>)}
95+
* instead.
96+
*/
9797
@Deprecated // Use stateDelta(), addState() and removeStateByKey() instead.
98-
public void setStateDelta(ConcurrentMap<String, Object> stateDelta) {
99-
this.stateDelta = stateDelta;
98+
public void setStateDelta(Map<String, Object> stateDelta) {
99+
this.stateDelta = Collections.synchronizedMap(new HashMap<>(stateDelta));
100100
}
101101

102102
/**
103103
* Removes a key from the state delta.
104104
*
105105
* @param key The key to remove.
106+
* @deprecated Use {@link #stateDelta()}.put(key, null) instead.
106107
*/
108+
@Deprecated
107109
public void removeStateByKey(String key) {
108-
stateDelta.put(key, State.REMOVED);
110+
stateDelta().put(key, null);
109111
}
110112

111113
@JsonProperty("artifactDelta")
112-
public ConcurrentMap<String, Integer> artifactDelta() {
114+
public Map<String, Integer> artifactDelta() {
113115
return artifactDelta;
114116
}
115117

116-
public void setArtifactDelta(ConcurrentMap<String, Integer> artifactDelta) {
117-
this.artifactDelta = artifactDelta;
118-
}
119-
120-
@JsonProperty("deletedArtifactIds")
121-
@JsonInclude(JsonInclude.Include.NON_EMPTY)
122-
public Set<String> deletedArtifactIds() {
123-
return deletedArtifactIds;
124-
}
125-
126-
public void setDeletedArtifactIds(Set<String> deletedArtifactIds) {
127-
this.deletedArtifactIds = deletedArtifactIds;
118+
public void setArtifactDelta(Map<String, Integer> artifactDelta) {
119+
this.artifactDelta = Collections.synchronizedMap(new HashMap<>(artifactDelta));
128120
}
129121

130122
@JsonProperty("transferToAgent")
@@ -154,23 +146,23 @@ public void setEscalate(boolean escalate) {
154146
}
155147

156148
@JsonProperty("requestedAuthConfigs")
157-
public ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs() {
149+
public Map<String, Map<String, Object>> requestedAuthConfigs() {
158150
return requestedAuthConfigs;
159151
}
160152

161-
public void setRequestedAuthConfigs(
162-
ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs) {
153+
public void setRequestedAuthConfigs(Map<String, Map<String, Object>> requestedAuthConfigs) {
163154
this.requestedAuthConfigs = requestedAuthConfigs;
164155
}
165156

166157
@JsonProperty("requestedToolConfirmations")
167-
public ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations() {
158+
public Map<String, ToolConfirmation> requestedToolConfirmations() {
168159
return requestedToolConfirmations;
169160
}
170161

171162
public void setRequestedToolConfirmations(
172-
ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations) {
173-
this.requestedToolConfirmations = requestedToolConfirmations;
163+
Map<String, ToolConfirmation> requestedToolConfirmations) {
164+
this.requestedToolConfirmations =
165+
Collections.synchronizedMap(new HashMap<>(requestedToolConfirmations));
174166
}
175167

176168
@JsonProperty("endOfAgent")
@@ -235,7 +227,6 @@ public boolean equals(Object o) {
235227
return Objects.equals(skipSummarization, that.skipSummarization)
236228
&& Objects.equals(stateDelta, that.stateDelta)
237229
&& Objects.equals(artifactDelta, that.artifactDelta)
238-
&& Objects.equals(deletedArtifactIds, that.deletedArtifactIds)
239230
&& Objects.equals(transferToAgent, that.transferToAgent)
240231
&& Objects.equals(escalate, that.escalate)
241232
&& Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs)
@@ -250,7 +241,6 @@ public int hashCode() {
250241
skipSummarization,
251242
stateDelta,
252243
artifactDelta,
253-
deletedArtifactIds,
254244
transferToAgent,
255245
escalate,
256246
requestedAuthConfigs,
@@ -262,38 +252,34 @@ public int hashCode() {
262252
/** Builder for {@link EventActions}. */
263253
public static class Builder {
264254
private Optional<Boolean> skipSummarization;
265-
private ConcurrentMap<String, Object> stateDelta;
266-
private ConcurrentMap<String, Integer> artifactDelta;
267-
private Set<String> deletedArtifactIds;
255+
private Map<String, Object> stateDelta;
256+
private Map<String, Integer> artifactDelta;
268257
private Optional<String> transferToAgent;
269258
private Optional<Boolean> escalate;
270-
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs;
271-
private ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations;
259+
private Map<String, Map<String, Object>> requestedAuthConfigs;
260+
private Map<String, ToolConfirmation> requestedToolConfirmations;
272261
private boolean endOfAgent = false;
273262
private Optional<EventCompaction> compaction;
274263

275264
public Builder() {
276265
this.skipSummarization = Optional.empty();
277-
this.stateDelta = new ConcurrentHashMap<>();
278-
this.artifactDelta = new ConcurrentHashMap<>();
279-
this.deletedArtifactIds = new HashSet<>();
266+
this.stateDelta = new HashMap<>();
267+
this.artifactDelta = new HashMap<>();
280268
this.transferToAgent = Optional.empty();
281269
this.escalate = Optional.empty();
282-
this.requestedAuthConfigs = new ConcurrentHashMap<>();
283-
this.requestedToolConfirmations = new ConcurrentHashMap<>();
270+
this.requestedAuthConfigs = new HashMap<>();
271+
this.requestedToolConfirmations = new HashMap<>();
284272
this.compaction = Optional.empty();
285273
}
286274

287275
private Builder(EventActions eventActions) {
288276
this.skipSummarization = eventActions.skipSummarization();
289-
this.stateDelta = new ConcurrentHashMap<>(eventActions.stateDelta());
290-
this.artifactDelta = new ConcurrentHashMap<>(eventActions.artifactDelta());
291-
this.deletedArtifactIds = new HashSet<>(eventActions.deletedArtifactIds());
277+
this.stateDelta = new HashMap<>(eventActions.stateDelta());
278+
this.artifactDelta = new HashMap<>(eventActions.artifactDelta());
292279
this.transferToAgent = eventActions.transferToAgent();
293280
this.escalate = eventActions.escalate();
294-
this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs());
295-
this.requestedToolConfirmations =
296-
new ConcurrentHashMap<>(eventActions.requestedToolConfirmations());
281+
this.requestedAuthConfigs = new HashMap<>(eventActions.requestedAuthConfigs());
282+
this.requestedToolConfirmations = new HashMap<>(eventActions.requestedToolConfirmations());
297283
this.endOfAgent = eventActions.endOfAgent();
298284
this.compaction = eventActions.compaction();
299285
}
@@ -307,22 +293,22 @@ public Builder skipSummarization(boolean skipSummarization) {
307293

308294
@CanIgnoreReturnValue
309295
@JsonProperty("stateDelta")
310-
public Builder stateDelta(ConcurrentMap<String, Object> value) {
296+
public Builder stateDelta(Map<String, Object> value) {
311297
this.stateDelta = value;
312298
return this;
313299
}
314300

315301
@CanIgnoreReturnValue
316302
@JsonProperty("artifactDelta")
317-
public Builder artifactDelta(ConcurrentMap<String, Integer> value) {
303+
public Builder artifactDelta(Map<String, Integer> value) {
318304
this.artifactDelta = value;
319305
return this;
320306
}
321307

322308
@CanIgnoreReturnValue
323309
@JsonProperty("deletedArtifactIds")
324310
public Builder deletedArtifactIds(Set<String> value) {
325-
this.deletedArtifactIds = value;
311+
value.forEach(v -> artifactDelta.put(v, null));
326312
return this;
327313
}
328314

@@ -342,16 +328,15 @@ public Builder escalate(boolean escalate) {
342328

343329
@CanIgnoreReturnValue
344330
@JsonProperty("requestedAuthConfigs")
345-
public Builder requestedAuthConfigs(
346-
ConcurrentMap<String, ConcurrentMap<String, Object>> value) {
331+
public Builder requestedAuthConfigs(Map<String, Map<String, Object>> value) {
347332
this.requestedAuthConfigs = value;
348333
return this;
349334
}
350335

351336
@CanIgnoreReturnValue
352337
@JsonProperty("requestedToolConfirmations")
353-
public Builder requestedToolConfirmations(ConcurrentMap<String, ToolConfirmation> value) {
354-
this.requestedToolConfirmations = value;
338+
public Builder requestedToolConfirmations(Map<String, ToolConfirmation> value) {
339+
this.requestedToolConfirmations = Collections.synchronizedMap(new HashMap<>(value));
355340
return this;
356341
}
357342

@@ -385,7 +370,6 @@ public Builder merge(EventActions other) {
385370
other.skipSummarization().ifPresent(this::skipSummarization);
386371
this.stateDelta.putAll(other.stateDelta());
387372
this.artifactDelta.putAll(other.artifactDelta());
388-
this.deletedArtifactIds.addAll(other.deletedArtifactIds());
389373
other.transferToAgent().ifPresent(this::transferToAgent);
390374
other.escalate().ifPresent(this::escalate);
391375
this.requestedAuthConfigs.putAll(other.requestedAuthConfigs());

0 commit comments

Comments
 (0)