diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 3345765643..6f26ee4c00 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1,10 +1,10 @@ mod supported_countries; -use std::{pin::Pin, str::FromStr}; +use std::str::FromStr; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; -use futures::{AsyncBufReadExt, AsyncReadExt, Stream, StreamExt, io::BufReader, stream::BoxStream}; +use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::http::{HeaderMap, HeaderValue}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use serde::{Deserialize, Serialize}; @@ -437,50 +437,6 @@ pub async fn stream_completion_with_rate_limit_info( } } -pub async fn extract_tool_args_from_events( - tool_name: String, - mut events: Pin>>>, -) -> Result>> { - let mut tool_use_index = None; - while let Some(event) = events.next().await { - if let Event::ContentBlockStart { - index, - content_block: ResponseContent::ToolUse { name, .. }, - } = event? - { - if name == tool_name { - tool_use_index = Some(index); - break; - } - } - } - - let Some(tool_use_index) = tool_use_index else { - return Err(anyhow!("tool not used")); - }; - - Ok(events.filter_map(move |event| { - let result = match event { - Err(error) => Some(Err(error)), - Ok(Event::ContentBlockDelta { index, delta }) => match delta { - ContentDelta::TextDelta { .. } => None, - ContentDelta::ThinkingDelta { .. } => None, - ContentDelta::SignatureDelta { .. } => None, - ContentDelta::InputJsonDelta { partial_json } => { - if index == tool_use_index { - Some(Ok(partial_json)) - } else { - None - } - } - }, - _ => None, - }; - - async move { result } - })) -} - #[derive(Debug, Serialize, Deserialize, Copy, Clone)] #[serde(rename_all = "lowercase")] pub enum CacheControlType { diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 0e04e12773..14bd6b436d 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -8,9 +8,7 @@ use aws_config::Region; use aws_config::stalled_stream_protection::StalledStreamProtectionConfig; use aws_credential_types::Credentials; use aws_http_client::AwsHttpClient; -use bedrock::bedrock_client::types::{ - ContentBlockDelta, ContentBlockStart, ContentBlockStartEvent, ConverseStreamOutput, -}; +use bedrock::bedrock_client::types::{ContentBlockDelta, ContentBlockStart, ConverseStreamOutput}; use bedrock::bedrock_client::{self, Config}; use bedrock::{BedrockError, BedrockInnerContent, BedrockMessage, BedrockStreamingResponse, Model}; use collections::{BTreeMap, HashMap}; @@ -544,70 +542,6 @@ pub fn get_bedrock_tokens( .boxed() } -pub async fn extract_tool_args_from_events( - name: String, - mut events: Pin>>>, - handle: Handle, -) -> Result>> { - handle - .spawn(async move { - let mut tool_use_index = None; - while let Some(event) = events.next().await { - if let BedrockStreamingResponse::ContentBlockStart(ContentBlockStartEvent { - content_block_index, - start, - .. - }) = event? - { - match start { - None => { - continue; - } - Some(start) => match start.as_tool_use() { - Ok(tool_use) => { - if name == tool_use.name { - tool_use_index = Some(content_block_index); - break; - } - } - Err(err) => { - return Err(anyhow!("Failed to parse tool use event: {:?}", err)); - } - }, - } - } - } - - let Some(tool_use_index) = tool_use_index else { - return Err(anyhow!("Tool is not used")); - }; - - Ok(events.filter_map(move |event| { - let result = match event { - Err(_err) => None, - Ok(output) => match output.clone() { - BedrockStreamingResponse::ContentBlockDelta(inner) => { - match inner.clone().delta { - Some(ContentBlockDelta::ToolUse(tool_use)) => { - if inner.content_block_index == tool_use_index { - Some(Ok(tool_use.input)) - } else { - None - } - } - _ => None, - } - } - _ => None, - }, - }; - - async move { result } - })) - }) - .await? -} - pub fn map_to_language_model_completion_events( events: Pin>>>, handle: Handle, diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index ca5eafaacd..4ad13b1d89 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -12,7 +12,6 @@ use serde_json::Value; use std::{ convert::TryFrom, future::{self, Future}, - pin::Pin, }; use strum::EnumIter; @@ -620,57 +619,6 @@ pub fn embed<'a>( } } -pub async fn extract_tool_args_from_events( - tool_name: String, - mut events: Pin>>>, -) -> Result>> { - let mut tool_use_index = None; - let mut first_chunk = None; - while let Some(event) = events.next().await { - let call = event?.choices.into_iter().find_map(|choice| { - choice.delta.tool_calls?.into_iter().find_map(|call| { - if call.function.as_ref()?.name.as_deref()? == tool_name { - Some(call) - } else { - None - } - }) - }); - if let Some(call) = call { - tool_use_index = Some(call.index); - first_chunk = call.function.and_then(|func| func.arguments); - break; - } - } - - let Some(tool_use_index) = tool_use_index else { - return Err(anyhow!("tool not used")); - }; - - Ok(events.filter_map(move |event| { - let result = match event { - Err(error) => Some(Err(error)), - Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| { - choice.delta.tool_calls?.into_iter().find_map(|call| { - if call.index == tool_use_index { - let func = call.function?; - let mut arguments = func.arguments?; - if let Some(mut first_chunk) = first_chunk.take() { - first_chunk.push_str(&arguments); - arguments = first_chunk - } - Some(Ok(arguments)) - } else { - None - } - }) - }), - }; - - async move { result } - })) -} - pub fn extract_text_from_events( response: impl Stream>, ) -> impl Stream> {