diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/LiveGenerativeModel.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/LiveGenerativeModel.kt index 25088a39cd3..fc0c3adb1fe 100644 --- a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/LiveGenerativeModel.kt +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/LiveGenerativeModel.kt @@ -105,35 +105,43 @@ internal constructor( * connection with the server. */ @OptIn(ExperimentalSerializationApi::class) - public suspend fun connect(): LiveSession { - val clientMessage = - LiveClientSetupMessage( - modelName, - config?.toInternal(), - tools.map { it.toInternal() }.takeIf { it.isNotEmpty() }, - systemInstruction?.toInternal(), - config?.inputAudioTranscription?.toInternal(), - config?.outputAudioTranscription?.toInternal() - ) - .toInternal() - val data: String = Json.encodeToString(clientMessage) - var webSession: DefaultClientWebSocketSession? = null - try { - webSession = controller.getWebSocketSession(location) + public suspend fun connect(sessionResumption: SessionResumptionConfig? = null): LiveSession { + val connectFactory: suspend (SessionResumptionConfig?) -> DefaultClientWebSocketSession = { newResumption -> + val clientMessage = + LiveClientSetupMessage( + modelName, + config?.toInternal(), + tools.map { it.toInternal() }.takeIf { it.isNotEmpty() }, + systemInstruction?.toInternal(), + config?.inputAudioTranscription?.toInternal(), + config?.outputAudioTranscription?.toInternal(), + newResumption?.toInternal(), + config?.contextWindowCompression?.toInternal() + ) + .toInternal() + val data: String = Json.encodeToString(clientMessage) + val webSession = controller.getWebSocketSession(location) webSession.send(Frame.Text(data)) val receivedJsonStr = webSession.incoming.receive().readBytes().toString(Charsets.UTF_8) val receivedJson = JSON.parseToJsonElement(receivedJsonStr) - return if (receivedJson is JsonObject && "setupComplete" in receivedJson) { - LiveSession( - session = webSession, - blockingDispatcher = blockingDispatcher, - firebaseApp = firebaseApp - ) + if (receivedJson is JsonObject && "setupComplete" in receivedJson) { + webSession } else { webSession.close() throw ServiceConnectionHandshakeFailedException("Unable to connect to the server") } + } + + var webSession: DefaultClientWebSocketSession? = null + try { + webSession = connectFactory(sessionResumption) + return LiveSession( + session = webSession, + blockingDispatcher = blockingDispatcher, + firebaseApp = firebaseApp, + connectionFactory = connectFactory + ) } catch (e: ClosedReceiveChannelException) { val reason = webSession?.closeReason?.await() val message = diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ContextWindowCompressionConfig.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ContextWindowCompressionConfig.kt new file mode 100644 index 00000000000..79632f245ee --- /dev/null +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ContextWindowCompressionConfig.kt @@ -0,0 +1,58 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai.type + +import kotlinx.serialization.Serializable + +/** + * Configures the sliding window context compression mechanism. + * + * The context window will be truncated by keeping only a suffix of it. + * + * @property targetTokens The session reduction target, i.e., how many tokens we should keep. + */ +@PublicPreviewAPI +public class SlidingWindow(public val targetTokens: Int? = null) { + internal fun toInternal() = Internal(targetTokens) + + @Serializable + internal data class Internal( + val targetTokens: Int? = null + ) +} + +/** + * Enables context window compression to manage the model's context window. + * + * This mechanism prevents the context from exceeding a given length. + * + * @property triggerTokens The number of tokens (before running a turn) that triggers the context window compression. + * @property slidingWindow The sliding window compression mechanism. + */ +@PublicPreviewAPI +public class ContextWindowCompressionConfig( + public val triggerTokens: Int? = null, + public val slidingWindow: SlidingWindow? = null +) { + internal fun toInternal() = Internal(triggerTokens, slidingWindow?.toInternal()) + + @Serializable + internal data class Internal( + val triggerTokens: Int? = null, + val slidingWindow: SlidingWindow.Internal? = null + ) +} diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveClientSetupMessage.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveClientSetupMessage.kt index 856eebbdde5..1c4ee414ee0 100644 --- a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveClientSetupMessage.kt +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveClientSetupMessage.kt @@ -35,6 +35,8 @@ internal class LiveClientSetupMessage( val systemInstruction: Content.Internal?, val inputAudioTranscription: AudioTranscriptionConfig.Internal?, val outputAudioTranscription: AudioTranscriptionConfig.Internal?, + val sessionResumption: SessionResumptionConfig.Internal?, + val contextWindowCompression: ContextWindowCompressionConfig.Internal?, ) { @Serializable internal class Internal(val setup: LiveClientSetup) { @@ -46,6 +48,8 @@ internal class LiveClientSetupMessage( val systemInstruction: Content.Internal?, val inputAudioTranscription: AudioTranscriptionConfig.Internal?, val outputAudioTranscription: AudioTranscriptionConfig.Internal?, + @SerialName("session_resumption") val sessionResumption: SessionResumptionConfig.Internal? = null, + @SerialName("context_window_compression") val contextWindowCompression: ContextWindowCompressionConfig.Internal? = null, ) } @@ -57,7 +61,9 @@ internal class LiveClientSetupMessage( tools, systemInstruction, inputAudioTranscription, - outputAudioTranscription + outputAudioTranscription, + sessionResumption, + contextWindowCompression ) ) } diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveGenerationConfig.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveGenerationConfig.kt index 3e014d43162..6c84c41a9f0 100644 --- a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveGenerationConfig.kt +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveGenerationConfig.kt @@ -75,6 +75,7 @@ private constructor( internal val speechConfig: SpeechConfig?, internal val inputAudioTranscription: AudioTranscriptionConfig?, internal val outputAudioTranscription: AudioTranscriptionConfig?, + public val contextWindowCompression: ContextWindowCompressionConfig?, ) { /** @@ -102,6 +103,8 @@ private constructor( * @property inputAudioTranscription see [LiveGenerationConfig.inputAudioTranscription] * * @property outputAudioTranscription see [LiveGenerationConfig.outputAudioTranscription] + * + * @property contextWindowCompression see [LiveGenerationConfig.contextWindowCompression] */ public class Builder { @JvmField public var temperature: Float? = null @@ -114,6 +117,7 @@ private constructor( @JvmField public var speechConfig: SpeechConfig? = null @JvmField public var inputAudioTranscription: AudioTranscriptionConfig? = null @JvmField public var outputAudioTranscription: AudioTranscriptionConfig? = null + @JvmField public var contextWindowCompression: ContextWindowCompressionConfig? = null public fun setTemperature(temperature: Float?): Builder = apply { this.temperature = temperature @@ -144,6 +148,10 @@ private constructor( this.outputAudioTranscription = config } + public fun setContextWindowCompression(config: ContextWindowCompressionConfig?): Builder = apply { + this.contextWindowCompression = config + } + /** Create a new [LiveGenerationConfig] with the attached arguments. */ public fun build(): LiveGenerationConfig = LiveGenerationConfig( @@ -157,6 +165,7 @@ private constructor( responseModality = responseModality, inputAudioTranscription = inputAudioTranscription, outputAudioTranscription = outputAudioTranscription, + contextWindowCompression = contextWindowCompression, ) } diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveServerMessage.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveServerMessage.kt index 1efc06f9a52..dd13b321041 100644 --- a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveServerMessage.kt +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveServerMessage.kt @@ -35,6 +35,7 @@ import kotlinx.serialization.json.jsonObject * @see LiveServerToolCallCancellation * @see LiveServerSetupComplete * @see LiveServerGoAway + * @see LiveSessionResumptionUpdate */ @PublicPreviewAPI public interface LiveServerMessage @@ -212,6 +213,41 @@ public class LiveServerGoAway(public val timeLeft: Duration?) : LiveServerMessag } } +/** + * An update of the session resumption state. + * + * This message is only sent if [SessionResumptionConfig] was set in the session setup. + * + * @property newHandle The new handle that represents the state that can be resumed. Empty if `resumable` is false. + * @property resumable Indicates if the session can be resumed at this point. + * @property lastConsumedClientMessageIndex The index of the last client message that is included in the state represented by this update. + */ +@PublicPreviewAPI +public class LiveSessionResumptionUpdate( + public val newHandle: String? = null, + public val resumable: Boolean? = null, + public val lastConsumedClientMessageIndex: Int? = null +) : LiveServerMessage { + @Serializable + internal data class Internal( + val newHandle: String? = null, + val resumable: Boolean? = null, + val lastConsumedClientMessageIndex: Int? = null + ) + + @Serializable + internal data class InternalWrapper(val sessionResumptionUpdate: Internal) : + InternalLiveServerMessage { + override fun toPublic(): LiveSessionResumptionUpdate { + return LiveSessionResumptionUpdate( + newHandle = sessionResumptionUpdate.newHandle, + resumable = sessionResumptionUpdate.resumable, + lastConsumedClientMessageIndex = sessionResumptionUpdate.lastConsumedClientMessageIndex + ) + } + } +} + @PublicPreviewAPI @Serializable(LiveServerMessageSerializer::class) internal sealed interface InternalLiveServerMessage { @@ -233,6 +269,7 @@ internal object LiveServerMessageSerializer : "toolCallCancellation" in jsonObject -> LiveServerToolCallCancellation.InternalWrapper.serializer() "goAway" in jsonObject -> LiveServerGoAway.InternalWrapper.serializer() + "sessionResumptionUpdate" in jsonObject -> LiveSessionResumptionUpdate.InternalWrapper.serializer() else -> throw SerializationException( "Unknown LiveServerMessage response type. Keys found: ${jsonObject.keys}" diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt index 152c6b8d0a3..e20307e55e0 100644 --- a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt @@ -70,10 +70,11 @@ import kotlinx.serialization.json.Json @OptIn(ExperimentalSerializationApi::class) public class LiveSession internal constructor( - private val session: DefaultClientWebSocketSession, + private var session: DefaultClientWebSocketSession, @Blocking private val blockingDispatcher: CoroutineContext, private var audioHelper: AudioHelper? = null, private val firebaseApp: FirebaseApp, + private val connectionFactory: (suspend (SessionResumptionConfig?) -> DefaultClientWebSocketSession)? = null ) { /** * Coroutine scope that we batch data on for network related behavior. @@ -314,8 +315,17 @@ internal constructor( // TODO(b/410059569): Remove when fixed flow { while (true) { - val response = session.incoming.tryReceive() - if (response.isClosed || !startedReceiving.get()) break + val currentSession = session + val response = currentSession.incoming.tryReceive() + if (!startedReceiving.get()) break + if (response.isClosed) { + if (currentSession === session) { + break + } else { + delay(0) + continue + } + } response .getOrNull() ?.let { @@ -501,6 +511,31 @@ internal constructor( } } + /** + * Resumes an existing live session with the server. + * + * This closes the current WebSocket connection and establishes a new one using + * the same configuration (URI, headers, model, system instruction, tools, etc.) + * as the original session. + * + * @param sessionResumption The configuration for session resumption, such as the handle to the previous session state to restore. + */ + public suspend fun resumeSession(sessionResumption: SessionResumptionConfig? = null) { + if (connectionFactory == null) { + throw IllegalStateException("resumeSession is not supported on this instance.") + } + + val newSession = connectionFactory.invoke(sessionResumption) + val oldSession = session + this.session = newSession + + try { + oldSession.close() + } catch (e: Exception) { + // ignore + } + } + /** Listen to the user's microphone and send the data to the model. */ private fun recordUserAudio() { // Buffer the recording so we can keep recording while data is sent to the server diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/SessionResumptionConfig.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/SessionResumptionConfig.kt new file mode 100644 index 00000000000..cd8521d8acc --- /dev/null +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/SessionResumptionConfig.kt @@ -0,0 +1,36 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai.type + +import kotlinx.serialization.Serializable + +/** + * Configuration for the session resumption mechanism. + * + * When included in the session setup, the server will send [LiveSessionResumptionUpdate] messages. + * + * @property handle The session resumption handle of the previous session to restore. + */ +@PublicPreviewAPI +public class SessionResumptionConfig(public val handle: String? = null) { + internal fun toInternal() = Internal(handle) + + @Serializable + internal data class Internal( + val handle: String? = null + ) +}