bedrock: Fix bedrock not streaming (#28281)
Closes #26030 Release Notes: - Fixed Bedrock bug causing streaming responses to return as one big chunk --------- Co-authored-by: Peter Tripp <peter@zed.dev>
This commit is contained in:
parent
93b1e95a5d
commit
0d809c21ba
4 changed files with 177 additions and 271 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -1911,7 +1911,6 @@ dependencies = [
|
|||
"serde_json",
|
||||
"strum 0.27.1",
|
||||
"thiserror 2.0.12",
|
||||
"tokio",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
|
|
|
@ -25,5 +25,4 @@ serde.workspace = true
|
|||
serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
thiserror.workspace = true
|
||||
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
|
||||
workspace-hack.workspace = true
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
mod models;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
|
||||
use anyhow::{Context as _, Error, Result, anyhow};
|
||||
use anyhow::{Context, Error, Result, anyhow};
|
||||
use aws_sdk_bedrockruntime as bedrock;
|
||||
pub use aws_sdk_bedrockruntime as bedrock_client;
|
||||
pub use aws_sdk_bedrockruntime::types::{
|
||||
|
@ -24,9 +21,10 @@ pub use bedrock::types::{
|
|||
ToolResultContentBlock as BedrockToolResultContentBlock,
|
||||
ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
|
||||
};
|
||||
use futures::stream::{self, BoxStream, Stream};
|
||||
use futures::stream::{self, BoxStream};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Number, Value};
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
|
||||
pub use crate::models::*;
|
||||
|
@ -34,70 +32,59 @@ pub use crate::models::*;
|
|||
pub async fn stream_completion(
|
||||
client: bedrock::Client,
|
||||
request: Request,
|
||||
handle: tokio::runtime::Handle,
|
||||
) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
|
||||
handle
|
||||
.spawn(async move {
|
||||
let mut response = bedrock::Client::converse_stream(&client)
|
||||
.model_id(request.model.clone())
|
||||
.set_messages(request.messages.into());
|
||||
let mut response = bedrock::Client::converse_stream(&client)
|
||||
.model_id(request.model.clone())
|
||||
.set_messages(request.messages.into());
|
||||
|
||||
if let Some(Thinking::Enabled {
|
||||
budget_tokens: Some(budget_tokens),
|
||||
}) = request.thinking
|
||||
{
|
||||
response =
|
||||
response.additional_model_request_fields(Document::Object(HashMap::from([(
|
||||
"thinking".to_string(),
|
||||
Document::from(HashMap::from([
|
||||
("type".to_string(), Document::String("enabled".to_string())),
|
||||
(
|
||||
"budget_tokens".to_string(),
|
||||
Document::Number(AwsNumber::PosInt(budget_tokens)),
|
||||
),
|
||||
])),
|
||||
)])));
|
||||
}
|
||||
if let Some(Thinking::Enabled {
|
||||
budget_tokens: Some(budget_tokens),
|
||||
}) = request.thinking
|
||||
{
|
||||
let thinking_config = HashMap::from([
|
||||
("type".to_string(), Document::String("enabled".to_string())),
|
||||
(
|
||||
"budget_tokens".to_string(),
|
||||
Document::Number(AwsNumber::PosInt(budget_tokens)),
|
||||
),
|
||||
]);
|
||||
response = response.additional_model_request_fields(Document::Object(HashMap::from([(
|
||||
"thinking".to_string(),
|
||||
Document::from(thinking_config),
|
||||
)])));
|
||||
}
|
||||
|
||||
if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() {
|
||||
response = response.set_tool_config(request.tools);
|
||||
}
|
||||
if request
|
||||
.tools
|
||||
.as_ref()
|
||||
.map_or(false, |t| !t.tools.is_empty())
|
||||
{
|
||||
response = response.set_tool_config(request.tools);
|
||||
}
|
||||
|
||||
let response = response.send().await;
|
||||
let output = response
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send API request to Bedrock");
|
||||
|
||||
match response {
|
||||
Ok(output) => {
|
||||
let stream: Pin<
|
||||
Box<
|
||||
dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
|
||||
+ Send,
|
||||
>,
|
||||
> = Box::pin(stream::unfold(output.stream, |mut stream| async move {
|
||||
match stream.recv().await {
|
||||
Ok(Some(output)) => Some(({ Ok(output) }, stream)),
|
||||
Ok(None) => None,
|
||||
Err(err) => {
|
||||
Some((
|
||||
// TODO: Figure out how we can capture Throttling Exceptions
|
||||
Err(BedrockError::ClientError(anyhow!(
|
||||
"{:?}",
|
||||
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
|
||||
))),
|
||||
stream,
|
||||
))
|
||||
}
|
||||
}
|
||||
}));
|
||||
Ok(stream)
|
||||
}
|
||||
Err(err) => Err(anyhow!(
|
||||
"{:?}",
|
||||
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
|
||||
let stream = Box::pin(stream::unfold(
|
||||
output?.stream,
|
||||
move |mut stream| async move {
|
||||
match stream.recv().await {
|
||||
Ok(Some(output)) => Some((Ok(output), stream)),
|
||||
Ok(None) => None,
|
||||
Err(err) => Some((
|
||||
Err(BedrockError::ClientError(anyhow!(
|
||||
"{:?}",
|
||||
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
|
||||
))),
|
||||
stream,
|
||||
)),
|
||||
}
|
||||
})
|
||||
.await
|
||||
.context("spawning a task")?
|
||||
},
|
||||
));
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
pub fn aws_document_to_value(document: &Document) -> Value {
|
||||
|
|
|
@ -46,7 +46,6 @@ use settings::{Settings, SettingsStore};
|
|||
use smol::lock::OnceCell;
|
||||
use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
|
||||
use theme::ThemeSettings;
|
||||
use tokio::runtime::Handle;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
use util::ResultExt;
|
||||
|
||||
|
@ -460,22 +459,22 @@ impl BedrockModel {
|
|||
&self,
|
||||
request: bedrock::Request,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<
|
||||
BoxFuture<'static, BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
|
||||
> {
|
||||
let runtime_client = self
|
||||
.get_or_init_client(cx)
|
||||
let Ok(runtime_client) = self
|
||||
.get_or_init_client(&cx)
|
||||
.cloned()
|
||||
.context("Bedrock client not initialized")?;
|
||||
let owned_handle = self.handler.clone();
|
||||
.context("Bedrock client not initialized")
|
||||
else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
Ok(async move {
|
||||
let request = bedrock::stream_completion(runtime_client, request, owned_handle);
|
||||
request.await.unwrap_or_else(|e| {
|
||||
futures::stream::once(async move { Err(BedrockError::ClientError(e)) }).boxed()
|
||||
})
|
||||
match Tokio::spawn(cx, bedrock::stream_completion(runtime_client, request)) {
|
||||
Ok(res) => async { res.await.map_err(|err| anyhow!(err))? }.boxed(),
|
||||
Err(err) => futures::future::ready(Err(anyhow!(err))).boxed(),
|
||||
}
|
||||
.boxed())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -570,12 +569,10 @@ impl LanguageModel for BedrockModel {
|
|||
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
|
||||
};
|
||||
|
||||
let owned_handle = self.handler.clone();
|
||||
|
||||
let request = self.stream_completion(request, cx);
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = request.map_err(|err| anyhow!(err))?.await;
|
||||
let events = map_to_language_model_completion_events(response, owned_handle);
|
||||
let response = request.await.map_err(|err| anyhow!(err))?;
|
||||
let events = map_to_language_model_completion_events(response);
|
||||
|
||||
if deny_tool_calls {
|
||||
Ok(deny_tool_use_events(events).boxed())
|
||||
|
@ -879,7 +876,6 @@ pub fn get_bedrock_tokens(
|
|||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
|
||||
handle: Handle,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
struct RawToolUse {
|
||||
id: String,
|
||||
|
@ -892,198 +888,123 @@ pub fn map_to_language_model_completion_events(
|
|||
tool_uses_by_index: HashMap<i32, RawToolUse>,
|
||||
}
|
||||
|
||||
futures::stream::unfold(
|
||||
State {
|
||||
events,
|
||||
tool_uses_by_index: HashMap::default(),
|
||||
},
|
||||
move |mut state: State| {
|
||||
let inner_handle = handle.clone();
|
||||
async move {
|
||||
inner_handle
|
||||
.spawn(async {
|
||||
while let Some(event) = state.events.next().await {
|
||||
match event {
|
||||
Ok(event) => match event {
|
||||
ConverseStreamOutput::ContentBlockDelta(cb_delta) => {
|
||||
match cb_delta.delta {
|
||||
Some(ContentBlockDelta::Text(text_out)) => {
|
||||
let completion_event =
|
||||
LanguageModelCompletionEvent::Text(text_out);
|
||||
return Some((Some(Ok(completion_event)), state));
|
||||
}
|
||||
let initial_state = State {
|
||||
events,
|
||||
tool_uses_by_index: HashMap::default(),
|
||||
};
|
||||
|
||||
Some(ContentBlockDelta::ToolUse(text_out)) => {
|
||||
if let Some(tool_use) = state
|
||||
.tool_uses_by_index
|
||||
.get_mut(&cb_delta.content_block_index)
|
||||
{
|
||||
tool_use.input_json.push_str(text_out.input());
|
||||
}
|
||||
}
|
||||
|
||||
Some(ContentBlockDelta::ReasoningContent(thinking)) => {
|
||||
match thinking {
|
||||
ReasoningContentBlockDelta::RedactedContent(
|
||||
redacted,
|
||||
) => {
|
||||
let thinking_event =
|
||||
LanguageModelCompletionEvent::Thinking {
|
||||
text: String::from_utf8(
|
||||
redacted.into_inner(),
|
||||
)
|
||||
.unwrap_or("REDACTED".to_string()),
|
||||
signature: None,
|
||||
};
|
||||
|
||||
return Some((
|
||||
Some(Ok(thinking_event)),
|
||||
state,
|
||||
));
|
||||
}
|
||||
ReasoningContentBlockDelta::Signature(
|
||||
signature,
|
||||
) => {
|
||||
return Some((
|
||||
Some(Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: "".to_string(),
|
||||
signature: Some(signature)
|
||||
})),
|
||||
state,
|
||||
));
|
||||
}
|
||||
ReasoningContentBlockDelta::Text(thoughts) => {
|
||||
let thinking_event =
|
||||
LanguageModelCompletionEvent::Thinking {
|
||||
text: thoughts.to_string(),
|
||||
signature: None
|
||||
};
|
||||
|
||||
return Some((
|
||||
Some(Ok(thinking_event)),
|
||||
state,
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
ConverseStreamOutput::ContentBlockStart(cb_start) => {
|
||||
if let Some(ContentBlockStart::ToolUse(text_out)) =
|
||||
cb_start.start
|
||||
{
|
||||
let tool_use = RawToolUse {
|
||||
id: text_out.tool_use_id,
|
||||
name: text_out.name,
|
||||
input_json: String::new(),
|
||||
};
|
||||
|
||||
state
|
||||
.tool_uses_by_index
|
||||
.insert(cb_start.content_block_index, tool_use);
|
||||
}
|
||||
}
|
||||
ConverseStreamOutput::ContentBlockStop(cb_stop) => {
|
||||
if let Some(tool_use) = state
|
||||
.tool_uses_by_index
|
||||
.remove(&cb_stop.content_block_index)
|
||||
{
|
||||
let tool_use_event = LanguageModelToolUse {
|
||||
id: tool_use.id.into(),
|
||||
name: tool_use.name.into(),
|
||||
is_input_complete: true,
|
||||
raw_input: tool_use.input_json.clone(),
|
||||
input: if tool_use.input_json.is_empty() {
|
||||
Value::Null
|
||||
} else {
|
||||
serde_json::Value::from_str(
|
||||
&tool_use.input_json,
|
||||
)
|
||||
.map_err(|err| anyhow!(err))
|
||||
.unwrap()
|
||||
},
|
||||
};
|
||||
|
||||
return Some((
|
||||
Some(Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
tool_use_event,
|
||||
))),
|
||||
state,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
ConverseStreamOutput::Metadata(cb_meta) => {
|
||||
if let Some(metadata) = cb_meta.usage {
|
||||
let completion_event =
|
||||
LanguageModelCompletionEvent::UsageUpdate(
|
||||
TokenUsage {
|
||||
input_tokens: metadata.input_tokens as u64,
|
||||
output_tokens: metadata.output_tokens as u64,
|
||||
cache_creation_input_tokens:
|
||||
metadata.cache_write_input_tokens.unwrap_or_default() as u64,
|
||||
cache_read_input_tokens:
|
||||
metadata.cache_read_input_tokens.unwrap_or_default() as u64,
|
||||
},
|
||||
);
|
||||
return Some((Some(Ok(completion_event)), state));
|
||||
}
|
||||
}
|
||||
ConverseStreamOutput::MessageStop(message_stop) => {
|
||||
let reason = match message_stop.stop_reason {
|
||||
StopReason::ContentFiltered => {
|
||||
LanguageModelCompletionEvent::Stop(
|
||||
language_model::StopReason::EndTurn,
|
||||
)
|
||||
}
|
||||
StopReason::EndTurn => {
|
||||
LanguageModelCompletionEvent::Stop(
|
||||
language_model::StopReason::EndTurn,
|
||||
)
|
||||
}
|
||||
StopReason::GuardrailIntervened => {
|
||||
LanguageModelCompletionEvent::Stop(
|
||||
language_model::StopReason::EndTurn,
|
||||
)
|
||||
}
|
||||
StopReason::MaxTokens => {
|
||||
LanguageModelCompletionEvent::Stop(
|
||||
language_model::StopReason::EndTurn,
|
||||
)
|
||||
}
|
||||
StopReason::StopSequence => {
|
||||
LanguageModelCompletionEvent::Stop(
|
||||
language_model::StopReason::EndTurn,
|
||||
)
|
||||
}
|
||||
StopReason::ToolUse => {
|
||||
LanguageModelCompletionEvent::Stop(
|
||||
language_model::StopReason::ToolUse,
|
||||
)
|
||||
}
|
||||
_ => LanguageModelCompletionEvent::Stop(
|
||||
language_model::StopReason::EndTurn,
|
||||
),
|
||||
};
|
||||
return Some((Some(Ok(reason)), state));
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
|
||||
Err(err) => return Some((Some(Err(anyhow!(err).into())), state)),
|
||||
futures::stream::unfold(initial_state, |mut state| async move {
|
||||
match state.events.next().await {
|
||||
Some(event_result) => match event_result {
|
||||
Ok(event) => {
|
||||
let result = match event {
|
||||
ConverseStreamOutput::ContentBlockDelta(cb_delta) => match cb_delta.delta {
|
||||
Some(ContentBlockDelta::Text(text)) => {
|
||||
Some(Ok(LanguageModelCompletionEvent::Text(text)))
|
||||
}
|
||||
Some(ContentBlockDelta::ToolUse(tool_output)) => {
|
||||
if let Some(tool_use) = state
|
||||
.tool_uses_by_index
|
||||
.get_mut(&cb_delta.content_block_index)
|
||||
{
|
||||
tool_use.input_json.push_str(tool_output.input());
|
||||
}
|
||||
None
|
||||
}
|
||||
Some(ContentBlockDelta::ReasoningContent(thinking)) => match thinking {
|
||||
ReasoningContentBlockDelta::Text(thoughts) => {
|
||||
Some(Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: thoughts.clone(),
|
||||
signature: None,
|
||||
}))
|
||||
}
|
||||
ReasoningContentBlockDelta::Signature(sig) => {
|
||||
Some(Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: "".into(),
|
||||
signature: Some(sig),
|
||||
}))
|
||||
}
|
||||
ReasoningContentBlockDelta::RedactedContent(redacted) => {
|
||||
let content = String::from_utf8(redacted.into_inner())
|
||||
.unwrap_or("REDACTED".to_string());
|
||||
Some(Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: content,
|
||||
signature: None,
|
||||
}))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
},
|
||||
ConverseStreamOutput::ContentBlockStart(cb_start) => {
|
||||
if let Some(ContentBlockStart::ToolUse(tool_start)) = cb_start.start {
|
||||
state.tool_uses_by_index.insert(
|
||||
cb_start.content_block_index,
|
||||
RawToolUse {
|
||||
id: tool_start.tool_use_id,
|
||||
name: tool_start.name,
|
||||
input_json: String::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
None
|
||||
}
|
||||
None
|
||||
})
|
||||
.await
|
||||
.log_err()
|
||||
.flatten()
|
||||
}
|
||||
},
|
||||
)
|
||||
.filter_map(|event| async move { event })
|
||||
ConverseStreamOutput::ContentBlockStop(cb_stop) => state
|
||||
.tool_uses_by_index
|
||||
.remove(&cb_stop.content_block_index)
|
||||
.map(|tool_use| {
|
||||
let input = if tool_use.input_json.is_empty() {
|
||||
Value::Null
|
||||
} else {
|
||||
serde_json::Value::from_str(&tool_use.input_json)
|
||||
.unwrap_or(Value::Null)
|
||||
};
|
||||
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_use.id.into(),
|
||||
name: tool_use.name.into(),
|
||||
is_input_complete: true,
|
||||
raw_input: tool_use.input_json.clone(),
|
||||
input,
|
||||
},
|
||||
))
|
||||
}),
|
||||
ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| {
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
input_tokens: metadata.input_tokens as u64,
|
||||
output_tokens: metadata.output_tokens as u64,
|
||||
cache_creation_input_tokens: metadata
|
||||
.cache_write_input_tokens
|
||||
.unwrap_or_default()
|
||||
as u64,
|
||||
cache_read_input_tokens: metadata
|
||||
.cache_read_input_tokens
|
||||
.unwrap_or_default()
|
||||
as u64,
|
||||
}))
|
||||
}),
|
||||
ConverseStreamOutput::MessageStop(message_stop) => {
|
||||
let stop_reason = match message_stop.stop_reason {
|
||||
StopReason::ToolUse => language_model::StopReason::ToolUse,
|
||||
_ => language_model::StopReason::EndTurn,
|
||||
};
|
||||
Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason)))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Some((result, state))
|
||||
}
|
||||
Err(err) => Some((
|
||||
Some(Err(LanguageModelCompletionError::Other(anyhow!(err)))),
|
||||
state,
|
||||
)),
|
||||
},
|
||||
None => None,
|
||||
}
|
||||
})
|
||||
.filter_map(|result| async move { result })
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue