Add support for queuing status updates in cloud language model provider (#29818)
This sets us up to display queue position information to the user, once our language model backend is updated to support request queuing. The JSON returned by the LLM backend will need to look like this: ```json {"queue": {"status": "queued", "position": 1}} {"queue": {"status": "started"}} {"event": {"THE_UPSTREAM_MODEL_PROVIDER_EVENT": "..."}} ``` Release Notes: - N/A --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
4d1df7bcd7
commit
04772bf17d
9 changed files with 492 additions and 430 deletions
|
@ -330,8 +330,11 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
> {
|
||||
let request = into_open_ai(request, &self.model, self.max_output_tokens());
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move { Ok(map_to_language_model_completion_events(completions.await?).boxed()) }
|
||||
.boxed()
|
||||
async move {
|
||||
let mapper = OpenAiEventMapper::new();
|
||||
Ok(mapper.map_stream(completions.await?).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -422,123 +425,108 @@ pub fn into_open_ai(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
#[derive(Default)]
|
||||
struct RawToolCall {
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
pub struct OpenAiEventMapper {
|
||||
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
||||
}
|
||||
|
||||
struct State {
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
||||
}
|
||||
|
||||
futures::stream::unfold(
|
||||
State {
|
||||
events,
|
||||
impl OpenAiEventMapper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tool_calls_by_index: HashMap::default(),
|
||||
},
|
||||
|mut state| async move {
|
||||
if let Some(event) = state.events.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return Some((
|
||||
vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"Response contained no choices"
|
||||
)))],
|
||||
state,
|
||||
));
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
let mut events = Vec::new();
|
||||
if let Some(content) = choice.delta.content.clone() {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||
}
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Ok(event) => self.map_event(event),
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
|
||||
for tool_call in tool_calls {
|
||||
let entry = state
|
||||
.tool_calls_by_index
|
||||
.entry(tool_call.index)
|
||||
.or_default();
|
||||
pub fn map_event(
|
||||
&mut self,
|
||||
event: ResponseStreamEvent,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"Response contained no choices"
|
||||
)))];
|
||||
};
|
||||
|
||||
if let Some(tool_id) = tool_call.id.clone() {
|
||||
entry.id = tool_id;
|
||||
}
|
||||
let mut events = Vec::new();
|
||||
if let Some(content) = choice.delta.content.clone() {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||
}
|
||||
|
||||
if let Some(function) = tool_call.function.as_ref() {
|
||||
if let Some(name) = function.name.clone() {
|
||||
entry.name = name;
|
||||
}
|
||||
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
|
||||
for tool_call in tool_calls {
|
||||
let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
|
||||
|
||||
if let Some(arguments) = function.arguments.clone() {
|
||||
entry.arguments.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(tool_id) = tool_call.id.clone() {
|
||||
entry.id = tool_id;
|
||||
}
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("stop") => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||
StopReason::EndTurn,
|
||||
)));
|
||||
}
|
||||
Some("tool_calls") => {
|
||||
events.extend(state.tool_calls_by_index.drain().map(
|
||||
|(_, tool_call)| match serde_json::Value::from_str(
|
||||
&tool_call.arguments,
|
||||
) {
|
||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.clone().into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
},
|
||||
)),
|
||||
Err(error) => {
|
||||
Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.as_str().into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
})
|
||||
}
|
||||
},
|
||||
));
|
||||
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||
StopReason::ToolUse,
|
||||
)));
|
||||
}
|
||||
Some(stop_reason) => {
|
||||
log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||
StopReason::EndTurn,
|
||||
)));
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
return Some((events, state));
|
||||
if let Some(function) = tool_call.function.as_ref() {
|
||||
if let Some(name) = function.name.clone() {
|
||||
entry.name = name;
|
||||
}
|
||||
Err(err) => {
|
||||
return Some((vec![Err(LanguageModelCompletionError::Other(err))], state));
|
||||
|
||||
if let Some(arguments) = function.arguments.clone() {
|
||||
entry.arguments.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
},
|
||||
)
|
||||
.flat_map(futures::stream::iter)
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("stop") => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
Some("tool_calls") => {
|
||||
events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
|
||||
match serde_json::Value::from_str(&tool_call.arguments) {
|
||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.clone().into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
},
|
||||
)),
|
||||
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.as_str().into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
}),
|
||||
}
|
||||
}));
|
||||
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
|
||||
}
|
||||
Some(stop_reason) => {
|
||||
log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
events
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct RawToolCall {
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
pub fn count_open_ai_tokens(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue