Add support for queuing status updates in cloud language model provider (#29818)
This sets us up to display queue position information to the user, once our language model backend is updated to support request queuing. The JSON returned by the LLM backend will need to look like this: ```json {"queue": {"status": "queued", "position": 1}} {"queue": {"status": "started"}} {"event": {"THE_UPSTREAM_MODEL_PROVIDER_EVENT": "..."}} ``` Release Notes: - N/A --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
4d1df7bcd7
commit
04772bf17d
9 changed files with 492 additions and 430 deletions
|
@ -24,7 +24,10 @@ use schemars::JsonSchema;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{
|
||||
Arc,
|
||||
atomic::{self, AtomicU64},
|
||||
};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
|
@ -371,7 +374,7 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
let response = request
|
||||
.await
|
||||
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
|
||||
Ok(map_to_language_model_completion_events(response))
|
||||
Ok(GoogleEventMapper::new().map_stream(response))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
@ -486,108 +489,98 @@ pub fn into_google(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
pub struct GoogleEventMapper {
|
||||
usage: UsageMetadata,
|
||||
stop_reason: StopReason,
|
||||
}
|
||||
|
||||
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
struct State {
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
|
||||
usage: UsageMetadata,
|
||||
stop_reason: StopReason,
|
||||
}
|
||||
|
||||
futures::stream::unfold(
|
||||
State {
|
||||
events,
|
||||
impl GoogleEventMapper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
usage: UsageMetadata::default(),
|
||||
stop_reason: StopReason::EndTurn,
|
||||
},
|
||||
|mut state| async move {
|
||||
if let Some(event) = state.events.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
let mut events: Vec<_> = Vec::new();
|
||||
let mut wants_to_use_tool = false;
|
||||
if let Some(usage_metadata) = event.usage_metadata {
|
||||
update_usage(&mut state.usage, &usage_metadata);
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
convert_usage(&state.usage),
|
||||
)))
|
||||
}
|
||||
if let Some(candidates) = event.candidates {
|
||||
for candidate in candidates {
|
||||
if let Some(finish_reason) = candidate.finish_reason.as_deref() {
|
||||
state.stop_reason = match finish_reason {
|
||||
"STOP" => StopReason::EndTurn,
|
||||
"MAX_TOKENS" => StopReason::MaxTokens,
|
||||
_ => {
|
||||
log::error!(
|
||||
"Unexpected google finish_reason: {finish_reason}"
|
||||
);
|
||||
StopReason::EndTurn
|
||||
}
|
||||
};
|
||||
}
|
||||
candidate
|
||||
.content
|
||||
.parts
|
||||
.into_iter()
|
||||
.for_each(|part| match part {
|
||||
Part::TextPart(text_part) => events.push(Ok(
|
||||
LanguageModelCompletionEvent::Text(text_part.text),
|
||||
)),
|
||||
Part::InlineDataPart(_) => {}
|
||||
Part::FunctionCallPart(function_call_part) => {
|
||||
wants_to_use_tool = true;
|
||||
let name: Arc<str> =
|
||||
function_call_part.function_call.name.into();
|
||||
let next_tool_id =
|
||||
TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
|
||||
let id: LanguageModelToolUseId =
|
||||
format!("{}-{}", name, next_tool_id).into();
|
||||
}
|
||||
}
|
||||
|
||||
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id,
|
||||
name,
|
||||
is_input_complete: true,
|
||||
raw_input: function_call_part
|
||||
.function_call
|
||||
.args
|
||||
.to_string(),
|
||||
input: function_call_part.function_call.args,
|
||||
},
|
||||
)));
|
||||
}
|
||||
Part::FunctionResponsePart(_) => {}
|
||||
});
|
||||
}
|
||||
}
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Ok(event) => self.map_event(event),
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Even when Gemini wants to use a Tool, the API
|
||||
// responds with `finish_reason: STOP`
|
||||
if wants_to_use_tool {
|
||||
state.stop_reason = StopReason::ToolUse;
|
||||
pub fn map_event(
|
||||
&mut self,
|
||||
event: GenerateContentResponse,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
let mut events: Vec<_> = Vec::new();
|
||||
let mut wants_to_use_tool = false;
|
||||
if let Some(usage_metadata) = event.usage_metadata {
|
||||
update_usage(&mut self.usage, &usage_metadata);
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
convert_usage(&self.usage),
|
||||
)))
|
||||
}
|
||||
if let Some(candidates) = event.candidates {
|
||||
for candidate in candidates {
|
||||
if let Some(finish_reason) = candidate.finish_reason.as_deref() {
|
||||
self.stop_reason = match finish_reason {
|
||||
"STOP" => StopReason::EndTurn,
|
||||
"MAX_TOKENS" => StopReason::MaxTokens,
|
||||
_ => {
|
||||
log::error!("Unexpected google finish_reason: {finish_reason}");
|
||||
StopReason::EndTurn
|
||||
}
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(state.stop_reason)));
|
||||
return Some((events, state));
|
||||
}
|
||||
Err(err) => {
|
||||
return Some((
|
||||
vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
|
||||
state,
|
||||
));
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
candidate
|
||||
.content
|
||||
.parts
|
||||
.into_iter()
|
||||
.for_each(|part| match part {
|
||||
Part::TextPart(text_part) => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
|
||||
}
|
||||
Part::InlineDataPart(_) => {}
|
||||
Part::FunctionCallPart(function_call_part) => {
|
||||
wants_to_use_tool = true;
|
||||
let name: Arc<str> = function_call_part.function_call.name.into();
|
||||
let next_tool_id =
|
||||
TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
let id: LanguageModelToolUseId =
|
||||
format!("{}-{}", name, next_tool_id).into();
|
||||
|
||||
None
|
||||
},
|
||||
)
|
||||
.flat_map(futures::stream::iter)
|
||||
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id,
|
||||
name,
|
||||
is_input_complete: true,
|
||||
raw_input: function_call_part.function_call.args.to_string(),
|
||||
input: function_call_part.function_call.args,
|
||||
},
|
||||
)));
|
||||
}
|
||||
Part::FunctionResponsePart(_) => {}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Even when Gemini wants to use a Tool, the API
|
||||
// responds with `finish_reason: STOP`
|
||||
if wants_to_use_tool {
|
||||
self.stop_reason = StopReason::ToolUse;
|
||||
}
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
|
||||
events
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_google_tokens(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue