Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,35 +105,43 @@ internal constructor(
* connection with the server.
*/
@OptIn(ExperimentalSerializationApi::class)
public suspend fun connect(): LiveSession {
val clientMessage =
LiveClientSetupMessage(
modelName,
config?.toInternal(),
tools.map { it.toInternal() }.takeIf { it.isNotEmpty() },
systemInstruction?.toInternal(),
config?.inputAudioTranscription?.toInternal(),
config?.outputAudioTranscription?.toInternal()
)
.toInternal()
val data: String = Json.encodeToString(clientMessage)
var webSession: DefaultClientWebSocketSession? = null
try {
webSession = controller.getWebSocketSession(location)
public suspend fun connect(sessionResumption: SessionResumptionConfig? = null): LiveSession {
val connectFactory: suspend (SessionResumptionConfig?) -> DefaultClientWebSocketSession = { newResumption ->
val clientMessage =
LiveClientSetupMessage(
modelName,
config?.toInternal(),
tools.map { it.toInternal() }.takeIf { it.isNotEmpty() },
systemInstruction?.toInternal(),
config?.inputAudioTranscription?.toInternal(),
config?.outputAudioTranscription?.toInternal(),
newResumption?.toInternal(),
config?.contextWindowCompression?.toInternal()
)
.toInternal()
val data: String = Json.encodeToString(clientMessage)
val webSession = controller.getWebSocketSession(location)
webSession.send(Frame.Text(data))
val receivedJsonStr = webSession.incoming.receive().readBytes().toString(Charsets.UTF_8)
val receivedJson = JSON.parseToJsonElement(receivedJsonStr)

return if (receivedJson is JsonObject && "setupComplete" in receivedJson) {
LiveSession(
session = webSession,
blockingDispatcher = blockingDispatcher,
firebaseApp = firebaseApp
)
if (receivedJson is JsonObject && "setupComplete" in receivedJson) {
webSession
} else {
webSession.close()
throw ServiceConnectionHandshakeFailedException("Unable to connect to the server")
}
}

var webSession: DefaultClientWebSocketSession? = null
try {
webSession = connectFactory(sessionResumption)
return LiveSession(
session = webSession,
blockingDispatcher = blockingDispatcher,
firebaseApp = firebaseApp,
connectionFactory = connectFactory
)
} catch (e: ClosedReceiveChannelException) {
val reason = webSession?.closeReason?.await()
val message =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright 2026 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.ai.type

import kotlinx.serialization.Serializable

/**
* Configures the sliding window context compression mechanism.
*
* The context window will be truncated by keeping only a suffix of it.
*
* @property targetTokens The session reduction target, i.e., how many tokens we should keep.
*/
@PublicPreviewAPI
public class SlidingWindow(public val targetTokens: Int? = null) {
internal fun toInternal() = Internal(targetTokens)

@Serializable
internal data class Internal(
val targetTokens: Int? = null
)
}

/**
* Enables context window compression to manage the model's context window.
*
* This mechanism prevents the context from exceeding a given length.
*
* @property triggerTokens The number of tokens (before running a turn) that triggers the context window compression.
* @property slidingWindow The sliding window compression mechanism.
*/
@PublicPreviewAPI
public class ContextWindowCompressionConfig(
public val triggerTokens: Int? = null,
public val slidingWindow: SlidingWindow? = null
) {
internal fun toInternal() = Internal(triggerTokens, slidingWindow?.toInternal())

@Serializable
internal data class Internal(
val triggerTokens: Int? = null,
val slidingWindow: SlidingWindow.Internal? = null
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ internal class LiveClientSetupMessage(
val systemInstruction: Content.Internal?,
val inputAudioTranscription: AudioTranscriptionConfig.Internal?,
val outputAudioTranscription: AudioTranscriptionConfig.Internal?,
val sessionResumption: SessionResumptionConfig.Internal?,
val contextWindowCompression: ContextWindowCompressionConfig.Internal?,
) {
@Serializable
internal class Internal(val setup: LiveClientSetup) {
Expand All @@ -46,6 +48,8 @@ internal class LiveClientSetupMessage(
val systemInstruction: Content.Internal?,
val inputAudioTranscription: AudioTranscriptionConfig.Internal?,
val outputAudioTranscription: AudioTranscriptionConfig.Internal?,
@SerialName("session_resumption") val sessionResumption: SessionResumptionConfig.Internal? = null,
@SerialName("context_window_compression") val contextWindowCompression: ContextWindowCompressionConfig.Internal? = null,
)
}

Expand All @@ -57,7 +61,9 @@ internal class LiveClientSetupMessage(
tools,
systemInstruction,
inputAudioTranscription,
outputAudioTranscription
outputAudioTranscription,
sessionResumption,
contextWindowCompression
)
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ private constructor(
internal val speechConfig: SpeechConfig?,
internal val inputAudioTranscription: AudioTranscriptionConfig?,
internal val outputAudioTranscription: AudioTranscriptionConfig?,
public val contextWindowCompression: ContextWindowCompressionConfig?,
) {

/**
Expand Down Expand Up @@ -102,6 +103,8 @@ private constructor(
* @property inputAudioTranscription see [LiveGenerationConfig.inputAudioTranscription]
*
* @property outputAudioTranscription see [LiveGenerationConfig.outputAudioTranscription]
*
* @property contextWindowCompression see [LiveGenerationConfig.contextWindowCompression]
*/
public class Builder {
@JvmField public var temperature: Float? = null
Expand All @@ -114,6 +117,7 @@ private constructor(
@JvmField public var speechConfig: SpeechConfig? = null
@JvmField public var inputAudioTranscription: AudioTranscriptionConfig? = null
@JvmField public var outputAudioTranscription: AudioTranscriptionConfig? = null
@JvmField public var contextWindowCompression: ContextWindowCompressionConfig? = null

public fun setTemperature(temperature: Float?): Builder = apply {
this.temperature = temperature
Expand Down Expand Up @@ -144,6 +148,10 @@ private constructor(
this.outputAudioTranscription = config
}

public fun setContextWindowCompression(config: ContextWindowCompressionConfig?): Builder = apply {
this.contextWindowCompression = config
}

/** Create a new [LiveGenerationConfig] with the attached arguments. */
public fun build(): LiveGenerationConfig =
LiveGenerationConfig(
Expand All @@ -157,6 +165,7 @@ private constructor(
responseModality = responseModality,
inputAudioTranscription = inputAudioTranscription,
outputAudioTranscription = outputAudioTranscription,
contextWindowCompression = contextWindowCompression,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import kotlinx.serialization.json.jsonObject
* @see LiveServerToolCallCancellation
* @see LiveServerSetupComplete
* @see LiveServerGoAway
* @see LiveSessionResumptionUpdate
*/
@PublicPreviewAPI public interface LiveServerMessage

Expand Down Expand Up @@ -212,6 +213,41 @@ public class LiveServerGoAway(public val timeLeft: Duration?) : LiveServerMessag
}
}

/**
* An update of the session resumption state.
*
* This message is only sent if [SessionResumptionConfig] was set in the session setup.
*
* @property newHandle The new handle that represents the state that can be resumed. Empty if `resumable` is false.
* @property resumable Indicates if the session can be resumed at this point.
* @property lastConsumedClientMessageIndex The index of the last client message that is included in the state represented by this update.
*/
@PublicPreviewAPI
public class LiveSessionResumptionUpdate(
public val newHandle: String? = null,
public val resumable: Boolean? = null,
public val lastConsumedClientMessageIndex: Int? = null
) : LiveServerMessage {
@Serializable
internal data class Internal(
val newHandle: String? = null,
val resumable: Boolean? = null,
val lastConsumedClientMessageIndex: Int? = null
)

@Serializable
internal data class InternalWrapper(val sessionResumptionUpdate: Internal) :
InternalLiveServerMessage {
override fun toPublic(): LiveSessionResumptionUpdate {
return LiveSessionResumptionUpdate(
newHandle = sessionResumptionUpdate.newHandle,
resumable = sessionResumptionUpdate.resumable,
lastConsumedClientMessageIndex = sessionResumptionUpdate.lastConsumedClientMessageIndex
)
}
}
}

@PublicPreviewAPI
@Serializable(LiveServerMessageSerializer::class)
internal sealed interface InternalLiveServerMessage {
Expand All @@ -233,6 +269,7 @@ internal object LiveServerMessageSerializer :
"toolCallCancellation" in jsonObject ->
LiveServerToolCallCancellation.InternalWrapper.serializer()
"goAway" in jsonObject -> LiveServerGoAway.InternalWrapper.serializer()
"sessionResumptionUpdate" in jsonObject -> LiveSessionResumptionUpdate.InternalWrapper.serializer()
else ->
throw SerializationException(
"Unknown LiveServerMessage response type. Keys found: ${jsonObject.keys}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@ import kotlinx.serialization.json.Json
@OptIn(ExperimentalSerializationApi::class)
public class LiveSession
internal constructor(
private val session: DefaultClientWebSocketSession,
private var session: DefaultClientWebSocketSession,
@Blocking private val blockingDispatcher: CoroutineContext,
private var audioHelper: AudioHelper? = null,
private val firebaseApp: FirebaseApp,
private val connectionFactory: (suspend (SessionResumptionConfig?) -> DefaultClientWebSocketSession)? = null
) {
/**
* Coroutine scope that we batch data on for network related behavior.
Expand Down Expand Up @@ -314,8 +315,17 @@ internal constructor(
// TODO(b/410059569): Remove when fixed
flow {
while (true) {
val response = session.incoming.tryReceive()
if (response.isClosed || !startedReceiving.get()) break
val currentSession = session
val response = currentSession.incoming.tryReceive()
if (!startedReceiving.get()) break
if (response.isClosed) {
if (currentSession === session) {
break
} else {
delay(0)
continue
}
}
response
.getOrNull()
?.let {
Expand Down Expand Up @@ -501,6 +511,31 @@ internal constructor(
}
}

/**
* Resumes an existing live session with the server.
*
* This closes the current WebSocket connection and establishes a new one using
* the same configuration (URI, headers, model, system instruction, tools, etc.)
* as the original session.
*
* @param sessionResumption The configuration for session resumption, such as the handle to the previous session state to restore.
*/
public suspend fun resumeSession(sessionResumption: SessionResumptionConfig? = null) {
if (connectionFactory == null) {
throw IllegalStateException("resumeSession is not supported on this instance.")
}

val newSession = connectionFactory.invoke(sessionResumption)
val oldSession = session
this.session = newSession

try {
oldSession.close()
} catch (e: Exception) {
// ignore
}
}

/** Listen to the user's microphone and send the data to the model. */
private fun recordUserAudio() {
// Buffer the recording so we can keep recording while data is sent to the server
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2026 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.ai.type

import kotlinx.serialization.Serializable

/**
* Configuration for the session resumption mechanism.
*
* When included in the session setup, the server will send [LiveSessionResumptionUpdate] messages.
*
* @property handle The session resumption handle of the previous session to restore.
*/
@PublicPreviewAPI
public class SessionResumptionConfig(public val handle: String? = null) {
internal fun toInternal() = Internal(handle)

@Serializable
internal data class Internal(
val handle: String? = null
)
}
Loading