Use tool calling instead of XML parsing to generate edit operations (#15385)
Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
f6012cd86e
commit
6e1f7c6e1d
22 changed files with 1155 additions and 853 deletions
13
Cargo.lock
generated
13
Cargo.lock
generated
|
@ -435,7 +435,6 @@ dependencies = [
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
"regex",
|
"regex",
|
||||||
"rope",
|
"rope",
|
||||||
"roxmltree 0.20.0",
|
|
||||||
"schemars",
|
"schemars",
|
||||||
"search",
|
"search",
|
||||||
"semantic_index",
|
"semantic_index",
|
||||||
|
@ -2641,7 +2640,9 @@ dependencies = [
|
||||||
"language_model",
|
"language_model",
|
||||||
"project",
|
"project",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_json",
|
||||||
"settings",
|
"settings",
|
||||||
"smol",
|
"smol",
|
||||||
"text",
|
"text",
|
||||||
|
@ -4237,7 +4238,7 @@ version = "0.5.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6a595cb550439a117696039dfc69830492058211b771a2a165379f2a1a53d84d"
|
checksum = "6a595cb550439a117696039dfc69830492058211b771a2a165379f2a1a53d84d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"roxmltree 0.19.0",
|
"roxmltree",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -8918,12 +8919,6 @@ version = "0.19.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3cd14fd5e3b777a7422cca79358c57a8f6e3a703d9ac187448d0daf220c2407f"
|
checksum = "3cd14fd5e3b777a7422cca79358c57a8f6e3a703d9ac187448d0daf220c2407f"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "roxmltree"
|
|
||||||
version = "0.20.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "6c20b6793b5c2fa6553b250154b78d6d0db37e72700ae35fad9387a46f487c97"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rpc"
|
name = "rpc"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
@ -11878,7 +11873,7 @@ dependencies = [
|
||||||
"kurbo",
|
"kurbo",
|
||||||
"log",
|
"log",
|
||||||
"pico-args",
|
"pico-args",
|
||||||
"roxmltree 0.19.0",
|
"roxmltree",
|
||||||
"simplecss",
|
"simplecss",
|
||||||
"siphasher 1.0.1",
|
"siphasher 1.0.1",
|
||||||
"strict-num",
|
"strict-num",
|
||||||
|
|
|
@ -3,7 +3,7 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
|
||||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||||
use isahc::config::Configurable;
|
use isahc::config::Configurable;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{convert::TryFrom, time::Duration};
|
use std::time::Duration;
|
||||||
use strum::EnumIter;
|
use strum::EnumIter;
|
||||||
|
|
||||||
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
|
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
|
||||||
|
@ -70,103 +70,40 @@ impl Model {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
pub async fn complete(
|
||||||
#[serde(rename_all = "lowercase")]
|
client: &dyn HttpClient,
|
||||||
pub enum Role {
|
api_url: &str,
|
||||||
User,
|
api_key: &str,
|
||||||
Assistant,
|
request: Request,
|
||||||
}
|
) -> Result<Response> {
|
||||||
|
let uri = format!("{api_url}/v1/messages");
|
||||||
|
let request_builder = HttpRequest::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.uri(uri)
|
||||||
|
.header("Anthropic-Version", "2023-06-01")
|
||||||
|
.header("Anthropic-Beta", "tools-2024-04-04")
|
||||||
|
.header("X-Api-Key", api_key)
|
||||||
|
.header("Content-Type", "application/json");
|
||||||
|
|
||||||
impl TryFrom<String> for Role {
|
let serialized_request = serde_json::to_string(&request)?;
|
||||||
type Error = anyhow::Error;
|
let request = request_builder.body(AsyncBody::from(serialized_request))?;
|
||||||
|
|
||||||
fn try_from(value: String) -> Result<Self> {
|
let mut response = client.send(request).await?;
|
||||||
match value.as_str() {
|
if response.status().is_success() {
|
||||||
"user" => Ok(Self::User),
|
let mut body = Vec::new();
|
||||||
"assistant" => Ok(Self::Assistant),
|
response.body_mut().read_to_end(&mut body).await?;
|
||||||
_ => Err(anyhow!("invalid role '{value}'")),
|
let response_message: Response = serde_json::from_slice(&body)?;
|
||||||
|
Ok(response_message)
|
||||||
|
} else {
|
||||||
|
let mut body = Vec::new();
|
||||||
|
response.body_mut().read_to_end(&mut body).await?;
|
||||||
|
let body_str = std::str::from_utf8(&body)?;
|
||||||
|
Err(anyhow!(
|
||||||
|
"Failed to connect to API: {} {}",
|
||||||
|
response.status(),
|
||||||
|
body_str
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<Role> for String {
|
|
||||||
fn from(val: Role) -> Self {
|
|
||||||
match val {
|
|
||||||
Role::User => "user".to_owned(),
|
|
||||||
Role::Assistant => "assistant".to_owned(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct Request {
|
|
||||||
pub model: String,
|
|
||||||
pub messages: Vec<RequestMessage>,
|
|
||||||
pub stream: bool,
|
|
||||||
pub system: String,
|
|
||||||
pub max_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
pub struct RequestMessage {
|
|
||||||
pub role: Role,
|
|
||||||
pub content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize, Debug)]
|
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
|
||||||
pub enum ResponseEvent {
|
|
||||||
MessageStart {
|
|
||||||
message: ResponseMessage,
|
|
||||||
},
|
|
||||||
ContentBlockStart {
|
|
||||||
index: u32,
|
|
||||||
content_block: ContentBlock,
|
|
||||||
},
|
|
||||||
Ping {},
|
|
||||||
ContentBlockDelta {
|
|
||||||
index: u32,
|
|
||||||
delta: TextDelta,
|
|
||||||
},
|
|
||||||
ContentBlockStop {
|
|
||||||
index: u32,
|
|
||||||
},
|
|
||||||
MessageDelta {
|
|
||||||
delta: ResponseMessage,
|
|
||||||
usage: Usage,
|
|
||||||
},
|
|
||||||
MessageStop {},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
|
||||||
pub struct ResponseMessage {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub message_type: Option<String>,
|
|
||||||
pub id: Option<String>,
|
|
||||||
pub role: Option<String>,
|
|
||||||
pub content: Option<Vec<String>>,
|
|
||||||
pub model: Option<String>,
|
|
||||||
pub stop_reason: Option<String>,
|
|
||||||
pub stop_sequence: Option<String>,
|
|
||||||
pub usage: Option<Usage>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
|
||||||
pub struct Usage {
|
|
||||||
pub input_tokens: Option<u32>,
|
|
||||||
pub output_tokens: Option<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
|
||||||
pub enum ContentBlock {
|
|
||||||
Text { text: String },
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
|
||||||
pub enum TextDelta {
|
|
||||||
TextDelta { text: String },
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn stream_completion(
|
pub async fn stream_completion(
|
||||||
|
@ -175,7 +112,11 @@ pub async fn stream_completion(
|
||||||
api_key: &str,
|
api_key: &str,
|
||||||
request: Request,
|
request: Request,
|
||||||
low_speed_timeout: Option<Duration>,
|
low_speed_timeout: Option<Duration>,
|
||||||
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
|
) -> Result<BoxStream<'static, Result<Event>>> {
|
||||||
|
let request = StreamingRequest {
|
||||||
|
base: request,
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
let uri = format!("{api_url}/v1/messages");
|
let uri = format!("{api_url}/v1/messages");
|
||||||
let mut request_builder = HttpRequest::builder()
|
let mut request_builder = HttpRequest::builder()
|
||||||
.method(Method::POST)
|
.method(Method::POST)
|
||||||
|
@ -187,7 +128,9 @@ pub async fn stream_completion(
|
||||||
if let Some(low_speed_timeout) = low_speed_timeout {
|
if let Some(low_speed_timeout) = low_speed_timeout {
|
||||||
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
|
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
|
||||||
}
|
}
|
||||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
let serialized_request = serde_json::to_string(&request)?;
|
||||||
|
let request = request_builder.body(AsyncBody::from(serialized_request))?;
|
||||||
|
|
||||||
let mut response = client.send(request).await?;
|
let mut response = client.send(request).await?;
|
||||||
if response.status().is_success() {
|
if response.status().is_success() {
|
||||||
let reader = BufReader::new(response.into_body());
|
let reader = BufReader::new(response.into_body());
|
||||||
|
@ -212,7 +155,7 @@ pub async fn stream_completion(
|
||||||
|
|
||||||
let body_str = std::str::from_utf8(&body)?;
|
let body_str = std::str::from_utf8(&body)?;
|
||||||
|
|
||||||
match serde_json::from_str::<ResponseEvent>(body_str) {
|
match serde_json::from_str::<Event>(body_str) {
|
||||||
Ok(_) => Err(anyhow!(
|
Ok(_) => Err(anyhow!(
|
||||||
"Unexpected success response while expecting an error: {}",
|
"Unexpected success response while expecting an error: {}",
|
||||||
body_str,
|
body_str,
|
||||||
|
@ -227,16 +170,18 @@ pub async fn stream_completion(
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn extract_text_from_events(
|
pub fn extract_text_from_events(
|
||||||
response: impl Stream<Item = Result<ResponseEvent>>,
|
response: impl Stream<Item = Result<Event>>,
|
||||||
) -> impl Stream<Item = Result<String>> {
|
) -> impl Stream<Item = Result<String>> {
|
||||||
response.filter_map(|response| async move {
|
response.filter_map(|response| async move {
|
||||||
match response {
|
match response {
|
||||||
Ok(response) => match response {
|
Ok(response) => match response {
|
||||||
ResponseEvent::ContentBlockStart { content_block, .. } => match content_block {
|
Event::ContentBlockStart { content_block, .. } => match content_block {
|
||||||
ContentBlock::Text { text } => Some(Ok(text)),
|
Content::Text { text } => Some(Ok(text)),
|
||||||
|
_ => None,
|
||||||
},
|
},
|
||||||
ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
|
Event::ContentBlockDelta { delta, .. } => match delta {
|
||||||
TextDelta::TextDelta { text } => Some(Ok(text)),
|
ContentDelta::TextDelta { text } => Some(Ok(text)),
|
||||||
|
_ => None,
|
||||||
},
|
},
|
||||||
_ => None,
|
_ => None,
|
||||||
},
|
},
|
||||||
|
@ -245,42 +190,162 @@ pub fn extract_text_from_events(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[cfg(test)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
// mod tests {
|
pub struct Message {
|
||||||
// use super::*;
|
pub role: Role,
|
||||||
// use http::IsahcHttpClient;
|
pub content: Vec<Content>,
|
||||||
|
}
|
||||||
|
|
||||||
// #[tokio::test]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
// async fn stream_completion_success() {
|
#[serde(rename_all = "lowercase")]
|
||||||
// let http_client = IsahcHttpClient::new().unwrap();
|
pub enum Role {
|
||||||
|
User,
|
||||||
|
Assistant,
|
||||||
|
}
|
||||||
|
|
||||||
// let request = Request {
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
// model: Model::Claude3Opus,
|
#[serde(tag = "type")]
|
||||||
// messages: vec![RequestMessage {
|
pub enum Content {
|
||||||
// role: Role::User,
|
#[serde(rename = "text")]
|
||||||
// content: "Ping".to_string(),
|
Text { text: String },
|
||||||
// }],
|
#[serde(rename = "image")]
|
||||||
// stream: true,
|
Image { source: ImageSource },
|
||||||
// system: "Respond to ping with pong".to_string(),
|
#[serde(rename = "tool_use")]
|
||||||
// max_tokens: 4096,
|
ToolUse {
|
||||||
// };
|
id: String,
|
||||||
|
name: String,
|
||||||
|
input: serde_json::Value,
|
||||||
|
},
|
||||||
|
#[serde(rename = "tool_result")]
|
||||||
|
ToolResult {
|
||||||
|
tool_use_id: String,
|
||||||
|
content: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// let stream = stream_completion(
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
// &http_client,
|
pub struct ImageSource {
|
||||||
// "https://api.anthropic.com",
|
#[serde(rename = "type")]
|
||||||
// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
|
pub source_type: String,
|
||||||
// request,
|
pub media_type: String,
|
||||||
// )
|
pub data: String,
|
||||||
// .await
|
}
|
||||||
// .unwrap();
|
|
||||||
|
|
||||||
// stream
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
// .for_each(|event| async {
|
pub struct Tool {
|
||||||
// match event {
|
pub name: String,
|
||||||
// Ok(event) => println!("{:?}", event),
|
pub description: String,
|
||||||
// Err(e) => eprintln!("Error: {:?}", e),
|
pub input_schema: serde_json::Value,
|
||||||
// }
|
}
|
||||||
// })
|
|
||||||
// .await;
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
// }
|
#[serde(tag = "type", rename_all = "lowercase")]
|
||||||
// }
|
pub enum ToolChoice {
|
||||||
|
Auto,
|
||||||
|
Any,
|
||||||
|
Tool { name: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Request {
|
||||||
|
pub model: String,
|
||||||
|
pub max_tokens: u32,
|
||||||
|
pub messages: Vec<Message>,
|
||||||
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||||
|
pub tools: Vec<Tool>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_choice: Option<ToolChoice>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system: Option<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub metadata: Option<Metadata>,
|
||||||
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||||
|
pub stop_sequences: Vec<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_k: Option<u32>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct StreamingRequest {
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub base: Request,
|
||||||
|
pub stream: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Metadata {
|
||||||
|
pub user_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Usage {
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub input_tokens: Option<u32>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub output_tokens: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Response {
|
||||||
|
pub id: String,
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub response_type: String,
|
||||||
|
pub role: Role,
|
||||||
|
pub content: Vec<Content>,
|
||||||
|
pub model: String,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop_reason: Option<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop_sequence: Option<String>,
|
||||||
|
pub usage: Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum Event {
|
||||||
|
#[serde(rename = "message_start")]
|
||||||
|
MessageStart { message: Response },
|
||||||
|
#[serde(rename = "content_block_start")]
|
||||||
|
ContentBlockStart {
|
||||||
|
index: usize,
|
||||||
|
content_block: Content,
|
||||||
|
},
|
||||||
|
#[serde(rename = "content_block_delta")]
|
||||||
|
ContentBlockDelta { index: usize, delta: ContentDelta },
|
||||||
|
#[serde(rename = "content_block_stop")]
|
||||||
|
ContentBlockStop { index: usize },
|
||||||
|
#[serde(rename = "message_delta")]
|
||||||
|
MessageDelta { delta: MessageDelta, usage: Usage },
|
||||||
|
#[serde(rename = "message_stop")]
|
||||||
|
MessageStop,
|
||||||
|
#[serde(rename = "ping")]
|
||||||
|
Ping,
|
||||||
|
#[serde(rename = "error")]
|
||||||
|
Error { error: ApiError },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum ContentDelta {
|
||||||
|
#[serde(rename = "text_delta")]
|
||||||
|
TextDelta { text: String },
|
||||||
|
#[serde(rename = "input_json_delta")]
|
||||||
|
InputJsonDelta { partial_json: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct MessageDelta {
|
||||||
|
pub stop_reason: Option<String>,
|
||||||
|
pub stop_sequence: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ApiError {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub error_type: String,
|
||||||
|
pub message: String,
|
||||||
|
}
|
||||||
|
|
|
@ -75,7 +75,6 @@ util.workspace = true
|
||||||
uuid.workspace = true
|
uuid.workspace = true
|
||||||
workspace.workspace = true
|
workspace.workspace = true
|
||||||
picker.workspace = true
|
picker.workspace = true
|
||||||
roxmltree = "0.20.0"
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
completion = { workspace = true, features = ["test-support"] }
|
completion = { workspace = true, features = ["test-support"] }
|
||||||
|
|
|
@ -1232,12 +1232,16 @@ impl ContextEditor {
|
||||||
|
|
||||||
fn apply_edit_step(&mut self, cx: &mut ViewContext<Self>) -> bool {
|
fn apply_edit_step(&mut self, cx: &mut ViewContext<Self>) -> bool {
|
||||||
if let Some(step) = self.active_edit_step.as_ref() {
|
if let Some(step) = self.active_edit_step.as_ref() {
|
||||||
|
let assist_ids = step.assist_ids.clone();
|
||||||
|
cx.window_context().defer(|cx| {
|
||||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||||
for assist_id in &step.assist_ids {
|
for assist_id in assist_ids {
|
||||||
assistant.start_assist(*assist_id, cx);
|
assistant.start_assist(assist_id, cx);
|
||||||
}
|
}
|
||||||
!step.assist_ids.is_empty()
|
|
||||||
})
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
!step.assist_ids.is_empty()
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
@ -1286,11 +1290,7 @@ impl ContextEditor {
|
||||||
.collect::<String>()
|
.collect::<String>()
|
||||||
));
|
));
|
||||||
match &step.operations {
|
match &step.operations {
|
||||||
Some(EditStepOperations::Parsed {
|
Some(EditStepOperations::Ready(operations)) => {
|
||||||
operations,
|
|
||||||
raw_output,
|
|
||||||
}) => {
|
|
||||||
output.push_str(&format!("Raw Output:\n{raw_output}\n"));
|
|
||||||
output.push_str("Parsed Operations:\n");
|
output.push_str("Parsed Operations:\n");
|
||||||
for op in operations {
|
for op in operations {
|
||||||
output.push_str(&format!(" {:?}\n", op));
|
output.push_str(&format!(" {:?}\n", op));
|
||||||
|
@ -1794,13 +1794,12 @@ impl ContextEditor {
|
||||||
.anchor_in_excerpt(excerpt_id, suggestion.range.end)
|
.anchor_in_excerpt(excerpt_id, suggestion.range.end)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
};
|
};
|
||||||
let initial_text = suggestion.prepend_newline.then(|| "\n".into());
|
|
||||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||||
assist_ids.push(assistant.suggest_assist(
|
assist_ids.push(assistant.suggest_assist(
|
||||||
&editor,
|
&editor,
|
||||||
range,
|
range,
|
||||||
description,
|
description,
|
||||||
initial_text,
|
suggestion.initial_insertion,
|
||||||
Some(workspace.clone()),
|
Some(workspace.clone()),
|
||||||
assistant_panel.upgrade().as_ref(),
|
assistant_panel.upgrade().as_ref(),
|
||||||
cx,
|
cx,
|
||||||
|
@ -1862,9 +1861,11 @@ impl ContextEditor {
|
||||||
.anchor_in_excerpt(excerpt_id, suggestion.range.end)
|
.anchor_in_excerpt(excerpt_id, suggestion.range.end)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
};
|
};
|
||||||
let initial_text =
|
inline_assist_suggestions.push((
|
||||||
suggestion.prepend_newline.then(|| "\n".to_string());
|
range,
|
||||||
inline_assist_suggestions.push((range, description, initial_text));
|
description,
|
||||||
|
suggestion.initial_insertion,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1875,12 +1876,12 @@ impl ContextEditor {
|
||||||
.new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), true, cx))?;
|
.new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), true, cx))?;
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||||
for (range, description, initial_text) in inline_assist_suggestions {
|
for (range, description, initial_insertion) in inline_assist_suggestions {
|
||||||
assist_ids.push(assistant.suggest_assist(
|
assist_ids.push(assistant.suggest_assist(
|
||||||
&editor,
|
&editor,
|
||||||
range,
|
range,
|
||||||
description,
|
description,
|
||||||
initial_text,
|
initial_insertion,
|
||||||
Some(workspace.clone()),
|
Some(workspace.clone()),
|
||||||
assistant_panel.upgrade().as_ref(),
|
assistant_panel.upgrade().as_ref(),
|
||||||
cx,
|
cx,
|
||||||
|
@ -2188,7 +2189,7 @@ impl ContextEditor {
|
||||||
let button_text = match self.edit_step_for_cursor(cx) {
|
let button_text = match self.edit_step_for_cursor(cx) {
|
||||||
Some(edit_step) => match &edit_step.operations {
|
Some(edit_step) => match &edit_step.operations {
|
||||||
Some(EditStepOperations::Pending(_)) => "Computing Changes...",
|
Some(EditStepOperations::Pending(_)) => "Computing Changes...",
|
||||||
Some(EditStepOperations::Parsed { .. }) => "Apply Changes",
|
Some(EditStepOperations::Ready(_)) => "Apply Changes",
|
||||||
None => "Send",
|
None => "Send",
|
||||||
},
|
},
|
||||||
None => "Send",
|
None => "Send",
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider,
|
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
|
||||||
MessageId, MessageStatus,
|
LanguageModelCompletionProvider, MessageId, MessageStatus,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use assistant_slash_command::{
|
use assistant_slash_command::{
|
||||||
|
@ -18,11 +18,11 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
|
||||||
use language::{
|
use language::{
|
||||||
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
|
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
|
||||||
};
|
};
|
||||||
use language_model::LanguageModelRequestMessage;
|
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
|
||||||
use language_model::{LanguageModelRequest, Role};
|
|
||||||
use open_ai::Model as OpenAiModel;
|
use open_ai::Model as OpenAiModel;
|
||||||
use paths::contexts_dir;
|
use paths::contexts_dir;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{
|
use std::{
|
||||||
cmp,
|
cmp,
|
||||||
|
@ -352,7 +352,7 @@ pub struct EditSuggestion {
|
||||||
pub range: Range<language::Anchor>,
|
pub range: Range<language::Anchor>,
|
||||||
/// If None, assume this is a suggestion to delete the range rather than transform it.
|
/// If None, assume this is a suggestion to delete the range rather than transform it.
|
||||||
pub description: Option<String>,
|
pub description: Option<String>,
|
||||||
pub prepend_newline: bool,
|
pub initial_insertion: Option<InitialInsertion>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EditStep {
|
impl EditStep {
|
||||||
|
@ -361,7 +361,7 @@ impl EditStep {
|
||||||
project: &Model<Project>,
|
project: &Model<Project>,
|
||||||
cx: &AppContext,
|
cx: &AppContext,
|
||||||
) -> Task<HashMap<Model<Buffer>, Vec<EditSuggestionGroup>>> {
|
) -> Task<HashMap<Model<Buffer>, Vec<EditSuggestionGroup>>> {
|
||||||
let Some(EditStepOperations::Parsed { operations, .. }) = &self.operations else {
|
let Some(EditStepOperations::Ready(operations)) = &self.operations else {
|
||||||
return Task::ready(HashMap::default());
|
return Task::ready(HashMap::default());
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -471,32 +471,28 @@ impl EditStep {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum EditStepOperations {
|
pub enum EditStepOperations {
|
||||||
Pending(Task<Result<()>>),
|
Pending(Task<Option<()>>),
|
||||||
Parsed {
|
Ready(Vec<EditOperation>),
|
||||||
operations: Vec<EditOperation>,
|
|
||||||
raw_output: String,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Debug for EditStepOperations {
|
impl Debug for EditStepOperations {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"),
|
EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"),
|
||||||
EditStepOperations::Parsed {
|
EditStepOperations::Ready(operations) => f
|
||||||
operations,
|
|
||||||
raw_output,
|
|
||||||
} => f
|
|
||||||
.debug_struct("EditStepOperations::Parsed")
|
.debug_struct("EditStepOperations::Parsed")
|
||||||
.field("operations", operations)
|
.field("operations", operations)
|
||||||
.field("raw_output", raw_output)
|
|
||||||
.finish(),
|
.finish(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
/// A description of an operation to apply to one location in the codebase.
|
||||||
|
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
|
||||||
pub struct EditOperation {
|
pub struct EditOperation {
|
||||||
|
/// The path to the file containing the relevant operation
|
||||||
pub path: String,
|
pub path: String,
|
||||||
|
#[serde(flatten)]
|
||||||
pub kind: EditOperationKind,
|
pub kind: EditOperationKind,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -523,7 +519,7 @@ impl EditOperation {
|
||||||
parse_status.changed().await?;
|
parse_status.changed().await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let prepend_newline = kind.prepend_newline();
|
let initial_insertion = kind.initial_insertion();
|
||||||
let suggestion_range = if let Some(symbol) = kind.symbol() {
|
let suggestion_range = if let Some(symbol) = kind.symbol() {
|
||||||
let outline = buffer
|
let outline = buffer
|
||||||
.update(&mut cx, |buffer, _| buffer.snapshot().outline(None))?
|
.update(&mut cx, |buffer, _| buffer.snapshot().outline(None))?
|
||||||
|
@ -601,39 +597,61 @@ impl EditOperation {
|
||||||
EditSuggestion {
|
EditSuggestion {
|
||||||
range: suggestion_range,
|
range: suggestion_range,
|
||||||
description: kind.description().map(ToString::to_string),
|
description: kind.description().map(ToString::to_string),
|
||||||
prepend_newline,
|
initial_insertion,
|
||||||
},
|
},
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
|
||||||
|
#[serde(tag = "kind")]
|
||||||
pub enum EditOperationKind {
|
pub enum EditOperationKind {
|
||||||
|
/// Rewrite the specified symbol in its entirely based on the given description.
|
||||||
Update {
|
Update {
|
||||||
|
/// A full path to the symbol to be rewritten from the provided list.
|
||||||
symbol: String,
|
symbol: String,
|
||||||
|
/// A brief one-line description of the change that should be applied.
|
||||||
description: String,
|
description: String,
|
||||||
},
|
},
|
||||||
|
/// Create a new file with the given path based on the given description.
|
||||||
Create {
|
Create {
|
||||||
|
/// A brief one-line description of the change that should be applied.
|
||||||
description: String,
|
description: String,
|
||||||
},
|
},
|
||||||
|
/// Insert a new symbol based on the given description before the specified symbol.
|
||||||
InsertSiblingBefore {
|
InsertSiblingBefore {
|
||||||
|
/// A full path to the symbol to be rewritten from the provided list.
|
||||||
symbol: String,
|
symbol: String,
|
||||||
|
/// A brief one-line description of the change that should be applied.
|
||||||
description: String,
|
description: String,
|
||||||
},
|
},
|
||||||
|
/// Insert a new symbol based on the given description after the specified symbol.
|
||||||
InsertSiblingAfter {
|
InsertSiblingAfter {
|
||||||
|
/// A full path to the symbol to be rewritten from the provided list.
|
||||||
symbol: String,
|
symbol: String,
|
||||||
|
/// A brief one-line description of the change that should be applied.
|
||||||
description: String,
|
description: String,
|
||||||
},
|
},
|
||||||
|
/// Insert a new symbol as a child of the specified symbol at the start.
|
||||||
PrependChild {
|
PrependChild {
|
||||||
|
/// An optional full path to the symbol to be rewritten from the provided list.
|
||||||
|
/// If not provided, the edit should be applied at the top of the file.
|
||||||
symbol: Option<String>,
|
symbol: Option<String>,
|
||||||
|
/// A brief one-line description of the change that should be applied.
|
||||||
description: String,
|
description: String,
|
||||||
},
|
},
|
||||||
|
/// Insert a new symbol as a child of the specified symbol at the end.
|
||||||
AppendChild {
|
AppendChild {
|
||||||
|
/// An optional full path to the symbol to be rewritten from the provided list.
|
||||||
|
/// If not provided, the edit should be applied at the top of the file.
|
||||||
symbol: Option<String>,
|
symbol: Option<String>,
|
||||||
|
/// A brief one-line description of the change that should be applied.
|
||||||
description: String,
|
description: String,
|
||||||
},
|
},
|
||||||
|
/// Delete the specified symbol.
|
||||||
Delete {
|
Delete {
|
||||||
|
/// A full path to the symbol to be rewritten from the provided list.
|
||||||
symbol: String,
|
symbol: String,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -663,13 +681,13 @@ impl EditOperationKind {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn prepend_newline(&self) -> bool {
|
pub fn initial_insertion(&self) -> Option<InitialInsertion> {
|
||||||
match self {
|
match self {
|
||||||
Self::PrependChild { .. }
|
EditOperationKind::InsertSiblingBefore { .. } => Some(InitialInsertion::NewlineAfter),
|
||||||
| Self::AppendChild { .. }
|
EditOperationKind::InsertSiblingAfter { .. } => Some(InitialInsertion::NewlineBefore),
|
||||||
| Self::InsertSiblingAfter { .. }
|
EditOperationKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
|
||||||
| Self::InsertSiblingBefore { .. } => true,
|
EditOperationKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
|
||||||
_ => false,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1137,18 +1155,15 @@ impl Context {
|
||||||
.timer(Duration::from_millis(200))
|
.timer(Duration::from_millis(200))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
if let Some(token_count) = cx.update(|cx| {
|
let token_count = cx
|
||||||
|
.update(|cx| {
|
||||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||||
})? {
|
})?
|
||||||
let token_count = token_count.await?;
|
.await?;
|
||||||
|
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
this.token_count = Some(token_count);
|
this.token_count = Some(token_count);
|
||||||
cx.notify()
|
cx.notify()
|
||||||
})?;
|
})
|
||||||
}
|
|
||||||
|
|
||||||
anyhow::Ok(())
|
|
||||||
}
|
}
|
||||||
.log_err()
|
.log_err()
|
||||||
});
|
});
|
||||||
|
@ -1304,7 +1319,24 @@ impl Context {
|
||||||
&self,
|
&self,
|
||||||
edit_step: &EditStep,
|
edit_step: &EditStep,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<()>> {
|
) -> Task<Option<()>> {
|
||||||
|
#[derive(Debug, Deserialize, JsonSchema)]
|
||||||
|
struct EditTool {
|
||||||
|
/// A sequence of operations to apply to the codebase.
|
||||||
|
/// When multiple operations are required for a step, be sure to include multiple operations in this list.
|
||||||
|
operations: Vec<EditOperation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelTool for EditTool {
|
||||||
|
fn name() -> String {
|
||||||
|
"edit".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description() -> String {
|
||||||
|
"suggest edits to one or more locations in the codebase".into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let mut request = self.to_completion_request(cx);
|
let mut request = self.to_completion_request(cx);
|
||||||
let edit_step_range = edit_step.source_range.clone();
|
let edit_step_range = edit_step.source_range.clone();
|
||||||
let step_text = self
|
let step_text = self
|
||||||
|
@ -1313,7 +1345,8 @@ impl Context {
|
||||||
.text_for_range(edit_step_range.clone())
|
.text_for_range(edit_step_range.clone())
|
||||||
.collect::<String>();
|
.collect::<String>();
|
||||||
|
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| {
|
||||||
|
async move {
|
||||||
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
|
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
|
||||||
|
|
||||||
let mut prompt = prompt_store.operations_prompt();
|
let mut prompt = prompt_store.operations_prompt();
|
||||||
|
@ -1324,13 +1357,13 @@ impl Context {
|
||||||
content: prompt,
|
content: prompt,
|
||||||
});
|
});
|
||||||
|
|
||||||
let raw_output = cx
|
let tool_use = cx
|
||||||
.update(|cx| {
|
.update(|cx| {
|
||||||
LanguageModelCompletionProvider::read_global(cx).complete(request, cx)
|
LanguageModelCompletionProvider::read_global(cx)
|
||||||
|
.use_tool::<EditTool>(request, cx)
|
||||||
})?
|
})?
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let operations = Self::parse_edit_operations(&raw_output);
|
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
let step_index = this
|
let step_index = this
|
||||||
.edit_steps
|
.edit_steps
|
||||||
|
@ -1340,133 +1373,13 @@ impl Context {
|
||||||
})
|
})
|
||||||
.map_err(|_| anyhow!("edit step not found"))?;
|
.map_err(|_| anyhow!("edit step not found"))?;
|
||||||
if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
|
if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
|
||||||
edit_step.operations = Some(EditStepOperations::Parsed {
|
edit_step.operations = Some(EditStepOperations::Ready(tool_use.operations));
|
||||||
operations,
|
|
||||||
raw_output,
|
|
||||||
});
|
|
||||||
cx.emit(ContextEvent::EditStepsChanged);
|
cx.emit(ContextEvent::EditStepsChanged);
|
||||||
}
|
}
|
||||||
anyhow::Ok(())
|
anyhow::Ok(())
|
||||||
})?
|
})?
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
.log_err()
|
||||||
fn parse_edit_operations(xml: &str) -> Vec<EditOperation> {
|
|
||||||
let Some(start_ix) = xml.find("<operations>") else {
|
|
||||||
return Vec::new();
|
|
||||||
};
|
|
||||||
let Some(end_ix) = xml[start_ix..].find("</operations>") else {
|
|
||||||
return Vec::new();
|
|
||||||
};
|
|
||||||
let end_ix = end_ix + start_ix + "</operations>".len();
|
|
||||||
|
|
||||||
let doc = roxmltree::Document::parse(&xml[start_ix..end_ix]).log_err();
|
|
||||||
doc.map_or(Vec::new(), |doc| {
|
|
||||||
doc.root_element()
|
|
||||||
.children()
|
|
||||||
.map(|node| {
|
|
||||||
let tag_name = node.tag_name().name();
|
|
||||||
let path = node
|
|
||||||
.attribute("path")
|
|
||||||
.with_context(|| {
|
|
||||||
format!("invalid node {node:?}, missing attribute 'path'")
|
|
||||||
})?
|
|
||||||
.to_string();
|
|
||||||
let kind = match tag_name {
|
|
||||||
"update" => EditOperationKind::Update {
|
|
||||||
symbol: node
|
|
||||||
.attribute("symbol")
|
|
||||||
.with_context(|| {
|
|
||||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
|
||||||
})?
|
|
||||||
.to_string(),
|
|
||||||
description: node
|
|
||||||
.attribute("description")
|
|
||||||
.with_context(|| {
|
|
||||||
format!(
|
|
||||||
"invalid node {node:?}, missing attribute 'description'"
|
|
||||||
)
|
|
||||||
})?
|
|
||||||
.to_string(),
|
|
||||||
},
|
|
||||||
"create" => EditOperationKind::Create {
|
|
||||||
description: node
|
|
||||||
.attribute("description")
|
|
||||||
.with_context(|| {
|
|
||||||
format!(
|
|
||||||
"invalid node {node:?}, missing attribute 'description'"
|
|
||||||
)
|
|
||||||
})?
|
|
||||||
.to_string(),
|
|
||||||
},
|
|
||||||
"insert_sibling_after" => EditOperationKind::InsertSiblingAfter {
|
|
||||||
symbol: node
|
|
||||||
.attribute("symbol")
|
|
||||||
.with_context(|| {
|
|
||||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
|
||||||
})?
|
|
||||||
.to_string(),
|
|
||||||
description: node
|
|
||||||
.attribute("description")
|
|
||||||
.with_context(|| {
|
|
||||||
format!(
|
|
||||||
"invalid node {node:?}, missing attribute 'description'"
|
|
||||||
)
|
|
||||||
})?
|
|
||||||
.to_string(),
|
|
||||||
},
|
|
||||||
"insert_sibling_before" => EditOperationKind::InsertSiblingBefore {
|
|
||||||
symbol: node
|
|
||||||
.attribute("symbol")
|
|
||||||
.with_context(|| {
|
|
||||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
|
||||||
})?
|
|
||||||
.to_string(),
|
|
||||||
description: node
|
|
||||||
.attribute("description")
|
|
||||||
.with_context(|| {
|
|
||||||
format!(
|
|
||||||
"invalid node {node:?}, missing attribute 'description'"
|
|
||||||
)
|
|
||||||
})?
|
|
||||||
.to_string(),
|
|
||||||
},
|
|
||||||
"prepend_child" => EditOperationKind::PrependChild {
|
|
||||||
symbol: node.attribute("symbol").map(String::from),
|
|
||||||
description: node
|
|
||||||
.attribute("description")
|
|
||||||
.with_context(|| {
|
|
||||||
format!(
|
|
||||||
"invalid node {node:?}, missing attribute 'description'"
|
|
||||||
)
|
|
||||||
})?
|
|
||||||
.to_string(),
|
|
||||||
},
|
|
||||||
"append_child" => EditOperationKind::AppendChild {
|
|
||||||
symbol: node.attribute("symbol").map(String::from),
|
|
||||||
description: node
|
|
||||||
.attribute("description")
|
|
||||||
.with_context(|| {
|
|
||||||
format!(
|
|
||||||
"invalid node {node:?}, missing attribute 'description'"
|
|
||||||
)
|
|
||||||
})?
|
|
||||||
.to_string(),
|
|
||||||
},
|
|
||||||
"delete" => EditOperationKind::Delete {
|
|
||||||
symbol: node
|
|
||||||
.attribute("symbol")
|
|
||||||
.with_context(|| {
|
|
||||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
|
||||||
})?
|
|
||||||
.to_string(),
|
|
||||||
},
|
|
||||||
_ => return Err(anyhow!("invalid node {node:?}")),
|
|
||||||
};
|
|
||||||
anyhow::Ok(EditOperation { path, kind })
|
|
||||||
})
|
|
||||||
.filter_map(|op| op.log_err())
|
|
||||||
.collect()
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3083,55 +2996,6 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_parse_edit_operations() {
|
|
||||||
let operations = indoc! {r#"
|
|
||||||
Here are the operations to make all fields of the Canvas struct private:
|
|
||||||
|
|
||||||
<operations>
|
|
||||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub pixels" description="Remove pub keyword from pixels field" />
|
|
||||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub size" description="Remove pub keyword from size field" />
|
|
||||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub stride" description="Remove pub keyword from stride field" />
|
|
||||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub format" description="Remove pub keyword from format field" />
|
|
||||||
</operations>
|
|
||||||
"#};
|
|
||||||
|
|
||||||
let parsed_operations = Context::parse_edit_operations(operations);
|
|
||||||
assert_eq!(
|
|
||||||
parsed_operations,
|
|
||||||
vec![
|
|
||||||
EditOperation {
|
|
||||||
path: "font-kit/src/canvas.rs".to_string(),
|
|
||||||
kind: EditOperationKind::Update {
|
|
||||||
symbol: "pub struct Canvas pub pixels".to_string(),
|
|
||||||
description: "Remove pub keyword from pixels field".to_string(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
EditOperation {
|
|
||||||
path: "font-kit/src/canvas.rs".to_string(),
|
|
||||||
kind: EditOperationKind::Update {
|
|
||||||
symbol: "pub struct Canvas pub size".to_string(),
|
|
||||||
description: "Remove pub keyword from size field".to_string(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
EditOperation {
|
|
||||||
path: "font-kit/src/canvas.rs".to_string(),
|
|
||||||
kind: EditOperationKind::Update {
|
|
||||||
symbol: "pub struct Canvas pub stride".to_string(),
|
|
||||||
description: "Remove pub keyword from stride field".to_string(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
EditOperation {
|
|
||||||
path: "font-kit/src/canvas.rs".to_string(),
|
|
||||||
kind: EditOperationKind::Update {
|
|
||||||
symbol: "pub struct Canvas pub format".to_string(),
|
|
||||||
description: "Remove pub keyword from format field".to_string(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_serialization(cx: &mut TestAppContext) {
|
async fn test_serialization(cx: &mut TestAppContext) {
|
||||||
let settings_store = cx.update(SettingsStore::test);
|
let settings_store = cx.update(SettingsStore::test);
|
||||||
|
|
|
@ -17,7 +17,7 @@ use editor::{
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
use futures::{
|
use futures::{
|
||||||
channel::mpsc,
|
channel::mpsc,
|
||||||
future::LocalBoxFuture,
|
future::{BoxFuture, LocalBoxFuture},
|
||||||
stream::{self, BoxStream},
|
stream::{self, BoxStream},
|
||||||
SinkExt, Stream, StreamExt,
|
SinkExt, Stream, StreamExt,
|
||||||
};
|
};
|
||||||
|
@ -36,7 +36,7 @@ use similar::TextDiff;
|
||||||
use smol::future::FutureExt;
|
use smol::future::FutureExt;
|
||||||
use std::{
|
use std::{
|
||||||
cmp,
|
cmp,
|
||||||
future::Future,
|
future::{self, Future},
|
||||||
mem,
|
mem,
|
||||||
ops::{Range, RangeInclusive},
|
ops::{Range, RangeInclusive},
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
|
@ -46,7 +46,7 @@ use std::{
|
||||||
};
|
};
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::{prelude::*, IconButtonShape, Tooltip};
|
use ui::{prelude::*, IconButtonShape, Tooltip};
|
||||||
use util::RangeExt;
|
use util::{RangeExt, ResultExt};
|
||||||
use workspace::{notifications::NotificationId, Toast, Workspace};
|
use workspace::{notifications::NotificationId, Toast, Workspace};
|
||||||
|
|
||||||
pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
|
pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
|
||||||
|
@ -187,7 +187,13 @@ impl InlineAssistant {
|
||||||
let [prompt_block_id, end_block_id] =
|
let [prompt_block_id, end_block_id] =
|
||||||
self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
|
self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
|
||||||
|
|
||||||
assists.push((assist_id, prompt_editor, prompt_block_id, end_block_id));
|
assists.push((
|
||||||
|
assist_id,
|
||||||
|
range,
|
||||||
|
prompt_editor,
|
||||||
|
prompt_block_id,
|
||||||
|
end_block_id,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let editor_assists = self
|
let editor_assists = self
|
||||||
|
@ -195,7 +201,7 @@ impl InlineAssistant {
|
||||||
.entry(editor.downgrade())
|
.entry(editor.downgrade())
|
||||||
.or_insert_with(|| EditorInlineAssists::new(&editor, cx));
|
.or_insert_with(|| EditorInlineAssists::new(&editor, cx));
|
||||||
let mut assist_group = InlineAssistGroup::new();
|
let mut assist_group = InlineAssistGroup::new();
|
||||||
for (assist_id, prompt_editor, prompt_block_id, end_block_id) in assists {
|
for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
|
||||||
self.assists.insert(
|
self.assists.insert(
|
||||||
assist_id,
|
assist_id,
|
||||||
InlineAssist::new(
|
InlineAssist::new(
|
||||||
|
@ -206,6 +212,7 @@ impl InlineAssistant {
|
||||||
&prompt_editor,
|
&prompt_editor,
|
||||||
prompt_block_id,
|
prompt_block_id,
|
||||||
end_block_id,
|
end_block_id,
|
||||||
|
range,
|
||||||
prompt_editor.read(cx).codegen.clone(),
|
prompt_editor.read(cx).codegen.clone(),
|
||||||
workspace.clone(),
|
workspace.clone(),
|
||||||
cx,
|
cx,
|
||||||
|
@ -227,7 +234,7 @@ impl InlineAssistant {
|
||||||
editor: &View<Editor>,
|
editor: &View<Editor>,
|
||||||
mut range: Range<Anchor>,
|
mut range: Range<Anchor>,
|
||||||
initial_prompt: String,
|
initial_prompt: String,
|
||||||
initial_insertion: Option<String>,
|
initial_insertion: Option<InitialInsertion>,
|
||||||
workspace: Option<WeakView<Workspace>>,
|
workspace: Option<WeakView<Workspace>>,
|
||||||
assistant_panel: Option<&View<AssistantPanel>>,
|
assistant_panel: Option<&View<AssistantPanel>>,
|
||||||
cx: &mut WindowContext,
|
cx: &mut WindowContext,
|
||||||
|
@ -239,22 +246,30 @@ impl InlineAssistant {
|
||||||
let assist_id = self.next_assist_id.post_inc();
|
let assist_id = self.next_assist_id.post_inc();
|
||||||
|
|
||||||
let buffer = editor.read(cx).buffer().clone();
|
let buffer = editor.read(cx).buffer().clone();
|
||||||
let prepend_transaction_id = initial_insertion.and_then(|initial_insertion| {
|
{
|
||||||
buffer.update(cx, |buffer, cx| {
|
let snapshot = buffer.read(cx).read(cx);
|
||||||
buffer.start_transaction(cx);
|
|
||||||
buffer.edit([(range.start..range.start, initial_insertion)], None, cx);
|
|
||||||
buffer.end_transaction(cx)
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
range.start = range.start.bias_left(&buffer.read(cx).read(cx));
|
let mut point_range = range.to_point(&snapshot);
|
||||||
range.end = range.end.bias_right(&buffer.read(cx).read(cx));
|
if point_range.is_empty() {
|
||||||
|
point_range.start.column = 0;
|
||||||
|
point_range.end.column = 0;
|
||||||
|
} else {
|
||||||
|
point_range.start.column = 0;
|
||||||
|
if point_range.end.row > point_range.start.row && point_range.end.column == 0 {
|
||||||
|
point_range.end.row -= 1;
|
||||||
|
}
|
||||||
|
point_range.end.column = snapshot.line_len(MultiBufferRow(point_range.end.row));
|
||||||
|
}
|
||||||
|
|
||||||
|
range.start = snapshot.anchor_before(point_range.start);
|
||||||
|
range.end = snapshot.anchor_after(point_range.end);
|
||||||
|
}
|
||||||
|
|
||||||
let codegen = cx.new_model(|cx| {
|
let codegen = cx.new_model(|cx| {
|
||||||
Codegen::new(
|
Codegen::new(
|
||||||
editor.read(cx).buffer().clone(),
|
editor.read(cx).buffer().clone(),
|
||||||
range.clone(),
|
range.clone(),
|
||||||
prepend_transaction_id,
|
initial_insertion,
|
||||||
self.telemetry.clone(),
|
self.telemetry.clone(),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
@ -295,6 +310,7 @@ impl InlineAssistant {
|
||||||
&prompt_editor,
|
&prompt_editor,
|
||||||
prompt_block_id,
|
prompt_block_id,
|
||||||
end_block_id,
|
end_block_id,
|
||||||
|
range,
|
||||||
prompt_editor.read(cx).codegen.clone(),
|
prompt_editor.read(cx).codegen.clone(),
|
||||||
workspace.clone(),
|
workspace.clone(),
|
||||||
cx,
|
cx,
|
||||||
|
@ -445,7 +461,7 @@ impl InlineAssistant {
|
||||||
let buffer = editor.buffer().read(cx).snapshot(cx);
|
let buffer = editor.buffer().read(cx).snapshot(cx);
|
||||||
for assist_id in &editor_assists.assist_ids {
|
for assist_id in &editor_assists.assist_ids {
|
||||||
let assist = &self.assists[assist_id];
|
let assist = &self.assists[assist_id];
|
||||||
let assist_range = assist.codegen.read(cx).range.to_offset(&buffer);
|
let assist_range = assist.range.to_offset(&buffer);
|
||||||
if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
|
if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
|
||||||
{
|
{
|
||||||
if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
|
if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
|
||||||
|
@ -473,7 +489,7 @@ impl InlineAssistant {
|
||||||
let buffer = editor.buffer().read(cx).snapshot(cx);
|
let buffer = editor.buffer().read(cx).snapshot(cx);
|
||||||
for assist_id in &editor_assists.assist_ids {
|
for assist_id in &editor_assists.assist_ids {
|
||||||
let assist = &self.assists[assist_id];
|
let assist = &self.assists[assist_id];
|
||||||
let assist_range = assist.codegen.read(cx).range.to_offset(&buffer);
|
let assist_range = assist.range.to_offset(&buffer);
|
||||||
if assist.decorations.is_some()
|
if assist.decorations.is_some()
|
||||||
&& assist_range.contains(&selection.start)
|
&& assist_range.contains(&selection.start)
|
||||||
&& assist_range.contains(&selection.end)
|
&& assist_range.contains(&selection.end)
|
||||||
|
@ -551,7 +567,7 @@ impl InlineAssistant {
|
||||||
assist.codegen.read(cx).status,
|
assist.codegen.read(cx).status,
|
||||||
CodegenStatus::Error(_) | CodegenStatus::Done
|
CodegenStatus::Error(_) | CodegenStatus::Done
|
||||||
) {
|
) {
|
||||||
let assist_range = assist.codegen.read(cx).range.to_offset(&snapshot);
|
let assist_range = assist.range.to_offset(&snapshot);
|
||||||
if edited_ranges
|
if edited_ranges
|
||||||
.iter()
|
.iter()
|
||||||
.any(|range| range.overlaps(&assist_range))
|
.any(|range| range.overlaps(&assist_range))
|
||||||
|
@ -721,7 +737,7 @@ impl InlineAssistant {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let position = assist.codegen.read(cx).range.start;
|
let position = assist.range.start;
|
||||||
editor.update(cx, |editor, cx| {
|
editor.update(cx, |editor, cx| {
|
||||||
editor.change_selections(None, cx, |selections| {
|
editor.change_selections(None, cx, |selections| {
|
||||||
selections.select_anchor_ranges([position..position])
|
selections.select_anchor_ranges([position..position])
|
||||||
|
@ -740,8 +756,7 @@ impl InlineAssistant {
|
||||||
.0 as f32;
|
.0 as f32;
|
||||||
} else {
|
} else {
|
||||||
let snapshot = editor.snapshot(cx);
|
let snapshot = editor.snapshot(cx);
|
||||||
let codegen = assist.codegen.read(cx);
|
let start_row = assist
|
||||||
let start_row = codegen
|
|
||||||
.range
|
.range
|
||||||
.start
|
.start
|
||||||
.to_display_point(&snapshot.display_snapshot)
|
.to_display_point(&snapshot.display_snapshot)
|
||||||
|
@ -829,11 +844,7 @@ impl InlineAssistant {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let Some(user_prompt) = assist
|
let Some(user_prompt) = assist.user_prompt(cx) else {
|
||||||
.decorations
|
|
||||||
.as_ref()
|
|
||||||
.map(|decorations| decorations.prompt_editor.read(cx).prompt(cx))
|
|
||||||
else {
|
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -843,139 +854,19 @@ impl InlineAssistant {
|
||||||
self.prompt_history.pop_front();
|
self.prompt_history.pop_front();
|
||||||
}
|
}
|
||||||
|
|
||||||
let codegen = assist.codegen.clone();
|
let assistant_panel_context = assist.assistant_panel_context(cx);
|
||||||
let telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
|
||||||
.active_model()
|
|
||||||
.map(|m| m.telemetry_id())
|
|
||||||
.unwrap_or_default();
|
|
||||||
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
|
|
||||||
if user_prompt.trim().to_lowercase() == "delete" {
|
|
||||||
async { Ok(stream::empty().boxed()) }.boxed_local()
|
|
||||||
} else {
|
|
||||||
let request = self.request_for_inline_assist(assist_id, cx);
|
|
||||||
let mut cx = cx.to_async();
|
|
||||||
async move {
|
|
||||||
let request = request.await?;
|
|
||||||
let chunks = cx
|
|
||||||
.update(|cx| {
|
|
||||||
LanguageModelCompletionProvider::read_global(cx)
|
|
||||||
.stream_completion(request, cx)
|
|
||||||
})?
|
|
||||||
.await?;
|
|
||||||
Ok(chunks.boxed())
|
|
||||||
}
|
|
||||||
.boxed_local()
|
|
||||||
};
|
|
||||||
codegen.update(cx, |codegen, cx| {
|
|
||||||
codegen.start(telemetry_id, chunks, cx);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
fn request_for_inline_assist(
|
assist
|
||||||
&self,
|
.codegen
|
||||||
assist_id: InlineAssistId,
|
.update(cx, |codegen, cx| {
|
||||||
cx: &mut WindowContext,
|
codegen.start(
|
||||||
) -> Task<Result<LanguageModelRequest>> {
|
assist.range.clone(),
|
||||||
cx.spawn(|mut cx| async move {
|
user_prompt,
|
||||||
let (user_prompt, context_request, project_name, buffer, range) =
|
assistant_panel_context,
|
||||||
cx.read_global(|this: &InlineAssistant, cx: &WindowContext| {
|
cx,
|
||||||
let assist = this.assists.get(&assist_id).context("invalid assist")?;
|
|
||||||
let decorations = assist.decorations.as_ref().context("invalid assist")?;
|
|
||||||
let editor = assist.editor.upgrade().context("invalid assist")?;
|
|
||||||
let user_prompt = decorations.prompt_editor.read(cx).prompt(cx);
|
|
||||||
let context_request = if assist.include_context {
|
|
||||||
assist.workspace.as_ref().and_then(|workspace| {
|
|
||||||
let workspace = workspace.upgrade()?.read(cx);
|
|
||||||
let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
|
|
||||||
Some(
|
|
||||||
assistant_panel
|
|
||||||
.read(cx)
|
|
||||||
.active_context(cx)?
|
|
||||||
.read(cx)
|
|
||||||
.to_completion_request(cx),
|
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
} else {
|
.log_err();
|
||||||
None
|
|
||||||
};
|
|
||||||
let project_name = assist.workspace.as_ref().and_then(|workspace| {
|
|
||||||
let workspace = workspace.upgrade()?;
|
|
||||||
Some(
|
|
||||||
workspace
|
|
||||||
.read(cx)
|
|
||||||
.project()
|
|
||||||
.read(cx)
|
|
||||||
.worktree_root_names(cx)
|
|
||||||
.collect::<Vec<&str>>()
|
|
||||||
.join("/"),
|
|
||||||
)
|
|
||||||
});
|
|
||||||
let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
|
|
||||||
let range = assist.codegen.read(cx).range.clone();
|
|
||||||
anyhow::Ok((user_prompt, context_request, project_name, buffer, range))
|
|
||||||
})??;
|
|
||||||
|
|
||||||
let language = buffer.language_at(range.start);
|
|
||||||
let language_name = if let Some(language) = language.as_ref() {
|
|
||||||
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(language.name())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
// Higher Temperature increases the randomness of model outputs.
|
|
||||||
// If Markdown or No Language is Known, increase the randomness for more creative output
|
|
||||||
// If Code, decrease temperature to get more deterministic outputs
|
|
||||||
let temperature = if let Some(language) = language_name.clone() {
|
|
||||||
if language.as_ref() == "Markdown" {
|
|
||||||
1.0
|
|
||||||
} else {
|
|
||||||
0.5
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
1.0
|
|
||||||
};
|
|
||||||
|
|
||||||
let prompt = cx
|
|
||||||
.background_executor()
|
|
||||||
.spawn(async move {
|
|
||||||
let language_name = language_name.as_deref();
|
|
||||||
let start = buffer.point_to_buffer_offset(range.start);
|
|
||||||
let end = buffer.point_to_buffer_offset(range.end);
|
|
||||||
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
|
|
||||||
let (start_buffer, start_buffer_offset) = start;
|
|
||||||
let (end_buffer, end_buffer_offset) = end;
|
|
||||||
if start_buffer.remote_id() == end_buffer.remote_id() {
|
|
||||||
(start_buffer.clone(), start_buffer_offset..end_buffer_offset)
|
|
||||||
} else {
|
|
||||||
return Err(anyhow!("invalid transformation range"));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return Err(anyhow!("invalid transformation range"));
|
|
||||||
};
|
|
||||||
generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let mut messages = Vec::new();
|
|
||||||
if let Some(context_request) = context_request {
|
|
||||||
messages = context_request.messages;
|
|
||||||
}
|
|
||||||
|
|
||||||
messages.push(LanguageModelRequestMessage {
|
|
||||||
role: Role::User,
|
|
||||||
content: prompt,
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(LanguageModelRequest {
|
|
||||||
messages,
|
|
||||||
stop: vec!["|END|>".to_string()],
|
|
||||||
temperature,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
|
pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
|
||||||
|
@ -1006,12 +897,11 @@ impl InlineAssistant {
|
||||||
let codegen = assist.codegen.read(cx);
|
let codegen = assist.codegen.read(cx);
|
||||||
foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
|
foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
|
||||||
|
|
||||||
if codegen.edit_position != codegen.range.end {
|
gutter_pending_ranges
|
||||||
gutter_pending_ranges.push(codegen.edit_position..codegen.range.end);
|
.push(codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end);
|
||||||
}
|
|
||||||
|
|
||||||
if codegen.range.start != codegen.edit_position {
|
if let Some(edit_position) = codegen.edit_position {
|
||||||
gutter_transformed_ranges.push(codegen.range.start..codegen.edit_position);
|
gutter_transformed_ranges.push(assist.range.start..edit_position);
|
||||||
}
|
}
|
||||||
|
|
||||||
if assist.decorations.is_some() {
|
if assist.decorations.is_some() {
|
||||||
|
@ -1268,6 +1158,12 @@ fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||||
|
pub enum InitialInsertion {
|
||||||
|
NewlineBefore,
|
||||||
|
NewlineAfter,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
|
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
|
||||||
pub struct InlineAssistId(usize);
|
pub struct InlineAssistId(usize);
|
||||||
|
|
||||||
|
@ -1629,24 +1525,20 @@ impl PromptEditor {
|
||||||
let assist_id = self.id;
|
let assist_id = self.id;
|
||||||
self.pending_token_count = cx.spawn(|this, mut cx| async move {
|
self.pending_token_count = cx.spawn(|this, mut cx| async move {
|
||||||
cx.background_executor().timer(Duration::from_secs(1)).await;
|
cx.background_executor().timer(Duration::from_secs(1)).await;
|
||||||
let request = cx
|
let token_count = cx
|
||||||
.update_global(|inline_assistant: &mut InlineAssistant, cx| {
|
.update_global(|inline_assistant: &mut InlineAssistant, cx| {
|
||||||
inline_assistant.request_for_inline_assist(assist_id, cx)
|
let assist = inline_assistant
|
||||||
})?
|
.assists
|
||||||
|
.get(&assist_id)
|
||||||
|
.context("assist not found")?;
|
||||||
|
anyhow::Ok(assist.count_tokens(cx))
|
||||||
|
})??
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
if let Some(token_count) = cx.update(|cx| {
|
|
||||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
|
||||||
})? {
|
|
||||||
let token_count = token_count.await?;
|
|
||||||
|
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
this.token_count = Some(token_count);
|
this.token_count = Some(token_count);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
})
|
})
|
||||||
} else {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1855,6 +1747,7 @@ impl PromptEditor {
|
||||||
|
|
||||||
struct InlineAssist {
|
struct InlineAssist {
|
||||||
group_id: InlineAssistGroupId,
|
group_id: InlineAssistGroupId,
|
||||||
|
range: Range<Anchor>,
|
||||||
editor: WeakView<Editor>,
|
editor: WeakView<Editor>,
|
||||||
decorations: Option<InlineAssistDecorations>,
|
decorations: Option<InlineAssistDecorations>,
|
||||||
codegen: Model<Codegen>,
|
codegen: Model<Codegen>,
|
||||||
|
@ -1873,6 +1766,7 @@ impl InlineAssist {
|
||||||
prompt_editor: &View<PromptEditor>,
|
prompt_editor: &View<PromptEditor>,
|
||||||
prompt_block_id: CustomBlockId,
|
prompt_block_id: CustomBlockId,
|
||||||
end_block_id: CustomBlockId,
|
end_block_id: CustomBlockId,
|
||||||
|
range: Range<Anchor>,
|
||||||
codegen: Model<Codegen>,
|
codegen: Model<Codegen>,
|
||||||
workspace: Option<WeakView<Workspace>>,
|
workspace: Option<WeakView<Workspace>>,
|
||||||
cx: &mut WindowContext,
|
cx: &mut WindowContext,
|
||||||
|
@ -1888,6 +1782,7 @@ impl InlineAssist {
|
||||||
removed_line_block_ids: HashSet::default(),
|
removed_line_block_ids: HashSet::default(),
|
||||||
end_block_id,
|
end_block_id,
|
||||||
}),
|
}),
|
||||||
|
range,
|
||||||
codegen: codegen.clone(),
|
codegen: codegen.clone(),
|
||||||
workspace: workspace.clone(),
|
workspace: workspace.clone(),
|
||||||
_subscriptions: vec![
|
_subscriptions: vec![
|
||||||
|
@ -1963,6 +1858,41 @@ impl InlineAssist {
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn user_prompt(&self, cx: &AppContext) -> Option<String> {
|
||||||
|
let decorations = self.decorations.as_ref()?;
|
||||||
|
Some(decorations.prompt_editor.read(cx).prompt(cx))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
|
||||||
|
if self.include_context {
|
||||||
|
let workspace = self.workspace.as_ref()?;
|
||||||
|
let workspace = workspace.upgrade()?.read(cx);
|
||||||
|
let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
|
||||||
|
Some(
|
||||||
|
assistant_panel
|
||||||
|
.read(cx)
|
||||||
|
.active_context(cx)?
|
||||||
|
.read(cx)
|
||||||
|
.to_completion_request(cx),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<usize>> {
|
||||||
|
let Some(user_prompt) = self.user_prompt(cx) else {
|
||||||
|
return future::ready(Err(anyhow!("no user prompt"))).boxed();
|
||||||
|
};
|
||||||
|
let assistant_panel_context = self.assistant_panel_context(cx);
|
||||||
|
self.codegen.read(cx).count_tokens(
|
||||||
|
self.range.clone(),
|
||||||
|
user_prompt,
|
||||||
|
assistant_panel_context,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct InlineAssistDecorations {
|
struct InlineAssistDecorations {
|
||||||
|
@ -1982,16 +1912,15 @@ pub struct Codegen {
|
||||||
buffer: Model<MultiBuffer>,
|
buffer: Model<MultiBuffer>,
|
||||||
old_buffer: Model<Buffer>,
|
old_buffer: Model<Buffer>,
|
||||||
snapshot: MultiBufferSnapshot,
|
snapshot: MultiBufferSnapshot,
|
||||||
range: Range<Anchor>,
|
edit_position: Option<Anchor>,
|
||||||
edit_position: Anchor,
|
|
||||||
last_equal_ranges: Vec<Range<Anchor>>,
|
last_equal_ranges: Vec<Range<Anchor>>,
|
||||||
prepend_transaction_id: Option<TransactionId>,
|
transaction_id: Option<TransactionId>,
|
||||||
generation_transaction_id: Option<TransactionId>,
|
|
||||||
status: CodegenStatus,
|
status: CodegenStatus,
|
||||||
generation: Task<()>,
|
generation: Task<()>,
|
||||||
diff: Diff,
|
diff: Diff,
|
||||||
telemetry: Option<Arc<Telemetry>>,
|
telemetry: Option<Arc<Telemetry>>,
|
||||||
_subscription: gpui::Subscription,
|
_subscription: gpui::Subscription,
|
||||||
|
initial_insertion: Option<InitialInsertion>,
|
||||||
}
|
}
|
||||||
|
|
||||||
enum CodegenStatus {
|
enum CodegenStatus {
|
||||||
|
@ -2015,7 +1944,7 @@ impl Codegen {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
buffer: Model<MultiBuffer>,
|
buffer: Model<MultiBuffer>,
|
||||||
range: Range<Anchor>,
|
range: Range<Anchor>,
|
||||||
prepend_transaction_id: Option<TransactionId>,
|
initial_insertion: Option<InitialInsertion>,
|
||||||
telemetry: Option<Arc<Telemetry>>,
|
telemetry: Option<Arc<Telemetry>>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
@ -2044,17 +1973,16 @@ impl Codegen {
|
||||||
Self {
|
Self {
|
||||||
buffer: buffer.clone(),
|
buffer: buffer.clone(),
|
||||||
old_buffer,
|
old_buffer,
|
||||||
edit_position: range.start,
|
edit_position: None,
|
||||||
range,
|
|
||||||
snapshot,
|
snapshot,
|
||||||
last_equal_ranges: Default::default(),
|
last_equal_ranges: Default::default(),
|
||||||
prepend_transaction_id,
|
transaction_id: None,
|
||||||
generation_transaction_id: None,
|
|
||||||
status: CodegenStatus::Idle,
|
status: CodegenStatus::Idle,
|
||||||
generation: Task::ready(()),
|
generation: Task::ready(()),
|
||||||
diff: Diff::default(),
|
diff: Diff::default(),
|
||||||
telemetry,
|
telemetry,
|
||||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||||
|
initial_insertion,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2065,13 +1993,8 @@ impl Codegen {
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) {
|
) {
|
||||||
if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
|
if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
|
||||||
if self.generation_transaction_id == Some(*transaction_id) {
|
if self.transaction_id == Some(*transaction_id) {
|
||||||
self.generation_transaction_id = None;
|
self.transaction_id = None;
|
||||||
self.generation = Task::ready(());
|
|
||||||
cx.emit(CodegenEvent::Undone);
|
|
||||||
} else if self.prepend_transaction_id == Some(*transaction_id) {
|
|
||||||
self.prepend_transaction_id = None;
|
|
||||||
self.generation_transaction_id = None;
|
|
||||||
self.generation = Task::ready(());
|
self.generation = Task::ready(());
|
||||||
cx.emit(CodegenEvent::Undone);
|
cx.emit(CodegenEvent::Undone);
|
||||||
}
|
}
|
||||||
|
@ -2082,19 +2005,152 @@ impl Codegen {
|
||||||
&self.last_equal_ranges
|
&self.last_equal_ranges
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn count_tokens(
|
||||||
|
&self,
|
||||||
|
edit_range: Range<Anchor>,
|
||||||
|
user_prompt: String,
|
||||||
|
assistant_panel_context: Option<LanguageModelRequest>,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
|
let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
|
||||||
|
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn start(
|
pub fn start(
|
||||||
&mut self,
|
&mut self,
|
||||||
telemetry_id: String,
|
mut edit_range: Range<Anchor>,
|
||||||
|
user_prompt: String,
|
||||||
|
assistant_panel_context: Option<LanguageModelRequest>,
|
||||||
|
cx: &mut ModelContext<Self>,
|
||||||
|
) -> Result<()> {
|
||||||
|
self.undo(cx);
|
||||||
|
|
||||||
|
// Handle initial insertion
|
||||||
|
self.transaction_id = if let Some(initial_insertion) = self.initial_insertion {
|
||||||
|
self.buffer.update(cx, |buffer, cx| {
|
||||||
|
buffer.start_transaction(cx);
|
||||||
|
let offset = edit_range.start.to_offset(&self.snapshot);
|
||||||
|
let edit_position;
|
||||||
|
match initial_insertion {
|
||||||
|
InitialInsertion::NewlineBefore => {
|
||||||
|
buffer.edit([(offset..offset, "\n\n")], None, cx);
|
||||||
|
self.snapshot = buffer.snapshot(cx);
|
||||||
|
edit_position = self.snapshot.anchor_after(offset + 1);
|
||||||
|
}
|
||||||
|
InitialInsertion::NewlineAfter => {
|
||||||
|
buffer.edit([(offset..offset, "\n")], None, cx);
|
||||||
|
self.snapshot = buffer.snapshot(cx);
|
||||||
|
edit_position = self.snapshot.anchor_after(offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.edit_position = Some(edit_position);
|
||||||
|
edit_range = edit_position.bias_left(&self.snapshot)..edit_position;
|
||||||
|
buffer.end_transaction(cx)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||||
|
.active_model_telemetry_id()
|
||||||
|
.context("no active model")?;
|
||||||
|
|
||||||
|
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
|
||||||
|
.trim()
|
||||||
|
.to_lowercase()
|
||||||
|
== "delete"
|
||||||
|
{
|
||||||
|
async { Ok(stream::empty().boxed()) }.boxed_local()
|
||||||
|
} else {
|
||||||
|
let request =
|
||||||
|
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
|
||||||
|
let chunks =
|
||||||
|
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
||||||
|
async move { Ok(chunks.await?.boxed()) }.boxed_local()
|
||||||
|
};
|
||||||
|
self.handle_stream(model_telemetry_id, edit_range, chunks, cx);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_request(
|
||||||
|
&self,
|
||||||
|
user_prompt: String,
|
||||||
|
assistant_panel_context: Option<LanguageModelRequest>,
|
||||||
|
edit_range: Range<Anchor>,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> LanguageModelRequest {
|
||||||
|
let buffer = self.buffer.read(cx).snapshot(cx);
|
||||||
|
let language = buffer.language_at(edit_range.start);
|
||||||
|
let language_name = if let Some(language) = language.as_ref() {
|
||||||
|
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(language.name())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Higher Temperature increases the randomness of model outputs.
|
||||||
|
// If Markdown or No Language is Known, increase the randomness for more creative output
|
||||||
|
// If Code, decrease temperature to get more deterministic outputs
|
||||||
|
let temperature = if let Some(language) = language_name.clone() {
|
||||||
|
if language.as_ref() == "Markdown" {
|
||||||
|
1.0
|
||||||
|
} else {
|
||||||
|
0.5
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
1.0
|
||||||
|
};
|
||||||
|
|
||||||
|
let language_name = language_name.as_deref();
|
||||||
|
let start = buffer.point_to_buffer_offset(edit_range.start);
|
||||||
|
let end = buffer.point_to_buffer_offset(edit_range.end);
|
||||||
|
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
|
||||||
|
let (start_buffer, start_buffer_offset) = start;
|
||||||
|
let (end_buffer, end_buffer_offset) = end;
|
||||||
|
if start_buffer.remote_id() == end_buffer.remote_id() {
|
||||||
|
(start_buffer.clone(), start_buffer_offset..end_buffer_offset)
|
||||||
|
} else {
|
||||||
|
panic!("invalid transformation range");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic!("invalid transformation range");
|
||||||
|
};
|
||||||
|
let prompt = generate_content_prompt(user_prompt, language_name, buffer, range);
|
||||||
|
|
||||||
|
let mut messages = Vec::new();
|
||||||
|
if let Some(context_request) = assistant_panel_context {
|
||||||
|
messages = context_request.messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.push(LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: prompt,
|
||||||
|
});
|
||||||
|
|
||||||
|
LanguageModelRequest {
|
||||||
|
messages,
|
||||||
|
stop: vec!["|END|>".to_string()],
|
||||||
|
temperature,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn handle_stream(
|
||||||
|
&mut self,
|
||||||
|
model_telemetry_id: String,
|
||||||
|
edit_range: Range<Anchor>,
|
||||||
stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
|
stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) {
|
) {
|
||||||
let range = self.range.clone();
|
|
||||||
let snapshot = self.snapshot.clone();
|
let snapshot = self.snapshot.clone();
|
||||||
let selected_text = snapshot
|
let selected_text = snapshot
|
||||||
.text_for_range(range.start..range.end)
|
.text_for_range(edit_range.start..edit_range.end)
|
||||||
.collect::<Rope>();
|
.collect::<Rope>();
|
||||||
|
|
||||||
let selection_start = range.start.to_point(&snapshot);
|
let selection_start = edit_range.start.to_point(&snapshot);
|
||||||
|
|
||||||
// Start with the indentation of the first line in the selection
|
// Start with the indentation of the first line in the selection
|
||||||
let mut suggested_line_indent = snapshot
|
let mut suggested_line_indent = snapshot
|
||||||
|
@ -2105,7 +2161,7 @@ impl Codegen {
|
||||||
|
|
||||||
// If the first line in the selection does not have indentation, check the following lines
|
// If the first line in the selection does not have indentation, check the following lines
|
||||||
if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
|
if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
|
||||||
for row in selection_start.row..=range.end.to_point(&snapshot).row {
|
for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
|
||||||
let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
|
let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
|
||||||
// Prefer tabs if a line in the selection uses tabs as indentation
|
// Prefer tabs if a line in the selection uses tabs as indentation
|
||||||
if line_indent.kind == IndentKind::Tab {
|
if line_indent.kind == IndentKind::Tab {
|
||||||
|
@ -2116,19 +2172,13 @@ impl Codegen {
|
||||||
}
|
}
|
||||||
|
|
||||||
let telemetry = self.telemetry.clone();
|
let telemetry = self.telemetry.clone();
|
||||||
self.edit_position = range.start;
|
|
||||||
self.diff = Diff::default();
|
self.diff = Diff::default();
|
||||||
self.status = CodegenStatus::Pending;
|
self.status = CodegenStatus::Pending;
|
||||||
if let Some(transaction_id) = self.generation_transaction_id.take() {
|
let mut edit_start = edit_range.start.to_offset(&snapshot);
|
||||||
self.buffer
|
|
||||||
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
|
|
||||||
}
|
|
||||||
self.generation = cx.spawn(|this, mut cx| {
|
self.generation = cx.spawn(|this, mut cx| {
|
||||||
async move {
|
async move {
|
||||||
let chunks = stream.await;
|
let chunks = stream.await;
|
||||||
let generate = async {
|
let generate = async {
|
||||||
let mut edit_start = range.start.to_offset(&snapshot);
|
|
||||||
|
|
||||||
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
|
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
|
||||||
let diff: Task<anyhow::Result<()>> =
|
let diff: Task<anyhow::Result<()>> =
|
||||||
cx.background_executor().spawn(async move {
|
cx.background_executor().spawn(async move {
|
||||||
|
@ -2218,7 +2268,7 @@ impl Codegen {
|
||||||
telemetry.report_assistant_event(
|
telemetry.report_assistant_event(
|
||||||
None,
|
None,
|
||||||
telemetry_events::AssistantKind::Inline,
|
telemetry_events::AssistantKind::Inline,
|
||||||
telemetry_id,
|
model_telemetry_id,
|
||||||
response_latency,
|
response_latency,
|
||||||
error_message,
|
error_message,
|
||||||
);
|
);
|
||||||
|
@ -2262,13 +2312,13 @@ impl Codegen {
|
||||||
None,
|
None,
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
this.edit_position = snapshot.anchor_after(edit_start);
|
this.edit_position = Some(snapshot.anchor_after(edit_start));
|
||||||
|
|
||||||
buffer.end_transaction(cx)
|
buffer.end_transaction(cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
if let Some(transaction) = transaction {
|
if let Some(transaction) = transaction {
|
||||||
if let Some(first_transaction) = this.generation_transaction_id {
|
if let Some(first_transaction) = this.transaction_id {
|
||||||
// Group all assistant edits into the first transaction.
|
// Group all assistant edits into the first transaction.
|
||||||
this.buffer.update(cx, |buffer, cx| {
|
this.buffer.update(cx, |buffer, cx| {
|
||||||
buffer.merge_transactions(
|
buffer.merge_transactions(
|
||||||
|
@ -2278,14 +2328,14 @@ impl Codegen {
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
this.generation_transaction_id = Some(transaction);
|
this.transaction_id = Some(transaction);
|
||||||
this.buffer.update(cx, |buffer, cx| {
|
this.buffer.update(cx, |buffer, cx| {
|
||||||
buffer.finalize_last_transaction(cx)
|
buffer.finalize_last_transaction(cx)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
this.update_diff(cx);
|
this.update_diff(edit_range.clone(), cx);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
|
@ -2321,27 +2371,22 @@ impl Codegen {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
|
pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
|
||||||
if let Some(transaction_id) = self.prepend_transaction_id.take() {
|
if let Some(transaction_id) = self.transaction_id.take() {
|
||||||
self.buffer
|
|
||||||
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(transaction_id) = self.generation_transaction_id.take() {
|
|
||||||
self.buffer
|
self.buffer
|
||||||
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
|
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_diff(&mut self, cx: &mut ModelContext<Self>) {
|
fn update_diff(&mut self, edit_range: Range<Anchor>, cx: &mut ModelContext<Self>) {
|
||||||
if self.diff.task.is_some() {
|
if self.diff.task.is_some() {
|
||||||
self.diff.should_update = true;
|
self.diff.should_update = true;
|
||||||
} else {
|
} else {
|
||||||
self.diff.should_update = false;
|
self.diff.should_update = false;
|
||||||
|
|
||||||
let old_snapshot = self.snapshot.clone();
|
let old_snapshot = self.snapshot.clone();
|
||||||
let old_range = self.range.to_point(&old_snapshot);
|
let old_range = edit_range.to_point(&old_snapshot);
|
||||||
let new_snapshot = self.buffer.read(cx).snapshot(cx);
|
let new_snapshot = self.buffer.read(cx).snapshot(cx);
|
||||||
let new_range = self.range.to_point(&new_snapshot);
|
let new_range = edit_range.to_point(&new_snapshot);
|
||||||
|
|
||||||
self.diff.task = Some(cx.spawn(|this, mut cx| async move {
|
self.diff.task = Some(cx.spawn(|this, mut cx| async move {
|
||||||
let (deleted_row_ranges, inserted_row_ranges) = cx
|
let (deleted_row_ranges, inserted_row_ranges) = cx
|
||||||
|
@ -2422,7 +2467,7 @@ impl Codegen {
|
||||||
this.diff.inserted_row_ranges = inserted_row_ranges;
|
this.diff.inserted_row_ranges = inserted_row_ranges;
|
||||||
this.diff.task = None;
|
this.diff.task = None;
|
||||||
if this.diff.should_update {
|
if this.diff.should_update {
|
||||||
this.update_diff(cx);
|
this.update_diff(edit_range, cx);
|
||||||
}
|
}
|
||||||
cx.notify();
|
cx.notify();
|
||||||
})
|
})
|
||||||
|
@ -2629,12 +2674,14 @@ mod tests {
|
||||||
let snapshot = buffer.snapshot(cx);
|
let snapshot = buffer.snapshot(cx);
|
||||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||||
});
|
});
|
||||||
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
|
let codegen =
|
||||||
|
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||||
|
|
||||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||||
codegen.update(cx, |codegen, cx| {
|
codegen.update(cx, |codegen, cx| {
|
||||||
codegen.start(
|
codegen.handle_stream(
|
||||||
String::new(),
|
String::new(),
|
||||||
|
range,
|
||||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
@ -2690,12 +2737,14 @@ mod tests {
|
||||||
let snapshot = buffer.snapshot(cx);
|
let snapshot = buffer.snapshot(cx);
|
||||||
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
|
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
|
||||||
});
|
});
|
||||||
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
|
let codegen =
|
||||||
|
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||||
|
|
||||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||||
codegen.update(cx, |codegen, cx| {
|
codegen.update(cx, |codegen, cx| {
|
||||||
codegen.start(
|
codegen.handle_stream(
|
||||||
String::new(),
|
String::new(),
|
||||||
|
range.clone(),
|
||||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
@ -2755,12 +2804,14 @@ mod tests {
|
||||||
let snapshot = buffer.snapshot(cx);
|
let snapshot = buffer.snapshot(cx);
|
||||||
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
|
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
|
||||||
});
|
});
|
||||||
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
|
let codegen =
|
||||||
|
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||||
|
|
||||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||||
codegen.update(cx, |codegen, cx| {
|
codegen.update(cx, |codegen, cx| {
|
||||||
codegen.start(
|
codegen.handle_stream(
|
||||||
String::new(),
|
String::new(),
|
||||||
|
range.clone(),
|
||||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
@ -2819,12 +2870,14 @@ mod tests {
|
||||||
let snapshot = buffer.snapshot(cx);
|
let snapshot = buffer.snapshot(cx);
|
||||||
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
|
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
|
||||||
});
|
});
|
||||||
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
|
let codegen =
|
||||||
|
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||||
|
|
||||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||||
codegen.update(cx, |codegen, cx| {
|
codegen.update(cx, |codegen, cx| {
|
||||||
codegen.start(
|
codegen.handle_stream(
|
||||||
String::new(),
|
String::new(),
|
||||||
|
range.clone(),
|
||||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
|
|
@ -734,7 +734,8 @@ impl PromptLibrary {
|
||||||
const DEBOUNCE_TIMEOUT: Duration = Duration::from_secs(1);
|
const DEBOUNCE_TIMEOUT: Duration = Duration::from_secs(1);
|
||||||
|
|
||||||
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
|
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
|
||||||
if let Some(token_count) = cx.update(|cx| {
|
let token_count = cx
|
||||||
|
.update(|cx| {
|
||||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(
|
LanguageModelCompletionProvider::read_global(cx).count_tokens(
|
||||||
LanguageModelRequest {
|
LanguageModelRequest {
|
||||||
messages: vec![LanguageModelRequestMessage {
|
messages: vec![LanguageModelRequestMessage {
|
||||||
|
@ -746,17 +747,14 @@ impl PromptLibrary {
|
||||||
},
|
},
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
})? {
|
})?
|
||||||
let token_count = token_count.await?;
|
.await?;
|
||||||
|
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap();
|
let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap();
|
||||||
prompt_editor.token_count = Some(token_count);
|
prompt_editor.token_count = Some(token_count);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
})
|
})
|
||||||
} else {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
.log_err()
|
.log_err()
|
||||||
});
|
});
|
||||||
|
|
|
@ -6,8 +6,7 @@ pub fn generate_content_prompt(
|
||||||
language_name: Option<&str>,
|
language_name: Option<&str>,
|
||||||
buffer: BufferSnapshot,
|
buffer: BufferSnapshot,
|
||||||
range: Range<usize>,
|
range: Range<usize>,
|
||||||
_project_name: Option<String>,
|
) -> String {
|
||||||
) -> anyhow::Result<String> {
|
|
||||||
let mut prompt = String::new();
|
let mut prompt = String::new();
|
||||||
|
|
||||||
let content_type = match language_name {
|
let content_type = match language_name {
|
||||||
|
@ -15,14 +14,16 @@ pub fn generate_content_prompt(
|
||||||
writeln!(
|
writeln!(
|
||||||
prompt,
|
prompt,
|
||||||
"Here's a file of text that I'm going to ask you to make an edit to."
|
"Here's a file of text that I'm going to ask you to make an edit to."
|
||||||
)?;
|
)
|
||||||
|
.unwrap();
|
||||||
"text"
|
"text"
|
||||||
}
|
}
|
||||||
Some(language_name) => {
|
Some(language_name) => {
|
||||||
writeln!(
|
writeln!(
|
||||||
prompt,
|
prompt,
|
||||||
"Here's a file of {language_name} that I'm going to ask you to make an edit to."
|
"Here's a file of {language_name} that I'm going to ask you to make an edit to."
|
||||||
)?;
|
)
|
||||||
|
.unwrap();
|
||||||
"code"
|
"code"
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -70,7 +71,7 @@ pub fn generate_content_prompt(
|
||||||
write!(prompt, "</document>\n\n").unwrap();
|
write!(prompt, "</document>\n\n").unwrap();
|
||||||
|
|
||||||
if is_truncated {
|
if is_truncated {
|
||||||
writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n")?;
|
writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n").unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
if range.is_empty() {
|
if range.is_empty() {
|
||||||
|
@ -107,7 +108,7 @@ pub fn generate_content_prompt(
|
||||||
prompt.push_str("\n\nImmediately start with the following format with no remarks:\n\n```\n{{REWRITTEN_CODE}}\n```");
|
prompt.push_str("\n\nImmediately start with the following format with no remarks:\n\n```\n{{REWRITTEN_CODE}}\n```");
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(prompt)
|
prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn generate_terminal_assistant_prompt(
|
pub fn generate_terminal_assistant_prompt(
|
||||||
|
|
|
@ -707,18 +707,15 @@ impl PromptEditor {
|
||||||
inline_assistant.request_for_inline_assist(assist_id, cx)
|
inline_assistant.request_for_inline_assist(assist_id, cx)
|
||||||
})??;
|
})??;
|
||||||
|
|
||||||
if let Some(token_count) = cx.update(|cx| {
|
let token_count = cx
|
||||||
|
.update(|cx| {
|
||||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||||
})? {
|
})?
|
||||||
let token_count = token_count.await?;
|
.await?;
|
||||||
|
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
this.token_count = Some(token_count);
|
this.token_count = Some(token_count);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
})
|
})
|
||||||
} else {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ use crate::{
|
||||||
ServerId, UpdatedChannelMessage, User, UserId,
|
ServerId, UpdatedChannelMessage, User, UserId,
|
||||||
},
|
},
|
||||||
executor::Executor,
|
executor::Executor,
|
||||||
AppState, Error, RateLimit, RateLimiter, Result,
|
AppState, Config, Error, RateLimit, RateLimiter, Result,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, bail, Context as _};
|
use anyhow::{anyhow, bail, Context as _};
|
||||||
use async_tungstenite::tungstenite::{
|
use async_tungstenite::tungstenite::{
|
||||||
|
@ -605,17 +605,39 @@ impl Server {
|
||||||
))
|
))
|
||||||
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
|
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
|
||||||
.add_message_handler(update_context)
|
.add_message_handler(update_context)
|
||||||
|
.add_request_handler({
|
||||||
|
let app_state = app_state.clone();
|
||||||
|
move |request, response, session| {
|
||||||
|
let app_state = app_state.clone();
|
||||||
|
async move {
|
||||||
|
complete_with_language_model(request, response, session, &app_state.config)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
.add_streaming_request_handler({
|
.add_streaming_request_handler({
|
||||||
let app_state = app_state.clone();
|
let app_state = app_state.clone();
|
||||||
move |request, response, session| {
|
move |request, response, session| {
|
||||||
complete_with_language_model(
|
let app_state = app_state.clone();
|
||||||
|
async move {
|
||||||
|
stream_complete_with_language_model(
|
||||||
request,
|
request,
|
||||||
response,
|
response,
|
||||||
session,
|
session,
|
||||||
app_state.config.openai_api_key.clone(),
|
&app_state.config,
|
||||||
app_state.config.google_ai_api_key.clone(),
|
|
||||||
app_state.config.anthropic_api_key.clone(),
|
|
||||||
)
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.add_request_handler({
|
||||||
|
let app_state = app_state.clone();
|
||||||
|
move |request, response, session| {
|
||||||
|
let app_state = app_state.clone();
|
||||||
|
async move {
|
||||||
|
count_language_model_tokens(request, response, session, &app_state.config)
|
||||||
|
.await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.add_request_handler({
|
.add_request_handler({
|
||||||
|
@ -4503,116 +4525,172 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn complete_with_language_model(
|
async fn complete_with_language_model(
|
||||||
query: proto::QueryLanguageModel,
|
request: proto::CompleteWithLanguageModel,
|
||||||
response: StreamingResponse<proto::QueryLanguageModel>,
|
response: Response<proto::CompleteWithLanguageModel>,
|
||||||
session: Session,
|
session: Session,
|
||||||
open_ai_api_key: Option<Arc<str>>,
|
config: &Config,
|
||||||
google_ai_api_key: Option<Arc<str>>,
|
|
||||||
anthropic_api_key: Option<Arc<str>>,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let Some(session) = session.for_user() else {
|
let Some(session) = session.for_user() else {
|
||||||
return Err(anyhow!("user not found"))?;
|
return Err(anyhow!("user not found"))?;
|
||||||
};
|
};
|
||||||
authorize_access_to_language_models(&session).await?;
|
authorize_access_to_language_models(&session).await?;
|
||||||
|
|
||||||
match proto::LanguageModelRequestKind::from_i32(query.kind) {
|
|
||||||
Some(proto::LanguageModelRequestKind::Complete) => {
|
|
||||||
session
|
session
|
||||||
.rate_limiter
|
.rate_limiter
|
||||||
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
|
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||||
|
Some(proto::LanguageModelProvider::Anthropic) => {
|
||||||
|
let api_key = config
|
||||||
|
.anthropic_api_key
|
||||||
|
.as_ref()
|
||||||
|
.context("no Anthropic AI API key configured on the server")?;
|
||||||
|
anthropic::complete(
|
||||||
|
session.http_client.as_ref(),
|
||||||
|
anthropic::ANTHROPIC_API_URL,
|
||||||
|
api_key,
|
||||||
|
serde_json::from_str(&request.request)?,
|
||||||
|
)
|
||||||
|
.await?
|
||||||
}
|
}
|
||||||
Some(proto::LanguageModelRequestKind::CountTokens) => {
|
_ => return Err(anyhow!("unsupported provider"))?,
|
||||||
|
};
|
||||||
|
|
||||||
|
response.send(proto::CompleteWithLanguageModelResponse {
|
||||||
|
completion: serde_json::to_string(&result)?,
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stream_complete_with_language_model(
|
||||||
|
request: proto::StreamCompleteWithLanguageModel,
|
||||||
|
response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
|
||||||
|
session: Session,
|
||||||
|
config: &Config,
|
||||||
|
) -> Result<()> {
|
||||||
|
let Some(session) = session.for_user() else {
|
||||||
|
return Err(anyhow!("user not found"))?;
|
||||||
|
};
|
||||||
|
authorize_access_to_language_models(&session).await?;
|
||||||
|
|
||||||
session
|
session
|
||||||
.rate_limiter
|
.rate_limiter
|
||||||
.check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
|
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
|
||||||
.await?;
|
.await?;
|
||||||
}
|
|
||||||
None => Err(anyhow!("unknown request kind"))?,
|
|
||||||
}
|
|
||||||
|
|
||||||
match proto::LanguageModelProvider::from_i32(query.provider) {
|
match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||||
Some(proto::LanguageModelProvider::Anthropic) => {
|
Some(proto::LanguageModelProvider::Anthropic) => {
|
||||||
let api_key =
|
let api_key = config
|
||||||
anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
|
.anthropic_api_key
|
||||||
|
.as_ref()
|
||||||
|
.context("no Anthropic AI API key configured on the server")?;
|
||||||
let mut chunks = anthropic::stream_completion(
|
let mut chunks = anthropic::stream_completion(
|
||||||
session.http_client.as_ref(),
|
session.http_client.as_ref(),
|
||||||
anthropic::ANTHROPIC_API_URL,
|
anthropic::ANTHROPIC_API_URL,
|
||||||
&api_key,
|
api_key,
|
||||||
serde_json::from_str(&query.request)?,
|
serde_json::from_str(&request.request)?,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
while let Some(chunk) = chunks.next().await {
|
while let Some(event) = chunks.next().await {
|
||||||
let chunk = chunk?;
|
let chunk = event?;
|
||||||
response.send(proto::QueryLanguageModelResponse {
|
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||||
response: serde_json::to_string(&chunk)?,
|
event: serde_json::to_string(&chunk)?,
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some(proto::LanguageModelProvider::OpenAi) => {
|
Some(proto::LanguageModelProvider::OpenAi) => {
|
||||||
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
|
let api_key = config
|
||||||
let mut chunks = open_ai::stream_completion(
|
.openai_api_key
|
||||||
|
.as_ref()
|
||||||
|
.context("no OpenAI API key configured on the server")?;
|
||||||
|
let mut events = open_ai::stream_completion(
|
||||||
session.http_client.as_ref(),
|
session.http_client.as_ref(),
|
||||||
open_ai::OPEN_AI_API_URL,
|
open_ai::OPEN_AI_API_URL,
|
||||||
&api_key,
|
api_key,
|
||||||
serde_json::from_str(&query.request)?,
|
serde_json::from_str(&request.request)?,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
while let Some(chunk) = chunks.next().await {
|
while let Some(event) = events.next().await {
|
||||||
let chunk = chunk?;
|
let event = event?;
|
||||||
response.send(proto::QueryLanguageModelResponse {
|
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||||
response: serde_json::to_string(&chunk)?,
|
event: serde_json::to_string(&event)?,
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some(proto::LanguageModelProvider::Google) => {
|
Some(proto::LanguageModelProvider::Google) => {
|
||||||
let api_key =
|
let api_key = config
|
||||||
google_ai_api_key.context("no Google AI API key configured on the server")?;
|
.google_ai_api_key
|
||||||
|
.as_ref()
|
||||||
match proto::LanguageModelRequestKind::from_i32(query.kind) {
|
.context("no Google AI API key configured on the server")?;
|
||||||
Some(proto::LanguageModelRequestKind::Complete) => {
|
let mut events = google_ai::stream_generate_content(
|
||||||
let mut chunks = google_ai::stream_generate_content(
|
|
||||||
session.http_client.as_ref(),
|
session.http_client.as_ref(),
|
||||||
google_ai::API_URL,
|
google_ai::API_URL,
|
||||||
&api_key,
|
api_key,
|
||||||
serde_json::from_str(&query.request)?,
|
serde_json::from_str(&request.request)?,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
while let Some(chunk) = chunks.next().await {
|
while let Some(event) = events.next().await {
|
||||||
let chunk = chunk?;
|
let event = event?;
|
||||||
response.send(proto::QueryLanguageModelResponse {
|
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||||
response: serde_json::to_string(&chunk)?,
|
event: serde_json::to_string(&event)?,
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some(proto::LanguageModelRequestKind::CountTokens) => {
|
|
||||||
let tokens_response = google_ai::count_tokens(
|
|
||||||
session.http_client.as_ref(),
|
|
||||||
google_ai::API_URL,
|
|
||||||
&api_key,
|
|
||||||
serde_json::from_str(&query.request)?,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
response.send(proto::QueryLanguageModelResponse {
|
|
||||||
response: serde_json::to_string(&tokens_response)?,
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
None => Err(anyhow!("unknown request kind"))?,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => return Err(anyhow!("unknown provider"))?,
|
None => return Err(anyhow!("unknown provider"))?,
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CountTokensWithLanguageModelRateLimit;
|
async fn count_language_model_tokens(
|
||||||
|
request: proto::CountLanguageModelTokens,
|
||||||
|
response: Response<proto::CountLanguageModelTokens>,
|
||||||
|
session: Session,
|
||||||
|
config: &Config,
|
||||||
|
) -> Result<()> {
|
||||||
|
let Some(session) = session.for_user() else {
|
||||||
|
return Err(anyhow!("user not found"))?;
|
||||||
|
};
|
||||||
|
authorize_access_to_language_models(&session).await?;
|
||||||
|
|
||||||
impl RateLimit for CountTokensWithLanguageModelRateLimit {
|
session
|
||||||
|
.rate_limiter
|
||||||
|
.check::<CountLanguageModelTokensRateLimit>(session.user_id())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||||
|
Some(proto::LanguageModelProvider::Google) => {
|
||||||
|
let api_key = config
|
||||||
|
.google_ai_api_key
|
||||||
|
.as_ref()
|
||||||
|
.context("no Google AI API key configured on the server")?;
|
||||||
|
google_ai::count_tokens(
|
||||||
|
session.http_client.as_ref(),
|
||||||
|
google_ai::API_URL,
|
||||||
|
api_key,
|
||||||
|
serde_json::from_str(&request.request)?,
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
}
|
||||||
|
_ => return Err(anyhow!("unsupported provider"))?,
|
||||||
|
};
|
||||||
|
|
||||||
|
response.send(proto::CountLanguageModelTokensResponse {
|
||||||
|
token_count: result.total_tokens as u32,
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CountLanguageModelTokensRateLimit;
|
||||||
|
|
||||||
|
impl RateLimit for CountLanguageModelTokensRateLimit {
|
||||||
fn capacity() -> usize {
|
fn capacity() -> usize {
|
||||||
std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
|
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
|
||||||
.ok()
|
.ok()
|
||||||
.and_then(|v| v.parse().ok())
|
.and_then(|v| v.parse().ok())
|
||||||
.unwrap_or(600) // Picked arbitrarily
|
.unwrap_or(600) // Picked arbitrarily
|
||||||
|
@ -4623,7 +4701,7 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn db_name() -> &'static str {
|
fn db_name() -> &'static str {
|
||||||
"count-tokens-with-language-model"
|
"count-language-model-tokens"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,9 @@ anyhow.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
language_model.workspace = true
|
language_model.workspace = true
|
||||||
|
schemars.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
settings.workspace = true
|
settings.workspace = true
|
||||||
smol.workspace = true
|
smol.workspace = true
|
||||||
ui.workspace = true
|
ui.workspace = true
|
||||||
|
|
|
@ -3,10 +3,13 @@ use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
|
||||||
use gpui::{AppContext, Global, Model, ModelContext, Task};
|
use gpui::{AppContext, Global, Model, ModelContext, Task};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
|
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
|
||||||
LanguageModelRequest,
|
LanguageModelRequest, LanguageModelTool,
|
||||||
};
|
};
|
||||||
use smol::lock::{Semaphore, SemaphoreGuardArc};
|
use smol::{
|
||||||
use std::{pin::Pin, sync::Arc, task::Poll};
|
future::FutureExt,
|
||||||
|
lock::{Semaphore, SemaphoreGuardArc},
|
||||||
|
};
|
||||||
|
use std::{future, pin::Pin, sync::Arc, task::Poll};
|
||||||
use ui::Context;
|
use ui::Context;
|
||||||
|
|
||||||
pub fn init(cx: &mut AppContext) {
|
pub fn init(cx: &mut AppContext) {
|
||||||
|
@ -143,11 +146,11 @@ impl LanguageModelCompletionProvider {
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AppContext,
|
cx: &AppContext,
|
||||||
) -> Option<BoxFuture<'static, Result<usize>>> {
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
if let Some(model) = self.active_model() {
|
if let Some(model) = self.active_model() {
|
||||||
Some(model.count_tokens(request, cx))
|
model.count_tokens(request, cx)
|
||||||
} else {
|
} else {
|
||||||
None
|
future::ready(Err(anyhow!("no active model"))).boxed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -183,6 +186,29 @@ impl LanguageModelCompletionProvider {
|
||||||
Ok(completion)
|
Ok(completion)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn use_tool<T: LanguageModelTool>(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> Task<Result<T>> {
|
||||||
|
if let Some(language_model) = self.active_model() {
|
||||||
|
cx.spawn(|cx| async move {
|
||||||
|
let schema = schemars::schema_for!(T);
|
||||||
|
let schema_json = serde_json::to_value(&schema).unwrap();
|
||||||
|
let request =
|
||||||
|
language_model.use_tool(request, T::name(), T::description(), schema_json, &cx);
|
||||||
|
let response = request.await?;
|
||||||
|
Ok(serde_json::from_value(response)?)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Task::ready(Err(anyhow!("No active model set")))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn active_model_telemetry_id(&self) -> Option<String> {
|
||||||
|
self.active_model.as_ref().map(|m| m.telemetry_id())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
@ -16,6 +16,8 @@ pub use model::*;
|
||||||
pub use registry::*;
|
pub use registry::*;
|
||||||
pub use request::*;
|
pub use request::*;
|
||||||
pub use role::*;
|
pub use role::*;
|
||||||
|
use schemars::JsonSchema;
|
||||||
|
use serde::de::DeserializeOwned;
|
||||||
|
|
||||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||||
settings::init(cx);
|
settings::init(cx);
|
||||||
|
@ -42,6 +44,20 @@ pub trait LanguageModel: Send + Sync {
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AsyncAppContext,
|
cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||||
|
|
||||||
|
fn use_tool(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
name: String,
|
||||||
|
description: String,
|
||||||
|
schema: serde_json::Value,
|
||||||
|
cx: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<serde_json::Value>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
||||||
|
fn name() -> String;
|
||||||
|
fn description() -> String;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait LanguageModelProvider: 'static {
|
pub trait LanguageModelProvider: 'static {
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
use anthropic::stream_completion;
|
use crate::{
|
||||||
use anyhow::{anyhow, Result};
|
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||||
|
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
|
LanguageModelProviderState, LanguageModelRequest, Role,
|
||||||
|
};
|
||||||
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use editor::{Editor, EditorElement, EditorStyle};
|
use editor::{Editor, EditorElement, EditorStyle};
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
|
@ -15,12 +19,6 @@ use theme::ThemeSettings;
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
|
||||||
use crate::{
|
|
||||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
|
||||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
|
||||||
LanguageModelProviderState, LanguageModelRequest, Role,
|
|
||||||
};
|
|
||||||
|
|
||||||
const PROVIDER_ID: &str = "anthropic";
|
const PROVIDER_ID: &str = "anthropic";
|
||||||
const PROVIDER_NAME: &str = "Anthropic";
|
const PROVIDER_NAME: &str = "Anthropic";
|
||||||
|
|
||||||
|
@ -188,6 +186,61 @@ pub fn count_anthropic_tokens(
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AnthropicModel {
|
||||||
|
fn request_completion(
|
||||||
|
&self,
|
||||||
|
request: anthropic::Request,
|
||||||
|
cx: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<anthropic::Response>> {
|
||||||
|
let http_client = self.http_client.clone();
|
||||||
|
|
||||||
|
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
|
||||||
|
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
|
||||||
|
(state.api_key.clone(), settings.api_url.clone())
|
||||||
|
}) else {
|
||||||
|
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||||
|
};
|
||||||
|
|
||||||
|
async move {
|
||||||
|
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||||
|
anthropic::complete(http_client.as_ref(), &api_url, &api_key, request).await
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stream_completion(
|
||||||
|
&self,
|
||||||
|
request: anthropic::Request,
|
||||||
|
cx: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event>>>> {
|
||||||
|
let http_client = self.http_client.clone();
|
||||||
|
|
||||||
|
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
|
||||||
|
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
|
||||||
|
(
|
||||||
|
state.api_key.clone(),
|
||||||
|
settings.api_url.clone(),
|
||||||
|
settings.low_speed_timeout,
|
||||||
|
)
|
||||||
|
}) else {
|
||||||
|
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||||
|
};
|
||||||
|
|
||||||
|
async move {
|
||||||
|
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||||
|
let request = anthropic::stream_completion(
|
||||||
|
http_client.as_ref(),
|
||||||
|
&api_url,
|
||||||
|
&api_key,
|
||||||
|
request,
|
||||||
|
low_speed_timeout,
|
||||||
|
);
|
||||||
|
request.await
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl LanguageModel for AnthropicModel {
|
impl LanguageModel for AnthropicModel {
|
||||||
fn id(&self) -> LanguageModelId {
|
fn id(&self) -> LanguageModelId {
|
||||||
self.id.clone()
|
self.id.clone()
|
||||||
|
@ -227,34 +280,53 @@ impl LanguageModel for AnthropicModel {
|
||||||
cx: &AsyncAppContext,
|
cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
let request = request.into_anthropic(self.model.id().into());
|
let request = request.into_anthropic(self.model.id().into());
|
||||||
|
let request = self.stream_completion(request, cx);
|
||||||
let http_client = self.http_client.clone();
|
|
||||||
|
|
||||||
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
|
|
||||||
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
|
|
||||||
(
|
|
||||||
state.api_key.clone(),
|
|
||||||
settings.api_url.clone(),
|
|
||||||
settings.low_speed_timeout,
|
|
||||||
)
|
|
||||||
}) else {
|
|
||||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
|
||||||
};
|
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
|
||||||
let request = stream_completion(
|
|
||||||
http_client.as_ref(),
|
|
||||||
&api_url,
|
|
||||||
&api_key,
|
|
||||||
request,
|
|
||||||
low_speed_timeout,
|
|
||||||
);
|
|
||||||
let response = request.await?;
|
let response = request.await?;
|
||||||
Ok(anthropic::extract_text_from_events(response).boxed())
|
Ok(anthropic::extract_text_from_events(response).boxed())
|
||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn use_tool(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
tool_name: String,
|
||||||
|
tool_description: String,
|
||||||
|
input_schema: serde_json::Value,
|
||||||
|
cx: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||||
|
let mut request = request.into_anthropic(self.model.id().into());
|
||||||
|
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
||||||
|
name: tool_name.clone(),
|
||||||
|
});
|
||||||
|
request.tools = vec![anthropic::Tool {
|
||||||
|
name: tool_name.clone(),
|
||||||
|
description: tool_description,
|
||||||
|
input_schema,
|
||||||
|
}];
|
||||||
|
|
||||||
|
let response = self.request_completion(request, cx);
|
||||||
|
async move {
|
||||||
|
let response = response.await?;
|
||||||
|
response
|
||||||
|
.content
|
||||||
|
.into_iter()
|
||||||
|
.find_map(|content| {
|
||||||
|
if let anthropic::Content::ToolUse { name, input, .. } = content {
|
||||||
|
if name == tool_name {
|
||||||
|
Some(input)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.context("tool not used")
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AuthenticationPrompt {
|
struct AuthenticationPrompt {
|
||||||
|
|
|
@ -4,7 +4,7 @@ use crate::{
|
||||||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelRequest,
|
LanguageModelProviderState, LanguageModelRequest,
|
||||||
};
|
};
|
||||||
use anyhow::Result;
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use client::Client;
|
use client::Client;
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
|
@ -12,7 +12,7 @@ use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
use std::sync::Arc;
|
use std::{future, sync::Arc};
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
|
|
||||||
|
@ -234,15 +234,13 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
};
|
};
|
||||||
async move {
|
async move {
|
||||||
let request = serde_json::to_string(&request)?;
|
let request = serde_json::to_string(&request)?;
|
||||||
let response = client.request(proto::QueryLanguageModel {
|
let response = client
|
||||||
|
.request(proto::CountLanguageModelTokens {
|
||||||
provider: proto::LanguageModelProvider::Google as i32,
|
provider: proto::LanguageModelProvider::Google as i32,
|
||||||
kind: proto::LanguageModelRequestKind::CountTokens as i32,
|
|
||||||
request,
|
request,
|
||||||
});
|
})
|
||||||
let response = response.await?;
|
.await?;
|
||||||
let response =
|
Ok(response.token_count as usize)
|
||||||
serde_json::from_str::<google_ai::CountTokensResponse>(&response.response)?;
|
|
||||||
Ok(response.total_tokens)
|
|
||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
@ -260,14 +258,14 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
let request = request.into_anthropic(model.id().into());
|
let request = request.into_anthropic(model.id().into());
|
||||||
async move {
|
async move {
|
||||||
let request = serde_json::to_string(&request)?;
|
let request = serde_json::to_string(&request)?;
|
||||||
let response = client.request_stream(proto::QueryLanguageModel {
|
let stream = client
|
||||||
|
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||||
kind: proto::LanguageModelRequestKind::Complete as i32,
|
|
||||||
request,
|
request,
|
||||||
});
|
})
|
||||||
let chunks = response.await?;
|
.await?;
|
||||||
Ok(anthropic::extract_text_from_events(
|
Ok(anthropic::extract_text_from_events(
|
||||||
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
|
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||||
)
|
)
|
||||||
.boxed())
|
.boxed())
|
||||||
}
|
}
|
||||||
|
@ -278,14 +276,14 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
let request = request.into_open_ai(model.id().into());
|
let request = request.into_open_ai(model.id().into());
|
||||||
async move {
|
async move {
|
||||||
let request = serde_json::to_string(&request)?;
|
let request = serde_json::to_string(&request)?;
|
||||||
let response = client.request_stream(proto::QueryLanguageModel {
|
let stream = client
|
||||||
|
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||||
kind: proto::LanguageModelRequestKind::Complete as i32,
|
|
||||||
request,
|
request,
|
||||||
});
|
})
|
||||||
let chunks = response.await?;
|
.await?;
|
||||||
Ok(open_ai::extract_text_from_events(
|
Ok(open_ai::extract_text_from_events(
|
||||||
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
|
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||||
)
|
)
|
||||||
.boxed())
|
.boxed())
|
||||||
}
|
}
|
||||||
|
@ -296,14 +294,14 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
let request = request.into_google(model.id().into());
|
let request = request.into_google(model.id().into());
|
||||||
async move {
|
async move {
|
||||||
let request = serde_json::to_string(&request)?;
|
let request = serde_json::to_string(&request)?;
|
||||||
let response = client.request_stream(proto::QueryLanguageModel {
|
let stream = client
|
||||||
|
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||||
provider: proto::LanguageModelProvider::Google as i32,
|
provider: proto::LanguageModelProvider::Google as i32,
|
||||||
kind: proto::LanguageModelRequestKind::Complete as i32,
|
|
||||||
request,
|
request,
|
||||||
});
|
})
|
||||||
let chunks = response.await?;
|
.await?;
|
||||||
Ok(google_ai::extract_text_from_events(
|
Ok(google_ai::extract_text_from_events(
|
||||||
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
|
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||||
)
|
)
|
||||||
.boxed())
|
.boxed())
|
||||||
}
|
}
|
||||||
|
@ -311,6 +309,63 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn use_tool(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
tool_name: String,
|
||||||
|
tool_description: String,
|
||||||
|
input_schema: serde_json::Value,
|
||||||
|
_cx: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||||
|
match &self.model {
|
||||||
|
CloudModel::Anthropic(model) => {
|
||||||
|
let client = self.client.clone();
|
||||||
|
let mut request = request.into_anthropic(model.id().into());
|
||||||
|
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
||||||
|
name: tool_name.clone(),
|
||||||
|
});
|
||||||
|
request.tools = vec![anthropic::Tool {
|
||||||
|
name: tool_name.clone(),
|
||||||
|
description: tool_description,
|
||||||
|
input_schema,
|
||||||
|
}];
|
||||||
|
|
||||||
|
async move {
|
||||||
|
let request = serde_json::to_string(&request)?;
|
||||||
|
let response = client
|
||||||
|
.request(proto::CompleteWithLanguageModel {
|
||||||
|
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||||
|
request,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
let response: anthropic::Response = serde_json::from_str(&response.completion)?;
|
||||||
|
response
|
||||||
|
.content
|
||||||
|
.into_iter()
|
||||||
|
.find_map(|content| {
|
||||||
|
if let anthropic::Content::ToolUse { name, input, .. } = content {
|
||||||
|
if name == tool_name {
|
||||||
|
Some(input)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.context("tool not used")
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
CloudModel::OpenAi(_) => {
|
||||||
|
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
|
||||||
|
}
|
||||||
|
CloudModel::Google(_) => {
|
||||||
|
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AuthenticationPrompt {
|
struct AuthenticationPrompt {
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
use collections::HashMap;
|
|
||||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||||
LanguageModelRequest,
|
LanguageModelRequest,
|
||||||
};
|
};
|
||||||
|
use anyhow::anyhow;
|
||||||
|
use collections::HashMap;
|
||||||
|
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
|
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
|
||||||
use http_client::Result;
|
use http_client::Result;
|
||||||
|
use std::{
|
||||||
|
future,
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
};
|
||||||
use ui::WindowContext;
|
use ui::WindowContext;
|
||||||
|
|
||||||
pub fn language_model_id() -> LanguageModelId {
|
pub fn language_model_id() -> LanguageModelId {
|
||||||
|
@ -170,4 +172,15 @@ impl LanguageModel for FakeLanguageModel {
|
||||||
.insert(serde_json::to_string(&request).unwrap(), tx);
|
.insert(serde_json::to_string(&request).unwrap(), tx);
|
||||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn use_tool(
|
||||||
|
&self,
|
||||||
|
_request: LanguageModelRequest,
|
||||||
|
_name: String,
|
||||||
|
_description: String,
|
||||||
|
_schema: serde_json::Value,
|
||||||
|
_cx: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||||
|
future::ready(Err(anyhow!("not implemented"))).boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ use gpui::{
|
||||||
};
|
};
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
use std::{sync::Arc, time::Duration};
|
use std::{future, sync::Arc, time::Duration};
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
|
@ -238,6 +238,17 @@ impl LanguageModel for GoogleLanguageModel {
|
||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn use_tool(
|
||||||
|
&self,
|
||||||
|
_request: LanguageModelRequest,
|
||||||
|
_name: String,
|
||||||
|
_description: String,
|
||||||
|
_schema: serde_json::Value,
|
||||||
|
_cx: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||||
|
future::ready(Err(anyhow!("not implemented"))).boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AuthenticationPrompt {
|
struct AuthenticationPrompt {
|
||||||
|
|
|
@ -6,7 +6,7 @@ use ollama::{
|
||||||
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
||||||
};
|
};
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
use std::{sync::Arc, time::Duration};
|
use std::{future, sync::Arc, time::Duration};
|
||||||
use ui::{prelude::*, ButtonLike, ElevationIndex};
|
use ui::{prelude::*, ButtonLike, ElevationIndex};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -298,6 +298,17 @@ impl LanguageModel for OllamaLanguageModel {
|
||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn use_tool(
|
||||||
|
&self,
|
||||||
|
_request: LanguageModelRequest,
|
||||||
|
_name: String,
|
||||||
|
_description: String,
|
||||||
|
_schema: serde_json::Value,
|
||||||
|
_cx: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||||
|
future::ready(Err(anyhow!("not implemented"))).boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct DownloadOllamaMessage {
|
struct DownloadOllamaMessage {
|
||||||
|
|
|
@ -9,7 +9,7 @@ use gpui::{
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
use open_ai::stream_completion;
|
use open_ai::stream_completion;
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
use std::{sync::Arc, time::Duration};
|
use std::{future, sync::Arc, time::Duration};
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
|
@ -225,6 +225,17 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn use_tool(
|
||||||
|
&self,
|
||||||
|
_request: LanguageModelRequest,
|
||||||
|
_name: String,
|
||||||
|
_description: String,
|
||||||
|
_schema: serde_json::Value,
|
||||||
|
_cx: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||||
|
future::ready(Err(anyhow!("not implemented"))).boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn count_open_ai_tokens(
|
pub fn count_open_ai_tokens(
|
||||||
|
|
|
@ -106,19 +106,27 @@ impl LanguageModelRequest {
|
||||||
messages: new_messages
|
messages: new_messages
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|message| {
|
.filter_map(|message| {
|
||||||
Some(anthropic::RequestMessage {
|
Some(anthropic::Message {
|
||||||
role: match message.role {
|
role: match message.role {
|
||||||
Role::User => anthropic::Role::User,
|
Role::User => anthropic::Role::User,
|
||||||
Role::Assistant => anthropic::Role::Assistant,
|
Role::Assistant => anthropic::Role::Assistant,
|
||||||
Role::System => return None,
|
Role::System => return None,
|
||||||
},
|
},
|
||||||
content: message.content,
|
content: vec![anthropic::Content::Text {
|
||||||
|
text: message.content,
|
||||||
|
}],
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
stream: true,
|
|
||||||
max_tokens: 4092,
|
max_tokens: 4092,
|
||||||
system: system_message,
|
system: Some(system_message),
|
||||||
|
tools: Vec::new(),
|
||||||
|
tool_choice: None,
|
||||||
|
metadata: None,
|
||||||
|
stop_sequences: Vec::new(),
|
||||||
|
temperature: None,
|
||||||
|
top_k: None,
|
||||||
|
top_p: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -194,8 +194,12 @@ message Envelope {
|
||||||
|
|
||||||
JoinHostedProject join_hosted_project = 164;
|
JoinHostedProject join_hosted_project = 164;
|
||||||
|
|
||||||
QueryLanguageModel query_language_model = 224;
|
CompleteWithLanguageModel complete_with_language_model = 226;
|
||||||
QueryLanguageModelResponse query_language_model_response = 225; // current max
|
CompleteWithLanguageModelResponse complete_with_language_model_response = 227;
|
||||||
|
StreamCompleteWithLanguageModel stream_complete_with_language_model = 228;
|
||||||
|
StreamCompleteWithLanguageModelResponse stream_complete_with_language_model_response = 229;
|
||||||
|
CountLanguageModelTokens count_language_model_tokens = 230;
|
||||||
|
CountLanguageModelTokensResponse count_language_model_tokens_response = 231; // current max
|
||||||
GetCachedEmbeddings get_cached_embeddings = 189;
|
GetCachedEmbeddings get_cached_embeddings = 189;
|
||||||
GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
|
GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
|
||||||
ComputeEmbeddings compute_embeddings = 191;
|
ComputeEmbeddings compute_embeddings = 191;
|
||||||
|
@ -267,6 +271,7 @@ message Envelope {
|
||||||
|
|
||||||
reserved 158 to 161;
|
reserved 158 to 161;
|
||||||
reserved 166 to 169;
|
reserved 166 to 169;
|
||||||
|
reserved 224 to 225;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Messages
|
// Messages
|
||||||
|
@ -2050,10 +2055,31 @@ enum LanguageModelRole {
|
||||||
reserved 3;
|
reserved 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message QueryLanguageModel {
|
message CompleteWithLanguageModel {
|
||||||
LanguageModelProvider provider = 1;
|
LanguageModelProvider provider = 1;
|
||||||
LanguageModelRequestKind kind = 2;
|
string request = 2;
|
||||||
string request = 3;
|
}
|
||||||
|
|
||||||
|
message CompleteWithLanguageModelResponse {
|
||||||
|
string completion = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message StreamCompleteWithLanguageModel {
|
||||||
|
LanguageModelProvider provider = 1;
|
||||||
|
string request = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message StreamCompleteWithLanguageModelResponse {
|
||||||
|
string event = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CountLanguageModelTokens {
|
||||||
|
LanguageModelProvider provider = 1;
|
||||||
|
string request = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CountLanguageModelTokensResponse {
|
||||||
|
uint32 token_count = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum LanguageModelProvider {
|
enum LanguageModelProvider {
|
||||||
|
@ -2062,15 +2088,6 @@ enum LanguageModelProvider {
|
||||||
Google = 2;
|
Google = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum LanguageModelRequestKind {
|
|
||||||
Complete = 0;
|
|
||||||
CountTokens = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message QueryLanguageModelResponse {
|
|
||||||
string response = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message GetCachedEmbeddings {
|
message GetCachedEmbeddings {
|
||||||
string model = 1;
|
string model = 1;
|
||||||
repeated bytes digests = 2;
|
repeated bytes digests = 2;
|
||||||
|
|
|
@ -294,8 +294,12 @@ messages!(
|
||||||
(PrepareRename, Background),
|
(PrepareRename, Background),
|
||||||
(PrepareRenameResponse, Background),
|
(PrepareRenameResponse, Background),
|
||||||
(ProjectEntryResponse, Foreground),
|
(ProjectEntryResponse, Foreground),
|
||||||
(QueryLanguageModel, Background),
|
(CompleteWithLanguageModel, Background),
|
||||||
(QueryLanguageModelResponse, Background),
|
(CompleteWithLanguageModelResponse, Background),
|
||||||
|
(StreamCompleteWithLanguageModel, Background),
|
||||||
|
(StreamCompleteWithLanguageModelResponse, Background),
|
||||||
|
(CountLanguageModelTokens, Background),
|
||||||
|
(CountLanguageModelTokensResponse, Background),
|
||||||
(RefreshInlayHints, Foreground),
|
(RefreshInlayHints, Foreground),
|
||||||
(RejoinChannelBuffers, Foreground),
|
(RejoinChannelBuffers, Foreground),
|
||||||
(RejoinChannelBuffersResponse, Foreground),
|
(RejoinChannelBuffersResponse, Foreground),
|
||||||
|
@ -463,7 +467,12 @@ request_messages!(
|
||||||
(PerformRename, PerformRenameResponse),
|
(PerformRename, PerformRenameResponse),
|
||||||
(Ping, Ack),
|
(Ping, Ack),
|
||||||
(PrepareRename, PrepareRenameResponse),
|
(PrepareRename, PrepareRenameResponse),
|
||||||
(QueryLanguageModel, QueryLanguageModelResponse),
|
(CompleteWithLanguageModel, CompleteWithLanguageModelResponse),
|
||||||
|
(
|
||||||
|
StreamCompleteWithLanguageModel,
|
||||||
|
StreamCompleteWithLanguageModelResponse
|
||||||
|
),
|
||||||
|
(CountLanguageModelTokens, CountLanguageModelTokensResponse),
|
||||||
(RefreshInlayHints, Ack),
|
(RefreshInlayHints, Ack),
|
||||||
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
|
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
|
||||||
(RejoinRoom, RejoinRoomResponse),
|
(RejoinRoom, RejoinRoomResponse),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue