diff --git a/Cargo.lock b/Cargo.lock index 20ea6472b3..19c73433ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1911,7 +1911,6 @@ dependencies = [ "serde_json", "strum 0.27.1", "thiserror 2.0.12", - "tokio", "workspace-hack", ] diff --git a/crates/bedrock/Cargo.toml b/crates/bedrock/Cargo.toml index 84fd584601..3000af50bb 100644 --- a/crates/bedrock/Cargo.toml +++ b/crates/bedrock/Cargo.toml @@ -25,5 +25,4 @@ serde.workspace = true serde_json.workspace = true strum.workspace = true thiserror.workspace = true -tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } workspace-hack.workspace = true diff --git a/crates/bedrock/src/bedrock.rs b/crates/bedrock/src/bedrock.rs index e32a456dba..1c6a9bd0a1 100644 --- a/crates/bedrock/src/bedrock.rs +++ b/crates/bedrock/src/bedrock.rs @@ -1,9 +1,6 @@ mod models; -use std::collections::HashMap; -use std::pin::Pin; - -use anyhow::{Context as _, Error, Result, anyhow}; +use anyhow::{Context, Error, Result, anyhow}; use aws_sdk_bedrockruntime as bedrock; pub use aws_sdk_bedrockruntime as bedrock_client; pub use aws_sdk_bedrockruntime::types::{ @@ -24,9 +21,10 @@ pub use bedrock::types::{ ToolResultContentBlock as BedrockToolResultContentBlock, ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock, }; -use futures::stream::{self, BoxStream, Stream}; +use futures::stream::{self, BoxStream}; use serde::{Deserialize, Serialize}; use serde_json::{Number, Value}; +use std::collections::HashMap; use thiserror::Error; pub use crate::models::*; @@ -34,70 +32,59 @@ pub use crate::models::*; pub async fn stream_completion( client: bedrock::Client, request: Request, - handle: tokio::runtime::Handle, ) -> Result>, Error> { - handle - .spawn(async move { - let mut response = bedrock::Client::converse_stream(&client) - .model_id(request.model.clone()) - .set_messages(request.messages.into()); + let mut response = bedrock::Client::converse_stream(&client) + .model_id(request.model.clone()) + .set_messages(request.messages.into()); - if let Some(Thinking::Enabled { - budget_tokens: Some(budget_tokens), - }) = request.thinking - { - response = - response.additional_model_request_fields(Document::Object(HashMap::from([( - "thinking".to_string(), - Document::from(HashMap::from([ - ("type".to_string(), Document::String("enabled".to_string())), - ( - "budget_tokens".to_string(), - Document::Number(AwsNumber::PosInt(budget_tokens)), - ), - ])), - )]))); - } + if let Some(Thinking::Enabled { + budget_tokens: Some(budget_tokens), + }) = request.thinking + { + let thinking_config = HashMap::from([ + ("type".to_string(), Document::String("enabled".to_string())), + ( + "budget_tokens".to_string(), + Document::Number(AwsNumber::PosInt(budget_tokens)), + ), + ]); + response = response.additional_model_request_fields(Document::Object(HashMap::from([( + "thinking".to_string(), + Document::from(thinking_config), + )]))); + } - if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() { - response = response.set_tool_config(request.tools); - } + if request + .tools + .as_ref() + .map_or(false, |t| !t.tools.is_empty()) + { + response = response.set_tool_config(request.tools); + } - let response = response.send().await; + let output = response + .send() + .await + .context("Failed to send API request to Bedrock"); - match response { - Ok(output) => { - let stream: Pin< - Box< - dyn Stream> - + Send, - >, - > = Box::pin(stream::unfold(output.stream, |mut stream| async move { - match stream.recv().await { - Ok(Some(output)) => Some(({ Ok(output) }, stream)), - Ok(None) => None, - Err(err) => { - Some(( - // TODO: Figure out how we can capture Throttling Exceptions - Err(BedrockError::ClientError(anyhow!( - "{:?}", - aws_sdk_bedrockruntime::error::DisplayErrorContext(err) - ))), - stream, - )) - } - } - })); - Ok(stream) - } - Err(err) => Err(anyhow!( - "{:?}", - aws_sdk_bedrockruntime::error::DisplayErrorContext(err) + let stream = Box::pin(stream::unfold( + output?.stream, + move |mut stream| async move { + match stream.recv().await { + Ok(Some(output)) => Some((Ok(output), stream)), + Ok(None) => None, + Err(err) => Some(( + Err(BedrockError::ClientError(anyhow!( + "{:?}", + aws_sdk_bedrockruntime::error::DisplayErrorContext(err) + ))), + stream, )), } - }) - .await - .context("spawning a task")? + }, + )); + + Ok(stream) } pub fn aws_document_to_value(document: &Document) -> Value { diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index dd19915f93..9c0d481607 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -46,7 +46,6 @@ use settings::{Settings, SettingsStore}; use smol::lock::OnceCell; use strum::{EnumIter, IntoEnumIterator, IntoStaticStr}; use theme::ThemeSettings; -use tokio::runtime::Handle; use ui::{Icon, IconName, List, Tooltip, prelude::*}; use util::ResultExt; @@ -460,22 +459,22 @@ impl BedrockModel { &self, request: bedrock::Request, cx: &AsyncApp, - ) -> Result< - BoxFuture<'static, BoxStream<'static, Result>>, + ) -> BoxFuture< + 'static, + Result>>, > { - let runtime_client = self - .get_or_init_client(cx) + let Ok(runtime_client) = self + .get_or_init_client(&cx) .cloned() - .context("Bedrock client not initialized")?; - let owned_handle = self.handler.clone(); + .context("Bedrock client not initialized") + else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; - Ok(async move { - let request = bedrock::stream_completion(runtime_client, request, owned_handle); - request.await.unwrap_or_else(|e| { - futures::stream::once(async move { Err(BedrockError::ClientError(e)) }).boxed() - }) + match Tokio::spawn(cx, bedrock::stream_completion(runtime_client, request)) { + Ok(res) => async { res.await.map_err(|err| anyhow!(err))? }.boxed(), + Err(err) => futures::future::ready(Err(anyhow!(err))).boxed(), } - .boxed()) } } @@ -570,12 +569,10 @@ impl LanguageModel for BedrockModel { Err(err) => return futures::future::ready(Err(err.into())).boxed(), }; - let owned_handle = self.handler.clone(); - let request = self.stream_completion(request, cx); let future = self.request_limiter.stream(async move { - let response = request.map_err(|err| anyhow!(err))?.await; - let events = map_to_language_model_completion_events(response, owned_handle); + let response = request.await.map_err(|err| anyhow!(err))?; + let events = map_to_language_model_completion_events(response); if deny_tool_calls { Ok(deny_tool_use_events(events).boxed()) @@ -879,7 +876,6 @@ pub fn get_bedrock_tokens( pub fn map_to_language_model_completion_events( events: Pin>>>, - handle: Handle, ) -> impl Stream> { struct RawToolUse { id: String, @@ -892,198 +888,123 @@ pub fn map_to_language_model_completion_events( tool_uses_by_index: HashMap, } - futures::stream::unfold( - State { - events, - tool_uses_by_index: HashMap::default(), - }, - move |mut state: State| { - let inner_handle = handle.clone(); - async move { - inner_handle - .spawn(async { - while let Some(event) = state.events.next().await { - match event { - Ok(event) => match event { - ConverseStreamOutput::ContentBlockDelta(cb_delta) => { - match cb_delta.delta { - Some(ContentBlockDelta::Text(text_out)) => { - let completion_event = - LanguageModelCompletionEvent::Text(text_out); - return Some((Some(Ok(completion_event)), state)); - } + let initial_state = State { + events, + tool_uses_by_index: HashMap::default(), + }; - Some(ContentBlockDelta::ToolUse(text_out)) => { - if let Some(tool_use) = state - .tool_uses_by_index - .get_mut(&cb_delta.content_block_index) - { - tool_use.input_json.push_str(text_out.input()); - } - } - - Some(ContentBlockDelta::ReasoningContent(thinking)) => { - match thinking { - ReasoningContentBlockDelta::RedactedContent( - redacted, - ) => { - let thinking_event = - LanguageModelCompletionEvent::Thinking { - text: String::from_utf8( - redacted.into_inner(), - ) - .unwrap_or("REDACTED".to_string()), - signature: None, - }; - - return Some(( - Some(Ok(thinking_event)), - state, - )); - } - ReasoningContentBlockDelta::Signature( - signature, - ) => { - return Some(( - Some(Ok(LanguageModelCompletionEvent::Thinking { - text: "".to_string(), - signature: Some(signature) - })), - state, - )); - } - ReasoningContentBlockDelta::Text(thoughts) => { - let thinking_event = - LanguageModelCompletionEvent::Thinking { - text: thoughts.to_string(), - signature: None - }; - - return Some(( - Some(Ok(thinking_event)), - state, - )); - } - _ => {} - } - } - _ => {} - } - } - ConverseStreamOutput::ContentBlockStart(cb_start) => { - if let Some(ContentBlockStart::ToolUse(text_out)) = - cb_start.start - { - let tool_use = RawToolUse { - id: text_out.tool_use_id, - name: text_out.name, - input_json: String::new(), - }; - - state - .tool_uses_by_index - .insert(cb_start.content_block_index, tool_use); - } - } - ConverseStreamOutput::ContentBlockStop(cb_stop) => { - if let Some(tool_use) = state - .tool_uses_by_index - .remove(&cb_stop.content_block_index) - { - let tool_use_event = LanguageModelToolUse { - id: tool_use.id.into(), - name: tool_use.name.into(), - is_input_complete: true, - raw_input: tool_use.input_json.clone(), - input: if tool_use.input_json.is_empty() { - Value::Null - } else { - serde_json::Value::from_str( - &tool_use.input_json, - ) - .map_err(|err| anyhow!(err)) - .unwrap() - }, - }; - - return Some(( - Some(Ok(LanguageModelCompletionEvent::ToolUse( - tool_use_event, - ))), - state, - )); - } - } - - ConverseStreamOutput::Metadata(cb_meta) => { - if let Some(metadata) = cb_meta.usage { - let completion_event = - LanguageModelCompletionEvent::UsageUpdate( - TokenUsage { - input_tokens: metadata.input_tokens as u64, - output_tokens: metadata.output_tokens as u64, - cache_creation_input_tokens: - metadata.cache_write_input_tokens.unwrap_or_default() as u64, - cache_read_input_tokens: - metadata.cache_read_input_tokens.unwrap_or_default() as u64, - }, - ); - return Some((Some(Ok(completion_event)), state)); - } - } - ConverseStreamOutput::MessageStop(message_stop) => { - let reason = match message_stop.stop_reason { - StopReason::ContentFiltered => { - LanguageModelCompletionEvent::Stop( - language_model::StopReason::EndTurn, - ) - } - StopReason::EndTurn => { - LanguageModelCompletionEvent::Stop( - language_model::StopReason::EndTurn, - ) - } - StopReason::GuardrailIntervened => { - LanguageModelCompletionEvent::Stop( - language_model::StopReason::EndTurn, - ) - } - StopReason::MaxTokens => { - LanguageModelCompletionEvent::Stop( - language_model::StopReason::EndTurn, - ) - } - StopReason::StopSequence => { - LanguageModelCompletionEvent::Stop( - language_model::StopReason::EndTurn, - ) - } - StopReason::ToolUse => { - LanguageModelCompletionEvent::Stop( - language_model::StopReason::ToolUse, - ) - } - _ => LanguageModelCompletionEvent::Stop( - language_model::StopReason::EndTurn, - ), - }; - return Some((Some(Ok(reason)), state)); - } - _ => {} - }, - - Err(err) => return Some((Some(Err(anyhow!(err).into())), state)), + futures::stream::unfold(initial_state, |mut state| async move { + match state.events.next().await { + Some(event_result) => match event_result { + Ok(event) => { + let result = match event { + ConverseStreamOutput::ContentBlockDelta(cb_delta) => match cb_delta.delta { + Some(ContentBlockDelta::Text(text)) => { + Some(Ok(LanguageModelCompletionEvent::Text(text))) } + Some(ContentBlockDelta::ToolUse(tool_output)) => { + if let Some(tool_use) = state + .tool_uses_by_index + .get_mut(&cb_delta.content_block_index) + { + tool_use.input_json.push_str(tool_output.input()); + } + None + } + Some(ContentBlockDelta::ReasoningContent(thinking)) => match thinking { + ReasoningContentBlockDelta::Text(thoughts) => { + Some(Ok(LanguageModelCompletionEvent::Thinking { + text: thoughts.clone(), + signature: None, + })) + } + ReasoningContentBlockDelta::Signature(sig) => { + Some(Ok(LanguageModelCompletionEvent::Thinking { + text: "".into(), + signature: Some(sig), + })) + } + ReasoningContentBlockDelta::RedactedContent(redacted) => { + let content = String::from_utf8(redacted.into_inner()) + .unwrap_or("REDACTED".to_string()); + Some(Ok(LanguageModelCompletionEvent::Thinking { + text: content, + signature: None, + })) + } + _ => None, + }, + _ => None, + }, + ConverseStreamOutput::ContentBlockStart(cb_start) => { + if let Some(ContentBlockStart::ToolUse(tool_start)) = cb_start.start { + state.tool_uses_by_index.insert( + cb_start.content_block_index, + RawToolUse { + id: tool_start.tool_use_id, + name: tool_start.name, + input_json: String::new(), + }, + ); + } + None } - None - }) - .await - .log_err() - .flatten() - } - }, - ) - .filter_map(|event| async move { event }) + ConverseStreamOutput::ContentBlockStop(cb_stop) => state + .tool_uses_by_index + .remove(&cb_stop.content_block_index) + .map(|tool_use| { + let input = if tool_use.input_json.is_empty() { + Value::Null + } else { + serde_json::Value::from_str(&tool_use.input_json) + .unwrap_or(Value::Null) + }; + + Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.into(), + name: tool_use.name.into(), + is_input_complete: true, + raw_input: tool_use.input_json.clone(), + input, + }, + )) + }), + ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| { + Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: metadata.input_tokens as u64, + output_tokens: metadata.output_tokens as u64, + cache_creation_input_tokens: metadata + .cache_write_input_tokens + .unwrap_or_default() + as u64, + cache_read_input_tokens: metadata + .cache_read_input_tokens + .unwrap_or_default() + as u64, + })) + }), + ConverseStreamOutput::MessageStop(message_stop) => { + let stop_reason = match message_stop.stop_reason { + StopReason::ToolUse => language_model::StopReason::ToolUse, + _ => language_model::StopReason::EndTurn, + }; + Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason))) + } + _ => None, + }; + + Some((result, state)) + } + Err(err) => Some(( + Some(Err(LanguageModelCompletionError::Other(anyhow!(err)))), + state, + )), + }, + None => None, + } + }) + .filter_map(|result| async move { result }) } struct ConfigurationView {