assistant: Add foundation for receiving tool uses from Anthropic models (#17170)
This PR updates the Assistant with support for receiving tool uses from Anthropic models and capturing them as text in the context editor. This is just laying the foundation for tool use. We don't yet fulfill the tool uses yet, or define any tools for the model to use. Here's an example of what it looks like using the example `get_weather` tool from the Anthropic docs: <img width="644" alt="Screenshot 2024-08-30 at 1 51 13 PM" src="https://github.com/user-attachments/assets/3614f953-0689-423c-8955-b146729ea638"> Release Notes: - N/A
This commit is contained in:
parent
ea25d438d1
commit
68ea661711
8 changed files with 114 additions and 25 deletions
|
@ -330,26 +330,94 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn extract_text_from_events(
|
pub fn extract_content_from_events(
|
||||||
response: impl Stream<Item = Result<Event, AnthropicError>>,
|
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||||
) -> impl Stream<Item = Result<String, AnthropicError>> {
|
) -> impl Stream<Item = Result<String, AnthropicError>> {
|
||||||
response.filter_map(|response| async move {
|
struct State {
|
||||||
match response {
|
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||||
Ok(response) => match response {
|
current_tool_use_index: Option<usize>,
|
||||||
Event::ContentBlockStart { content_block, .. } => match content_block {
|
|
||||||
ResponseContent::Text { text, .. } => Some(Ok(text)),
|
|
||||||
_ => None,
|
|
||||||
},
|
|
||||||
Event::ContentBlockDelta { delta, .. } => match delta {
|
|
||||||
ContentDelta::TextDelta { text } => Some(Ok(text)),
|
|
||||||
_ => None,
|
|
||||||
},
|
|
||||||
Event::Error { error } => Some(Err(AnthropicError::ApiError(error))),
|
|
||||||
_ => None,
|
|
||||||
},
|
|
||||||
Err(error) => Some(Err(error)),
|
|
||||||
}
|
}
|
||||||
})
|
|
||||||
|
const INDENT: &str = " ";
|
||||||
|
const NEWLINE: char = '\n';
|
||||||
|
|
||||||
|
futures::stream::unfold(
|
||||||
|
State {
|
||||||
|
events,
|
||||||
|
current_tool_use_index: None,
|
||||||
|
},
|
||||||
|
|mut state| async move {
|
||||||
|
while let Some(event) = state.events.next().await {
|
||||||
|
match event {
|
||||||
|
Ok(event) => match event {
|
||||||
|
Event::ContentBlockStart {
|
||||||
|
index,
|
||||||
|
content_block,
|
||||||
|
} => match content_block {
|
||||||
|
ResponseContent::Text { text } => {
|
||||||
|
return Some((Ok(text), state));
|
||||||
|
}
|
||||||
|
ResponseContent::ToolUse { id, name, .. } => {
|
||||||
|
state.current_tool_use_index = Some(index);
|
||||||
|
|
||||||
|
let mut text = String::new();
|
||||||
|
text.push(NEWLINE);
|
||||||
|
|
||||||
|
text.push_str("<tool_use>");
|
||||||
|
text.push(NEWLINE);
|
||||||
|
|
||||||
|
text.push_str(INDENT);
|
||||||
|
text.push_str("<id>");
|
||||||
|
text.push_str(&id);
|
||||||
|
text.push_str("</id>");
|
||||||
|
text.push(NEWLINE);
|
||||||
|
|
||||||
|
text.push_str(INDENT);
|
||||||
|
text.push_str("<name>");
|
||||||
|
text.push_str(&name);
|
||||||
|
text.push_str("</name>");
|
||||||
|
text.push(NEWLINE);
|
||||||
|
|
||||||
|
text.push_str(INDENT);
|
||||||
|
text.push_str("<input>");
|
||||||
|
|
||||||
|
return Some((Ok(text), state));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Event::ContentBlockDelta { index, delta } => match delta {
|
||||||
|
ContentDelta::TextDelta { text } => {
|
||||||
|
return Some((Ok(text), state));
|
||||||
|
}
|
||||||
|
ContentDelta::InputJsonDelta { partial_json } => {
|
||||||
|
if Some(index) == state.current_tool_use_index {
|
||||||
|
return Some((Ok(partial_json), state));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Event::ContentBlockStop { index } => {
|
||||||
|
if Some(index) == state.current_tool_use_index.take() {
|
||||||
|
let mut text = String::new();
|
||||||
|
text.push_str("</input>");
|
||||||
|
text.push(NEWLINE);
|
||||||
|
text.push_str("</tool_use>");
|
||||||
|
|
||||||
|
return Some((Ok(text), state));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Event::Error { error } => {
|
||||||
|
return Some((Err(AnthropicError::ApiError(error)), state));
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
},
|
||||||
|
Err(err) => {
|
||||||
|
return Some((Err(err), state));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn extract_tool_args_from_events(
|
pub async fn extract_tool_args_from_events(
|
||||||
|
|
|
@ -2048,7 +2048,8 @@ impl Context {
|
||||||
|
|
||||||
LanguageModelRequest {
|
LanguageModelRequest {
|
||||||
messages: request_messages,
|
messages: request_messages,
|
||||||
stop: vec![],
|
tools: Vec::new(),
|
||||||
|
stop: Vec::new(),
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2398,7 +2399,8 @@ impl Context {
|
||||||
}));
|
}));
|
||||||
let request = LanguageModelRequest {
|
let request = LanguageModelRequest {
|
||||||
messages: messages.collect(),
|
messages: messages.collect(),
|
||||||
stop: vec![],
|
tools: Vec::new(),
|
||||||
|
stop: Vec::new(),
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -2413,6 +2413,7 @@ impl Codegen {
|
||||||
|
|
||||||
Ok(LanguageModelRequest {
|
Ok(LanguageModelRequest {
|
||||||
messages,
|
messages,
|
||||||
|
tools: Vec::new(),
|
||||||
stop: vec!["|END|>".to_string()],
|
stop: vec!["|END|>".to_string()],
|
||||||
temperature,
|
temperature,
|
||||||
})
|
})
|
||||||
|
|
|
@ -794,6 +794,7 @@ impl PromptLibrary {
|
||||||
content: vec![body.to_string().into()],
|
content: vec![body.to_string().into()],
|
||||||
cache: false,
|
cache: false,
|
||||||
}],
|
}],
|
||||||
|
tools: Vec::new(),
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
temperature: 1.,
|
temperature: 1.,
|
||||||
},
|
},
|
||||||
|
|
|
@ -282,6 +282,7 @@ impl TerminalInlineAssistant {
|
||||||
|
|
||||||
Ok(LanguageModelRequest {
|
Ok(LanguageModelRequest {
|
||||||
messages,
|
messages,
|
||||||
|
tools: Vec::new(),
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
})
|
})
|
||||||
|
|
|
@ -370,7 +370,7 @@ impl LanguageModel for AnthropicModel {
|
||||||
let request = self.stream_completion(request, cx);
|
let request = self.stream_completion(request, cx);
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let response = request.await.map_err(|err| anyhow!(err))?;
|
let response = request.await.map_err(|err| anyhow!(err))?;
|
||||||
Ok(anthropic::extract_text_from_events(response))
|
Ok(anthropic::extract_content_from_events(response))
|
||||||
});
|
});
|
||||||
async move {
|
async move {
|
||||||
Ok(future
|
Ok(future
|
||||||
|
|
|
@ -515,9 +515,9 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(anthropic::extract_text_from_events(
|
Ok(anthropic::extract_content_from_events(Box::pin(
|
||||||
response_lines(response).map_err(AnthropicError::Other),
|
response_lines(response).map_err(AnthropicError::Other),
|
||||||
))
|
)))
|
||||||
});
|
});
|
||||||
async move {
|
async move {
|
||||||
Ok(future
|
Ok(future
|
||||||
|
|
|
@ -221,9 +221,17 @@ impl LanguageModelRequestMessage {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct LanguageModelRequestTool {
|
||||||
|
pub name: String,
|
||||||
|
pub description: String,
|
||||||
|
pub input_schema: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||||
pub struct LanguageModelRequest {
|
pub struct LanguageModelRequest {
|
||||||
pub messages: Vec<LanguageModelRequestMessage>,
|
pub messages: Vec<LanguageModelRequestMessage>,
|
||||||
|
pub tools: Vec<LanguageModelRequestTool>,
|
||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
pub temperature: f32,
|
pub temperature: f32,
|
||||||
}
|
}
|
||||||
|
@ -355,7 +363,15 @@ impl LanguageModelRequest {
|
||||||
messages: new_messages,
|
messages: new_messages,
|
||||||
max_tokens: max_output_tokens,
|
max_tokens: max_output_tokens,
|
||||||
system: Some(system_message),
|
system: Some(system_message),
|
||||||
tools: Vec::new(),
|
tools: self
|
||||||
|
.tools
|
||||||
|
.into_iter()
|
||||||
|
.map(|tool| anthropic::Tool {
|
||||||
|
name: tool.name,
|
||||||
|
description: tool.description,
|
||||||
|
input_schema: tool.input_schema,
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
metadata: None,
|
metadata: None,
|
||||||
stop_sequences: Vec::new(),
|
stop_sequences: Vec::new(),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue