bedrock: Fix bedrock not streaming (#28281)

Closes #26030 

Release Notes:

- Fixed Bedrock bug causing streaming responses to return as one big
chunk

---------

Co-authored-by: Peter Tripp <peter@zed.dev>
This commit is contained in:
Shardul Vaidya 2025-07-01 05:51:09 -04:00 committed by GitHub
parent 93b1e95a5d
commit 0d809c21ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 177 additions and 271 deletions

1
Cargo.lock generated
View file

@ -1911,7 +1911,6 @@ dependencies = [
"serde_json", "serde_json",
"strum 0.27.1", "strum 0.27.1",
"thiserror 2.0.12", "thiserror 2.0.12",
"tokio",
"workspace-hack", "workspace-hack",
] ]

View file

@ -25,5 +25,4 @@ serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
strum.workspace = true strum.workspace = true
thiserror.workspace = true thiserror.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
workspace-hack.workspace = true workspace-hack.workspace = true

View file

@ -1,9 +1,6 @@
mod models; mod models;
use std::collections::HashMap; use anyhow::{Context, Error, Result, anyhow};
use std::pin::Pin;
use anyhow::{Context as _, Error, Result, anyhow};
use aws_sdk_bedrockruntime as bedrock; use aws_sdk_bedrockruntime as bedrock;
pub use aws_sdk_bedrockruntime as bedrock_client; pub use aws_sdk_bedrockruntime as bedrock_client;
pub use aws_sdk_bedrockruntime::types::{ pub use aws_sdk_bedrockruntime::types::{
@ -24,9 +21,10 @@ pub use bedrock::types::{
ToolResultContentBlock as BedrockToolResultContentBlock, ToolResultContentBlock as BedrockToolResultContentBlock,
ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock, ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
}; };
use futures::stream::{self, BoxStream, Stream}; use futures::stream::{self, BoxStream};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Number, Value}; use serde_json::{Number, Value};
use std::collections::HashMap;
use thiserror::Error; use thiserror::Error;
pub use crate::models::*; pub use crate::models::*;
@ -34,70 +32,59 @@ pub use crate::models::*;
pub async fn stream_completion( pub async fn stream_completion(
client: bedrock::Client, client: bedrock::Client,
request: Request, request: Request,
handle: tokio::runtime::Handle,
) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> { ) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
handle let mut response = bedrock::Client::converse_stream(&client)
.spawn(async move { .model_id(request.model.clone())
let mut response = bedrock::Client::converse_stream(&client) .set_messages(request.messages.into());
.model_id(request.model.clone())
.set_messages(request.messages.into());
if let Some(Thinking::Enabled { if let Some(Thinking::Enabled {
budget_tokens: Some(budget_tokens), budget_tokens: Some(budget_tokens),
}) = request.thinking }) = request.thinking
{ {
response = let thinking_config = HashMap::from([
response.additional_model_request_fields(Document::Object(HashMap::from([( ("type".to_string(), Document::String("enabled".to_string())),
"thinking".to_string(), (
Document::from(HashMap::from([ "budget_tokens".to_string(),
("type".to_string(), Document::String("enabled".to_string())), Document::Number(AwsNumber::PosInt(budget_tokens)),
( ),
"budget_tokens".to_string(), ]);
Document::Number(AwsNumber::PosInt(budget_tokens)), response = response.additional_model_request_fields(Document::Object(HashMap::from([(
), "thinking".to_string(),
])), Document::from(thinking_config),
)]))); )])));
} }
if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() { if request
response = response.set_tool_config(request.tools); .tools
} .as_ref()
.map_or(false, |t| !t.tools.is_empty())
{
response = response.set_tool_config(request.tools);
}
let response = response.send().await; let output = response
.send()
.await
.context("Failed to send API request to Bedrock");
match response { let stream = Box::pin(stream::unfold(
Ok(output) => { output?.stream,
let stream: Pin< move |mut stream| async move {
Box< match stream.recv().await {
dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>> Ok(Some(output)) => Some((Ok(output), stream)),
+ Send, Ok(None) => None,
>, Err(err) => Some((
> = Box::pin(stream::unfold(output.stream, |mut stream| async move { Err(BedrockError::ClientError(anyhow!(
match stream.recv().await { "{:?}",
Ok(Some(output)) => Some(({ Ok(output) }, stream)), aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
Ok(None) => None, ))),
Err(err) => { stream,
Some((
// TODO: Figure out how we can capture Throttling Exceptions
Err(BedrockError::ClientError(anyhow!(
"{:?}",
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
))),
stream,
))
}
}
}));
Ok(stream)
}
Err(err) => Err(anyhow!(
"{:?}",
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
)), )),
} }
}) },
.await ));
.context("spawning a task")?
Ok(stream)
} }
pub fn aws_document_to_value(document: &Document) -> Value { pub fn aws_document_to_value(document: &Document) -> Value {

View file

@ -46,7 +46,6 @@ use settings::{Settings, SettingsStore};
use smol::lock::OnceCell; use smol::lock::OnceCell;
use strum::{EnumIter, IntoEnumIterator, IntoStaticStr}; use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
use theme::ThemeSettings; use theme::ThemeSettings;
use tokio::runtime::Handle;
use ui::{Icon, IconName, List, Tooltip, prelude::*}; use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt; use util::ResultExt;
@ -460,22 +459,22 @@ impl BedrockModel {
&self, &self,
request: bedrock::Request, request: bedrock::Request,
cx: &AsyncApp, cx: &AsyncApp,
) -> Result< ) -> BoxFuture<
BoxFuture<'static, BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>, 'static,
Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
> { > {
let runtime_client = self let Ok(runtime_client) = self
.get_or_init_client(cx) .get_or_init_client(&cx)
.cloned() .cloned()
.context("Bedrock client not initialized")?; .context("Bedrock client not initialized")
let owned_handle = self.handler.clone(); else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
Ok(async move { match Tokio::spawn(cx, bedrock::stream_completion(runtime_client, request)) {
let request = bedrock::stream_completion(runtime_client, request, owned_handle); Ok(res) => async { res.await.map_err(|err| anyhow!(err))? }.boxed(),
request.await.unwrap_or_else(|e| { Err(err) => futures::future::ready(Err(anyhow!(err))).boxed(),
futures::stream::once(async move { Err(BedrockError::ClientError(e)) }).boxed()
})
} }
.boxed())
} }
} }
@ -570,12 +569,10 @@ impl LanguageModel for BedrockModel {
Err(err) => return futures::future::ready(Err(err.into())).boxed(), Err(err) => return futures::future::ready(Err(err.into())).boxed(),
}; };
let owned_handle = self.handler.clone();
let request = self.stream_completion(request, cx); let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
let response = request.map_err(|err| anyhow!(err))?.await; let response = request.await.map_err(|err| anyhow!(err))?;
let events = map_to_language_model_completion_events(response, owned_handle); let events = map_to_language_model_completion_events(response);
if deny_tool_calls { if deny_tool_calls {
Ok(deny_tool_use_events(events).boxed()) Ok(deny_tool_use_events(events).boxed())
@ -879,7 +876,6 @@ pub fn get_bedrock_tokens(
pub fn map_to_language_model_completion_events( pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>, events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
handle: Handle,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> { ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
struct RawToolUse { struct RawToolUse {
id: String, id: String,
@ -892,198 +888,123 @@ pub fn map_to_language_model_completion_events(
tool_uses_by_index: HashMap<i32, RawToolUse>, tool_uses_by_index: HashMap<i32, RawToolUse>,
} }
futures::stream::unfold( let initial_state = State {
State { events,
events, tool_uses_by_index: HashMap::default(),
tool_uses_by_index: HashMap::default(), };
},
move |mut state: State| {
let inner_handle = handle.clone();
async move {
inner_handle
.spawn(async {
while let Some(event) = state.events.next().await {
match event {
Ok(event) => match event {
ConverseStreamOutput::ContentBlockDelta(cb_delta) => {
match cb_delta.delta {
Some(ContentBlockDelta::Text(text_out)) => {
let completion_event =
LanguageModelCompletionEvent::Text(text_out);
return Some((Some(Ok(completion_event)), state));
}
Some(ContentBlockDelta::ToolUse(text_out)) => { futures::stream::unfold(initial_state, |mut state| async move {
if let Some(tool_use) = state match state.events.next().await {
.tool_uses_by_index Some(event_result) => match event_result {
.get_mut(&cb_delta.content_block_index) Ok(event) => {
{ let result = match event {
tool_use.input_json.push_str(text_out.input()); ConverseStreamOutput::ContentBlockDelta(cb_delta) => match cb_delta.delta {
} Some(ContentBlockDelta::Text(text)) => {
} Some(Ok(LanguageModelCompletionEvent::Text(text)))
Some(ContentBlockDelta::ReasoningContent(thinking)) => {
match thinking {
ReasoningContentBlockDelta::RedactedContent(
redacted,
) => {
let thinking_event =
LanguageModelCompletionEvent::Thinking {
text: String::from_utf8(
redacted.into_inner(),
)
.unwrap_or("REDACTED".to_string()),
signature: None,
};
return Some((
Some(Ok(thinking_event)),
state,
));
}
ReasoningContentBlockDelta::Signature(
signature,
) => {
return Some((
Some(Ok(LanguageModelCompletionEvent::Thinking {
text: "".to_string(),
signature: Some(signature)
})),
state,
));
}
ReasoningContentBlockDelta::Text(thoughts) => {
let thinking_event =
LanguageModelCompletionEvent::Thinking {
text: thoughts.to_string(),
signature: None
};
return Some((
Some(Ok(thinking_event)),
state,
));
}
_ => {}
}
}
_ => {}
}
}
ConverseStreamOutput::ContentBlockStart(cb_start) => {
if let Some(ContentBlockStart::ToolUse(text_out)) =
cb_start.start
{
let tool_use = RawToolUse {
id: text_out.tool_use_id,
name: text_out.name,
input_json: String::new(),
};
state
.tool_uses_by_index
.insert(cb_start.content_block_index, tool_use);
}
}
ConverseStreamOutput::ContentBlockStop(cb_stop) => {
if let Some(tool_use) = state
.tool_uses_by_index
.remove(&cb_stop.content_block_index)
{
let tool_use_event = LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
is_input_complete: true,
raw_input: tool_use.input_json.clone(),
input: if tool_use.input_json.is_empty() {
Value::Null
} else {
serde_json::Value::from_str(
&tool_use.input_json,
)
.map_err(|err| anyhow!(err))
.unwrap()
},
};
return Some((
Some(Ok(LanguageModelCompletionEvent::ToolUse(
tool_use_event,
))),
state,
));
}
}
ConverseStreamOutput::Metadata(cb_meta) => {
if let Some(metadata) = cb_meta.usage {
let completion_event =
LanguageModelCompletionEvent::UsageUpdate(
TokenUsage {
input_tokens: metadata.input_tokens as u64,
output_tokens: metadata.output_tokens as u64,
cache_creation_input_tokens:
metadata.cache_write_input_tokens.unwrap_or_default() as u64,
cache_read_input_tokens:
metadata.cache_read_input_tokens.unwrap_or_default() as u64,
},
);
return Some((Some(Ok(completion_event)), state));
}
}
ConverseStreamOutput::MessageStop(message_stop) => {
let reason = match message_stop.stop_reason {
StopReason::ContentFiltered => {
LanguageModelCompletionEvent::Stop(
language_model::StopReason::EndTurn,
)
}
StopReason::EndTurn => {
LanguageModelCompletionEvent::Stop(
language_model::StopReason::EndTurn,
)
}
StopReason::GuardrailIntervened => {
LanguageModelCompletionEvent::Stop(
language_model::StopReason::EndTurn,
)
}
StopReason::MaxTokens => {
LanguageModelCompletionEvent::Stop(
language_model::StopReason::EndTurn,
)
}
StopReason::StopSequence => {
LanguageModelCompletionEvent::Stop(
language_model::StopReason::EndTurn,
)
}
StopReason::ToolUse => {
LanguageModelCompletionEvent::Stop(
language_model::StopReason::ToolUse,
)
}
_ => LanguageModelCompletionEvent::Stop(
language_model::StopReason::EndTurn,
),
};
return Some((Some(Ok(reason)), state));
}
_ => {}
},
Err(err) => return Some((Some(Err(anyhow!(err).into())), state)),
} }
Some(ContentBlockDelta::ToolUse(tool_output)) => {
if let Some(tool_use) = state
.tool_uses_by_index
.get_mut(&cb_delta.content_block_index)
{
tool_use.input_json.push_str(tool_output.input());
}
None
}
Some(ContentBlockDelta::ReasoningContent(thinking)) => match thinking {
ReasoningContentBlockDelta::Text(thoughts) => {
Some(Ok(LanguageModelCompletionEvent::Thinking {
text: thoughts.clone(),
signature: None,
}))
}
ReasoningContentBlockDelta::Signature(sig) => {
Some(Ok(LanguageModelCompletionEvent::Thinking {
text: "".into(),
signature: Some(sig),
}))
}
ReasoningContentBlockDelta::RedactedContent(redacted) => {
let content = String::from_utf8(redacted.into_inner())
.unwrap_or("REDACTED".to_string());
Some(Ok(LanguageModelCompletionEvent::Thinking {
text: content,
signature: None,
}))
}
_ => None,
},
_ => None,
},
ConverseStreamOutput::ContentBlockStart(cb_start) => {
if let Some(ContentBlockStart::ToolUse(tool_start)) = cb_start.start {
state.tool_uses_by_index.insert(
cb_start.content_block_index,
RawToolUse {
id: tool_start.tool_use_id,
name: tool_start.name,
input_json: String::new(),
},
);
}
None
} }
None ConverseStreamOutput::ContentBlockStop(cb_stop) => state
}) .tool_uses_by_index
.await .remove(&cb_stop.content_block_index)
.log_err() .map(|tool_use| {
.flatten() let input = if tool_use.input_json.is_empty() {
} Value::Null
}, } else {
) serde_json::Value::from_str(&tool_use.input_json)
.filter_map(|event| async move { event }) .unwrap_or(Value::Null)
};
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
is_input_complete: true,
raw_input: tool_use.input_json.clone(),
input,
},
))
}),
ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| {
Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
input_tokens: metadata.input_tokens as u64,
output_tokens: metadata.output_tokens as u64,
cache_creation_input_tokens: metadata
.cache_write_input_tokens
.unwrap_or_default()
as u64,
cache_read_input_tokens: metadata
.cache_read_input_tokens
.unwrap_or_default()
as u64,
}))
}),
ConverseStreamOutput::MessageStop(message_stop) => {
let stop_reason = match message_stop.stop_reason {
StopReason::ToolUse => language_model::StopReason::ToolUse,
_ => language_model::StopReason::EndTurn,
};
Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason)))
}
_ => None,
};
Some((result, state))
}
Err(err) => Some((
Some(Err(LanguageModelCompletionError::Other(anyhow!(err)))),
state,
)),
},
None => None,
}
})
.filter_map(|result| async move { result })
} }
struct ConfigurationView { struct ConfigurationView {