Use tool calling instead of XML parsing to generate edit operations (#15385)
Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
f6012cd86e
commit
6e1f7c6e1d
22 changed files with 1155 additions and 853 deletions
|
@ -3,7 +3,7 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
|
|||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use isahc::config::Configurable;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{convert::TryFrom, time::Duration};
|
||||
use std::time::Duration;
|
||||
use strum::EnumIter;
|
||||
|
||||
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
|
||||
|
@ -70,112 +70,53 @@ impl Model {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
pub async fn complete(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
) -> Result<Response> {
|
||||
let uri = format!("{api_url}/v1/messages");
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Anthropic-Version", "2023-06-01")
|
||||
.header("Anthropic-Beta", "tools-2024-04-04")
|
||||
.header("X-Api-Key", api_key)
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
impl TryFrom<String> for Role {
|
||||
type Error = anyhow::Error;
|
||||
let serialized_request = serde_json::to_string(&request)?;
|
||||
let request = request_builder.body(AsyncBody::from(serialized_request))?;
|
||||
|
||||
fn try_from(value: String) -> Result<Self> {
|
||||
match value.as_str() {
|
||||
"user" => Ok(Self::User),
|
||||
"assistant" => Ok(Self::Assistant),
|
||||
_ => Err(anyhow!("invalid role '{value}'")),
|
||||
}
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let mut body = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
let response_message: Response = serde_json::from_slice(&body)?;
|
||||
Ok(response_message)
|
||||
} else {
|
||||
let mut body = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
let body_str = std::str::from_utf8(&body)?;
|
||||
Err(anyhow!(
|
||||
"Failed to connect to API: {} {}",
|
||||
response.status(),
|
||||
body_str
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Role> for String {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => "user".to_owned(),
|
||||
Role::Assistant => "assistant".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Request {
|
||||
pub model: String,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
pub system: String,
|
||||
pub max_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponseEvent {
|
||||
MessageStart {
|
||||
message: ResponseMessage,
|
||||
},
|
||||
ContentBlockStart {
|
||||
index: u32,
|
||||
content_block: ContentBlock,
|
||||
},
|
||||
Ping {},
|
||||
ContentBlockDelta {
|
||||
index: u32,
|
||||
delta: TextDelta,
|
||||
},
|
||||
ContentBlockStop {
|
||||
index: u32,
|
||||
},
|
||||
MessageDelta {
|
||||
delta: ResponseMessage,
|
||||
usage: Usage,
|
||||
},
|
||||
MessageStop {},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ResponseMessage {
|
||||
#[serde(rename = "type")]
|
||||
pub message_type: Option<String>,
|
||||
pub id: Option<String>,
|
||||
pub role: Option<String>,
|
||||
pub content: Option<Vec<String>>,
|
||||
pub model: Option<String>,
|
||||
pub stop_reason: Option<String>,
|
||||
pub stop_sequence: Option<String>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Usage {
|
||||
pub input_tokens: Option<u32>,
|
||||
pub output_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ContentBlock {
|
||||
Text { text: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum TextDelta {
|
||||
TextDelta { text: String },
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
|
||||
) -> Result<BoxStream<'static, Result<Event>>> {
|
||||
let request = StreamingRequest {
|
||||
base: request,
|
||||
stream: true,
|
||||
};
|
||||
let uri = format!("{api_url}/v1/messages");
|
||||
let mut request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
|
@ -187,7 +128,9 @@ pub async fn stream_completion(
|
|||
if let Some(low_speed_timeout) = low_speed_timeout {
|
||||
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
|
||||
}
|
||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
let serialized_request = serde_json::to_string(&request)?;
|
||||
let request = request_builder.body(AsyncBody::from(serialized_request))?;
|
||||
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
|
@ -212,7 +155,7 @@ pub async fn stream_completion(
|
|||
|
||||
let body_str = std::str::from_utf8(&body)?;
|
||||
|
||||
match serde_json::from_str::<ResponseEvent>(body_str) {
|
||||
match serde_json::from_str::<Event>(body_str) {
|
||||
Ok(_) => Err(anyhow!(
|
||||
"Unexpected success response while expecting an error: {}",
|
||||
body_str,
|
||||
|
@ -227,16 +170,18 @@ pub async fn stream_completion(
|
|||
}
|
||||
|
||||
pub fn extract_text_from_events(
|
||||
response: impl Stream<Item = Result<ResponseEvent>>,
|
||||
response: impl Stream<Item = Result<Event>>,
|
||||
) -> impl Stream<Item = Result<String>> {
|
||||
response.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(response) => match response {
|
||||
ResponseEvent::ContentBlockStart { content_block, .. } => match content_block {
|
||||
ContentBlock::Text { text } => Some(Ok(text)),
|
||||
Event::ContentBlockStart { content_block, .. } => match content_block {
|
||||
Content::Text { text } => Some(Ok(text)),
|
||||
_ => None,
|
||||
},
|
||||
ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
|
||||
TextDelta::TextDelta { text } => Some(Ok(text)),
|
||||
Event::ContentBlockDelta { delta, .. } => match delta {
|
||||
ContentDelta::TextDelta { text } => Some(Ok(text)),
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
},
|
||||
|
@ -245,42 +190,162 @@ pub fn extract_text_from_events(
|
|||
})
|
||||
}
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use super::*;
|
||||
// use http::IsahcHttpClient;
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub content: Vec<Content>,
|
||||
}
|
||||
|
||||
// #[tokio::test]
|
||||
// async fn stream_completion_success() {
|
||||
// let http_client = IsahcHttpClient::new().unwrap();
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
// let request = Request {
|
||||
// model: Model::Claude3Opus,
|
||||
// messages: vec![RequestMessage {
|
||||
// role: Role::User,
|
||||
// content: "Ping".to_string(),
|
||||
// }],
|
||||
// stream: true,
|
||||
// system: "Respond to ping with pong".to_string(),
|
||||
// max_tokens: 4096,
|
||||
// };
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum Content {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image")]
|
||||
Image { source: ImageSource },
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
#[serde(rename = "tool_result")]
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
content: String,
|
||||
},
|
||||
}
|
||||
|
||||
// let stream = stream_completion(
|
||||
// &http_client,
|
||||
// "https://api.anthropic.com",
|
||||
// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
|
||||
// request,
|
||||
// )
|
||||
// .await
|
||||
// .unwrap();
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ImageSource {
|
||||
#[serde(rename = "type")]
|
||||
pub source_type: String,
|
||||
pub media_type: String,
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
// stream
|
||||
// .for_each(|event| async {
|
||||
// match event {
|
||||
// Ok(event) => println!("{:?}", event),
|
||||
// Err(e) => eprintln!("Error: {:?}", e),
|
||||
// }
|
||||
// })
|
||||
// .await;
|
||||
// }
|
||||
// }
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum ToolChoice {
|
||||
Auto,
|
||||
Any,
|
||||
Tool { name: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Request {
|
||||
pub model: String,
|
||||
pub max_tokens: u32,
|
||||
pub messages: Vec<Message>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub tools: Vec<Tool>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub system: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<Metadata>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub stop_sequences: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub top_k: Option<u32>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct StreamingRequest {
|
||||
#[serde(flatten)]
|
||||
pub base: Request,
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Metadata {
|
||||
pub user_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub input_tokens: Option<u32>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub output_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Response {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub response_type: String,
|
||||
pub role: Role,
|
||||
pub content: Vec<Content>,
|
||||
pub model: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub stop_reason: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub stop_sequence: Option<String>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum Event {
|
||||
#[serde(rename = "message_start")]
|
||||
MessageStart { message: Response },
|
||||
#[serde(rename = "content_block_start")]
|
||||
ContentBlockStart {
|
||||
index: usize,
|
||||
content_block: Content,
|
||||
},
|
||||
#[serde(rename = "content_block_delta")]
|
||||
ContentBlockDelta { index: usize, delta: ContentDelta },
|
||||
#[serde(rename = "content_block_stop")]
|
||||
ContentBlockStop { index: usize },
|
||||
#[serde(rename = "message_delta")]
|
||||
MessageDelta { delta: MessageDelta, usage: Usage },
|
||||
#[serde(rename = "message_stop")]
|
||||
MessageStop,
|
||||
#[serde(rename = "ping")]
|
||||
Ping,
|
||||
#[serde(rename = "error")]
|
||||
Error { error: ApiError },
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentDelta {
|
||||
#[serde(rename = "text_delta")]
|
||||
TextDelta { text: String },
|
||||
#[serde(rename = "input_json_delta")]
|
||||
InputJsonDelta { partial_json: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct MessageDelta {
|
||||
pub stop_reason: Option<String>,
|
||||
pub stop_sequence: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ApiError {
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
pub message: String,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue