Use tool calling instead of XML parsing to generate edit operations (#15385)

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-29 16:42:08 +02:00 committed by GitHub
parent f6012cd86e
commit 6e1f7c6e1d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 1155 additions and 853 deletions

13
Cargo.lock generated
View file

@ -435,7 +435,6 @@ dependencies = [
"rand 0.8.5", "rand 0.8.5",
"regex", "regex",
"rope", "rope",
"roxmltree 0.20.0",
"schemars", "schemars",
"search", "search",
"semantic_index", "semantic_index",
@ -2641,7 +2640,9 @@ dependencies = [
"language_model", "language_model",
"project", "project",
"rand 0.8.5", "rand 0.8.5",
"schemars",
"serde", "serde",
"serde_json",
"settings", "settings",
"smol", "smol",
"text", "text",
@ -4237,7 +4238,7 @@ version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a595cb550439a117696039dfc69830492058211b771a2a165379f2a1a53d84d" checksum = "6a595cb550439a117696039dfc69830492058211b771a2a165379f2a1a53d84d"
dependencies = [ dependencies = [
"roxmltree 0.19.0", "roxmltree",
] ]
[[package]] [[package]]
@ -8918,12 +8919,6 @@ version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cd14fd5e3b777a7422cca79358c57a8f6e3a703d9ac187448d0daf220c2407f" checksum = "3cd14fd5e3b777a7422cca79358c57a8f6e3a703d9ac187448d0daf220c2407f"
[[package]]
name = "roxmltree"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c20b6793b5c2fa6553b250154b78d6d0db37e72700ae35fad9387a46f487c97"
[[package]] [[package]]
name = "rpc" name = "rpc"
version = "0.1.0" version = "0.1.0"
@ -11878,7 +11873,7 @@ dependencies = [
"kurbo", "kurbo",
"log", "log",
"pico-args", "pico-args",
"roxmltree 0.19.0", "roxmltree",
"simplecss", "simplecss",
"siphasher 1.0.1", "siphasher 1.0.1",
"strict-num", "strict-num",

View file

@ -3,7 +3,7 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable; use isahc::config::Configurable;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{convert::TryFrom, time::Duration}; use std::time::Duration;
use strum::EnumIter; use strum::EnumIter;
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com"; pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
@ -70,104 +70,41 @@ impl Model {
} }
} }
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] pub async fn complete(
#[serde(rename_all = "lowercase")] client: &dyn HttpClient,
pub enum Role { api_url: &str,
User, api_key: &str,
Assistant, request: Request,
} ) -> Result<Response> {
let uri = format!("{api_url}/v1/messages");
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Anthropic-Version", "2023-06-01")
.header("Anthropic-Beta", "tools-2024-04-04")
.header("X-Api-Key", api_key)
.header("Content-Type", "application/json");
impl TryFrom<String> for Role { let serialized_request = serde_json::to_string(&request)?;
type Error = anyhow::Error; let request = request_builder.body(AsyncBody::from(serialized_request))?;
fn try_from(value: String) -> Result<Self> { let mut response = client.send(request).await?;
match value.as_str() { if response.status().is_success() {
"user" => Ok(Self::User), let mut body = Vec::new();
"assistant" => Ok(Self::Assistant), response.body_mut().read_to_end(&mut body).await?;
_ => Err(anyhow!("invalid role '{value}'")), let response_message: Response = serde_json::from_slice(&body)?;
Ok(response_message)
} else {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let body_str = std::str::from_utf8(&body)?;
Err(anyhow!(
"Failed to connect to API: {} {}",
response.status(),
body_str
))
} }
} }
}
impl From<Role> for String {
fn from(val: Role) -> Self {
match val {
Role::User => "user".to_owned(),
Role::Assistant => "assistant".to_owned(),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Request {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub system: String,
pub max_tokens: u32,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseEvent {
MessageStart {
message: ResponseMessage,
},
ContentBlockStart {
index: u32,
content_block: ContentBlock,
},
Ping {},
ContentBlockDelta {
index: u32,
delta: TextDelta,
},
ContentBlockStop {
index: u32,
},
MessageDelta {
delta: ResponseMessage,
usage: Usage,
},
MessageStop {},
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ResponseMessage {
#[serde(rename = "type")]
pub message_type: Option<String>,
pub id: Option<String>,
pub role: Option<String>,
pub content: Option<Vec<String>>,
pub model: Option<String>,
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
pub usage: Option<Usage>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Usage {
pub input_tokens: Option<u32>,
pub output_tokens: Option<u32>,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text { text: String },
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TextDelta {
TextDelta { text: String },
}
pub async fn stream_completion( pub async fn stream_completion(
client: &dyn HttpClient, client: &dyn HttpClient,
@ -175,7 +112,11 @@ pub async fn stream_completion(
api_key: &str, api_key: &str,
request: Request, request: Request,
low_speed_timeout: Option<Duration>, low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<ResponseEvent>>> { ) -> Result<BoxStream<'static, Result<Event>>> {
let request = StreamingRequest {
base: request,
stream: true,
};
let uri = format!("{api_url}/v1/messages"); let uri = format!("{api_url}/v1/messages");
let mut request_builder = HttpRequest::builder() let mut request_builder = HttpRequest::builder()
.method(Method::POST) .method(Method::POST)
@ -187,7 +128,9 @@ pub async fn stream_completion(
if let Some(low_speed_timeout) = low_speed_timeout { if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
} }
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let serialized_request = serde_json::to_string(&request)?;
let request = request_builder.body(AsyncBody::from(serialized_request))?;
let mut response = client.send(request).await?; let mut response = client.send(request).await?;
if response.status().is_success() { if response.status().is_success() {
let reader = BufReader::new(response.into_body()); let reader = BufReader::new(response.into_body());
@ -212,7 +155,7 @@ pub async fn stream_completion(
let body_str = std::str::from_utf8(&body)?; let body_str = std::str::from_utf8(&body)?;
match serde_json::from_str::<ResponseEvent>(body_str) { match serde_json::from_str::<Event>(body_str) {
Ok(_) => Err(anyhow!( Ok(_) => Err(anyhow!(
"Unexpected success response while expecting an error: {}", "Unexpected success response while expecting an error: {}",
body_str, body_str,
@ -227,16 +170,18 @@ pub async fn stream_completion(
} }
pub fn extract_text_from_events( pub fn extract_text_from_events(
response: impl Stream<Item = Result<ResponseEvent>>, response: impl Stream<Item = Result<Event>>,
) -> impl Stream<Item = Result<String>> { ) -> impl Stream<Item = Result<String>> {
response.filter_map(|response| async move { response.filter_map(|response| async move {
match response { match response {
Ok(response) => match response { Ok(response) => match response {
ResponseEvent::ContentBlockStart { content_block, .. } => match content_block { Event::ContentBlockStart { content_block, .. } => match content_block {
ContentBlock::Text { text } => Some(Ok(text)), Content::Text { text } => Some(Ok(text)),
_ => None,
}, },
ResponseEvent::ContentBlockDelta { delta, .. } => match delta { Event::ContentBlockDelta { delta, .. } => match delta {
TextDelta::TextDelta { text } => Some(Ok(text)), ContentDelta::TextDelta { text } => Some(Ok(text)),
_ => None,
}, },
_ => None, _ => None,
}, },
@ -245,42 +190,162 @@ pub fn extract_text_from_events(
}) })
} }
// #[cfg(test)] #[derive(Debug, Serialize, Deserialize)]
// mod tests { pub struct Message {
// use super::*; pub role: Role,
// use http::IsahcHttpClient; pub content: Vec<Content>,
}
// #[tokio::test] #[derive(Debug, Serialize, Deserialize)]
// async fn stream_completion_success() { #[serde(rename_all = "lowercase")]
// let http_client = IsahcHttpClient::new().unwrap(); pub enum Role {
User,
Assistant,
}
// let request = Request { #[derive(Debug, Serialize, Deserialize)]
// model: Model::Claude3Opus, #[serde(tag = "type")]
// messages: vec![RequestMessage { pub enum Content {
// role: Role::User, #[serde(rename = "text")]
// content: "Ping".to_string(), Text { text: String },
// }], #[serde(rename = "image")]
// stream: true, Image { source: ImageSource },
// system: "Respond to ping with pong".to_string(), #[serde(rename = "tool_use")]
// max_tokens: 4096, ToolUse {
// }; id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
},
}
// let stream = stream_completion( #[derive(Debug, Serialize, Deserialize)]
// &http_client, pub struct ImageSource {
// "https://api.anthropic.com", #[serde(rename = "type")]
// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"), pub source_type: String,
// request, pub media_type: String,
// ) pub data: String,
// .await }
// .unwrap();
// stream #[derive(Debug, Serialize, Deserialize)]
// .for_each(|event| async { pub struct Tool {
// match event { pub name: String,
// Ok(event) => println!("{:?}", event), pub description: String,
// Err(e) => eprintln!("Error: {:?}", e), pub input_schema: serde_json::Value,
// } }
// })
// .await; #[derive(Debug, Serialize, Deserialize)]
// } #[serde(tag = "type", rename_all = "lowercase")]
// } pub enum ToolChoice {
Auto,
Any,
Tool { name: String },
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Request {
pub model: String,
pub max_tokens: u32,
pub messages: Vec<Message>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<Metadata>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop_sequences: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
}
#[derive(Debug, Serialize, Deserialize)]
struct StreamingRequest {
#[serde(flatten)]
pub base: Request,
pub stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Metadata {
pub user_id: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Usage {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub input_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_tokens: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Response {
pub id: String,
#[serde(rename = "type")]
pub response_type: String,
pub role: Role,
pub content: Vec<Content>,
pub model: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_sequence: Option<String>,
pub usage: Usage,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Event {
#[serde(rename = "message_start")]
MessageStart { message: Response },
#[serde(rename = "content_block_start")]
ContentBlockStart {
index: usize,
content_block: Content,
},
#[serde(rename = "content_block_delta")]
ContentBlockDelta { index: usize, delta: ContentDelta },
#[serde(rename = "content_block_stop")]
ContentBlockStop { index: usize },
#[serde(rename = "message_delta")]
MessageDelta { delta: MessageDelta, usage: Usage },
#[serde(rename = "message_stop")]
MessageStop,
#[serde(rename = "ping")]
Ping,
#[serde(rename = "error")]
Error { error: ApiError },
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ContentDelta {
#[serde(rename = "text_delta")]
TextDelta { text: String },
#[serde(rename = "input_json_delta")]
InputJsonDelta { partial_json: String },
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MessageDelta {
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ApiError {
#[serde(rename = "type")]
pub error_type: String,
pub message: String,
}

View file

@ -75,7 +75,6 @@ util.workspace = true
uuid.workspace = true uuid.workspace = true
workspace.workspace = true workspace.workspace = true
picker.workspace = true picker.workspace = true
roxmltree = "0.20.0"
[dev-dependencies] [dev-dependencies]
completion = { workspace = true, features = ["test-support"] } completion = { workspace = true, features = ["test-support"] }

View file

@ -1232,12 +1232,16 @@ impl ContextEditor {
fn apply_edit_step(&mut self, cx: &mut ViewContext<Self>) -> bool { fn apply_edit_step(&mut self, cx: &mut ViewContext<Self>) -> bool {
if let Some(step) = self.active_edit_step.as_ref() { if let Some(step) = self.active_edit_step.as_ref() {
let assist_ids = step.assist_ids.clone();
cx.window_context().defer(|cx| {
InlineAssistant::update_global(cx, |assistant, cx| { InlineAssistant::update_global(cx, |assistant, cx| {
for assist_id in &step.assist_ids { for assist_id in assist_ids {
assistant.start_assist(*assist_id, cx); assistant.start_assist(assist_id, cx);
} }
!step.assist_ids.is_empty()
}) })
});
!step.assist_ids.is_empty()
} else { } else {
false false
} }
@ -1286,11 +1290,7 @@ impl ContextEditor {
.collect::<String>() .collect::<String>()
)); ));
match &step.operations { match &step.operations {
Some(EditStepOperations::Parsed { Some(EditStepOperations::Ready(operations)) => {
operations,
raw_output,
}) => {
output.push_str(&format!("Raw Output:\n{raw_output}\n"));
output.push_str("Parsed Operations:\n"); output.push_str("Parsed Operations:\n");
for op in operations { for op in operations {
output.push_str(&format!(" {:?}\n", op)); output.push_str(&format!(" {:?}\n", op));
@ -1794,13 +1794,12 @@ impl ContextEditor {
.anchor_in_excerpt(excerpt_id, suggestion.range.end) .anchor_in_excerpt(excerpt_id, suggestion.range.end)
.unwrap() .unwrap()
}; };
let initial_text = suggestion.prepend_newline.then(|| "\n".into());
InlineAssistant::update_global(cx, |assistant, cx| { InlineAssistant::update_global(cx, |assistant, cx| {
assist_ids.push(assistant.suggest_assist( assist_ids.push(assistant.suggest_assist(
&editor, &editor,
range, range,
description, description,
initial_text, suggestion.initial_insertion,
Some(workspace.clone()), Some(workspace.clone()),
assistant_panel.upgrade().as_ref(), assistant_panel.upgrade().as_ref(),
cx, cx,
@ -1862,9 +1861,11 @@ impl ContextEditor {
.anchor_in_excerpt(excerpt_id, suggestion.range.end) .anchor_in_excerpt(excerpt_id, suggestion.range.end)
.unwrap() .unwrap()
}; };
let initial_text = inline_assist_suggestions.push((
suggestion.prepend_newline.then(|| "\n".to_string()); range,
inline_assist_suggestions.push((range, description, initial_text)); description,
suggestion.initial_insertion,
));
} }
} }
} }
@ -1875,12 +1876,12 @@ impl ContextEditor {
.new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), true, cx))?; .new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), true, cx))?;
cx.update(|cx| { cx.update(|cx| {
InlineAssistant::update_global(cx, |assistant, cx| { InlineAssistant::update_global(cx, |assistant, cx| {
for (range, description, initial_text) in inline_assist_suggestions { for (range, description, initial_insertion) in inline_assist_suggestions {
assist_ids.push(assistant.suggest_assist( assist_ids.push(assistant.suggest_assist(
&editor, &editor,
range, range,
description, description,
initial_text, initial_insertion,
Some(workspace.clone()), Some(workspace.clone()),
assistant_panel.upgrade().as_ref(), assistant_panel.upgrade().as_ref(),
cx, cx,
@ -2188,7 +2189,7 @@ impl ContextEditor {
let button_text = match self.edit_step_for_cursor(cx) { let button_text = match self.edit_step_for_cursor(cx) {
Some(edit_step) => match &edit_step.operations { Some(edit_step) => match &edit_step.operations {
Some(EditStepOperations::Pending(_)) => "Computing Changes...", Some(EditStepOperations::Pending(_)) => "Computing Changes...",
Some(EditStepOperations::Parsed { .. }) => "Apply Changes", Some(EditStepOperations::Ready(_)) => "Apply Changes",
None => "Send", None => "Send",
}, },
None => "Send", None => "Send",

View file

@ -1,6 +1,6 @@
use crate::{ use crate::{
prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider, prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
MessageId, MessageStatus, LanguageModelCompletionProvider, MessageId, MessageStatus,
}; };
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{ use assistant_slash_command::{
@ -18,11 +18,11 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
use language::{ use language::{
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset, AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
}; };
use language_model::LanguageModelRequestMessage; use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
use language_model::{LanguageModelRequest, Role};
use open_ai::Model as OpenAiModel; use open_ai::Model as OpenAiModel;
use paths::contexts_dir; use paths::contexts_dir;
use project::Project; use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{
cmp, cmp,
@ -352,7 +352,7 @@ pub struct EditSuggestion {
pub range: Range<language::Anchor>, pub range: Range<language::Anchor>,
/// If None, assume this is a suggestion to delete the range rather than transform it. /// If None, assume this is a suggestion to delete the range rather than transform it.
pub description: Option<String>, pub description: Option<String>,
pub prepend_newline: bool, pub initial_insertion: Option<InitialInsertion>,
} }
impl EditStep { impl EditStep {
@ -361,7 +361,7 @@ impl EditStep {
project: &Model<Project>, project: &Model<Project>,
cx: &AppContext, cx: &AppContext,
) -> Task<HashMap<Model<Buffer>, Vec<EditSuggestionGroup>>> { ) -> Task<HashMap<Model<Buffer>, Vec<EditSuggestionGroup>>> {
let Some(EditStepOperations::Parsed { operations, .. }) = &self.operations else { let Some(EditStepOperations::Ready(operations)) = &self.operations else {
return Task::ready(HashMap::default()); return Task::ready(HashMap::default());
}; };
@ -471,32 +471,28 @@ impl EditStep {
} }
pub enum EditStepOperations { pub enum EditStepOperations {
Pending(Task<Result<()>>), Pending(Task<Option<()>>),
Parsed { Ready(Vec<EditOperation>),
operations: Vec<EditOperation>,
raw_output: String,
},
} }
impl Debug for EditStepOperations { impl Debug for EditStepOperations {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"), EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"),
EditStepOperations::Parsed { EditStepOperations::Ready(operations) => f
operations,
raw_output,
} => f
.debug_struct("EditStepOperations::Parsed") .debug_struct("EditStepOperations::Parsed")
.field("operations", operations) .field("operations", operations)
.field("raw_output", raw_output)
.finish(), .finish(),
} }
} }
} }
#[derive(Clone, Debug, PartialEq, Eq)] /// A description of an operation to apply to one location in the codebase.
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
pub struct EditOperation { pub struct EditOperation {
/// The path to the file containing the relevant operation
pub path: String, pub path: String,
#[serde(flatten)]
pub kind: EditOperationKind, pub kind: EditOperationKind,
} }
@ -523,7 +519,7 @@ impl EditOperation {
parse_status.changed().await?; parse_status.changed().await?;
} }
let prepend_newline = kind.prepend_newline(); let initial_insertion = kind.initial_insertion();
let suggestion_range = if let Some(symbol) = kind.symbol() { let suggestion_range = if let Some(symbol) = kind.symbol() {
let outline = buffer let outline = buffer
.update(&mut cx, |buffer, _| buffer.snapshot().outline(None))? .update(&mut cx, |buffer, _| buffer.snapshot().outline(None))?
@ -601,39 +597,61 @@ impl EditOperation {
EditSuggestion { EditSuggestion {
range: suggestion_range, range: suggestion_range,
description: kind.description().map(ToString::to_string), description: kind.description().map(ToString::to_string),
prepend_newline, initial_insertion,
}, },
)) ))
}) })
} }
} }
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
#[serde(tag = "kind")]
pub enum EditOperationKind { pub enum EditOperationKind {
/// Rewrite the specified symbol in its entirely based on the given description.
Update { Update {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String, symbol: String,
/// A brief one-line description of the change that should be applied.
description: String, description: String,
}, },
/// Create a new file with the given path based on the given description.
Create { Create {
/// A brief one-line description of the change that should be applied.
description: String, description: String,
}, },
/// Insert a new symbol based on the given description before the specified symbol.
InsertSiblingBefore { InsertSiblingBefore {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String, symbol: String,
/// A brief one-line description of the change that should be applied.
description: String, description: String,
}, },
/// Insert a new symbol based on the given description after the specified symbol.
InsertSiblingAfter { InsertSiblingAfter {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String, symbol: String,
/// A brief one-line description of the change that should be applied.
description: String, description: String,
}, },
/// Insert a new symbol as a child of the specified symbol at the start.
PrependChild { PrependChild {
/// An optional full path to the symbol to be rewritten from the provided list.
/// If not provided, the edit should be applied at the top of the file.
symbol: Option<String>, symbol: Option<String>,
/// A brief one-line description of the change that should be applied.
description: String, description: String,
}, },
/// Insert a new symbol as a child of the specified symbol at the end.
AppendChild { AppendChild {
/// An optional full path to the symbol to be rewritten from the provided list.
/// If not provided, the edit should be applied at the top of the file.
symbol: Option<String>, symbol: Option<String>,
/// A brief one-line description of the change that should be applied.
description: String, description: String,
}, },
/// Delete the specified symbol.
Delete { Delete {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String, symbol: String,
}, },
} }
@ -663,13 +681,13 @@ impl EditOperationKind {
} }
} }
pub fn prepend_newline(&self) -> bool { pub fn initial_insertion(&self) -> Option<InitialInsertion> {
match self { match self {
Self::PrependChild { .. } EditOperationKind::InsertSiblingBefore { .. } => Some(InitialInsertion::NewlineAfter),
| Self::AppendChild { .. } EditOperationKind::InsertSiblingAfter { .. } => Some(InitialInsertion::NewlineBefore),
| Self::InsertSiblingAfter { .. } EditOperationKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
| Self::InsertSiblingBefore { .. } => true, EditOperationKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
_ => false, _ => None,
} }
} }
} }
@ -1137,18 +1155,15 @@ impl Context {
.timer(Duration::from_millis(200)) .timer(Duration::from_millis(200))
.await; .await;
if let Some(token_count) = cx.update(|cx| { let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx) LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})? { })?
let token_count = token_count.await?; .await?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count); this.token_count = Some(token_count);
cx.notify() cx.notify()
})?; })
}
anyhow::Ok(())
} }
.log_err() .log_err()
}); });
@ -1304,7 +1319,24 @@ impl Context {
&self, &self,
edit_step: &EditStep, edit_step: &EditStep,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<()>> { ) -> Task<Option<()>> {
#[derive(Debug, Deserialize, JsonSchema)]
struct EditTool {
/// A sequence of operations to apply to the codebase.
/// When multiple operations are required for a step, be sure to include multiple operations in this list.
operations: Vec<EditOperation>,
}
impl LanguageModelTool for EditTool {
fn name() -> String {
"edit".into()
}
fn description() -> String {
"suggest edits to one or more locations in the codebase".into()
}
}
let mut request = self.to_completion_request(cx); let mut request = self.to_completion_request(cx);
let edit_step_range = edit_step.source_range.clone(); let edit_step_range = edit_step.source_range.clone();
let step_text = self let step_text = self
@ -1313,7 +1345,8 @@ impl Context {
.text_for_range(edit_step_range.clone()) .text_for_range(edit_step_range.clone())
.collect::<String>(); .collect::<String>();
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| {
async move {
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?; let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
let mut prompt = prompt_store.operations_prompt(); let mut prompt = prompt_store.operations_prompt();
@ -1324,13 +1357,13 @@ impl Context {
content: prompt, content: prompt,
}); });
let raw_output = cx let tool_use = cx
.update(|cx| { .update(|cx| {
LanguageModelCompletionProvider::read_global(cx).complete(request, cx) LanguageModelCompletionProvider::read_global(cx)
.use_tool::<EditTool>(request, cx)
})? })?
.await?; .await?;
let operations = Self::parse_edit_operations(&raw_output);
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
let step_index = this let step_index = this
.edit_steps .edit_steps
@ -1340,133 +1373,13 @@ impl Context {
}) })
.map_err(|_| anyhow!("edit step not found"))?; .map_err(|_| anyhow!("edit step not found"))?;
if let Some(edit_step) = this.edit_steps.get_mut(step_index) { if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
edit_step.operations = Some(EditStepOperations::Parsed { edit_step.operations = Some(EditStepOperations::Ready(tool_use.operations));
operations,
raw_output,
});
cx.emit(ContextEvent::EditStepsChanged); cx.emit(ContextEvent::EditStepsChanged);
} }
anyhow::Ok(()) anyhow::Ok(())
})? })?
})
} }
.log_err()
fn parse_edit_operations(xml: &str) -> Vec<EditOperation> {
let Some(start_ix) = xml.find("<operations>") else {
return Vec::new();
};
let Some(end_ix) = xml[start_ix..].find("</operations>") else {
return Vec::new();
};
let end_ix = end_ix + start_ix + "</operations>".len();
let doc = roxmltree::Document::parse(&xml[start_ix..end_ix]).log_err();
doc.map_or(Vec::new(), |doc| {
doc.root_element()
.children()
.map(|node| {
let tag_name = node.tag_name().name();
let path = node
.attribute("path")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'path'")
})?
.to_string();
let kind = match tag_name {
"update" => EditOperationKind::Update {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"create" => EditOperationKind::Create {
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"insert_sibling_after" => EditOperationKind::InsertSiblingAfter {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"insert_sibling_before" => EditOperationKind::InsertSiblingBefore {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"prepend_child" => EditOperationKind::PrependChild {
symbol: node.attribute("symbol").map(String::from),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"append_child" => EditOperationKind::AppendChild {
symbol: node.attribute("symbol").map(String::from),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"delete" => EditOperationKind::Delete {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
},
_ => return Err(anyhow!("invalid node {node:?}")),
};
anyhow::Ok(EditOperation { path, kind })
})
.filter_map(|op| op.log_err())
.collect()
}) })
} }
@ -3083,55 +2996,6 @@ mod tests {
} }
} }
#[test]
fn test_parse_edit_operations() {
let operations = indoc! {r#"
Here are the operations to make all fields of the Canvas struct private:
<operations>
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub pixels" description="Remove pub keyword from pixels field" />
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub size" description="Remove pub keyword from size field" />
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub stride" description="Remove pub keyword from stride field" />
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub format" description="Remove pub keyword from format field" />
</operations>
"#};
let parsed_operations = Context::parse_edit_operations(operations);
assert_eq!(
parsed_operations,
vec![
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub pixels".to_string(),
description: "Remove pub keyword from pixels field".to_string(),
},
},
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub size".to_string(),
description: "Remove pub keyword from size field".to_string(),
},
},
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub stride".to_string(),
description: "Remove pub keyword from stride field".to_string(),
},
},
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub format".to_string(),
description: "Remove pub keyword from format field".to_string(),
},
},
]
);
}
#[gpui::test] #[gpui::test]
async fn test_serialization(cx: &mut TestAppContext) { async fn test_serialization(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test); let settings_store = cx.update(SettingsStore::test);

View file

@ -17,7 +17,7 @@ use editor::{
use fs::Fs; use fs::Fs;
use futures::{ use futures::{
channel::mpsc, channel::mpsc,
future::LocalBoxFuture, future::{BoxFuture, LocalBoxFuture},
stream::{self, BoxStream}, stream::{self, BoxStream},
SinkExt, Stream, StreamExt, SinkExt, Stream, StreamExt,
}; };
@ -36,7 +36,7 @@ use similar::TextDiff;
use smol::future::FutureExt; use smol::future::FutureExt;
use std::{ use std::{
cmp, cmp,
future::Future, future::{self, Future},
mem, mem,
ops::{Range, RangeInclusive}, ops::{Range, RangeInclusive},
pin::Pin, pin::Pin,
@ -46,7 +46,7 @@ use std::{
}; };
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::{prelude::*, IconButtonShape, Tooltip}; use ui::{prelude::*, IconButtonShape, Tooltip};
use util::RangeExt; use util::{RangeExt, ResultExt};
use workspace::{notifications::NotificationId, Toast, Workspace}; use workspace::{notifications::NotificationId, Toast, Workspace};
pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) { pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
@ -187,7 +187,13 @@ impl InlineAssistant {
let [prompt_block_id, end_block_id] = let [prompt_block_id, end_block_id] =
self.insert_assist_blocks(editor, &range, &prompt_editor, cx); self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
assists.push((assist_id, prompt_editor, prompt_block_id, end_block_id)); assists.push((
assist_id,
range,
prompt_editor,
prompt_block_id,
end_block_id,
));
} }
let editor_assists = self let editor_assists = self
@ -195,7 +201,7 @@ impl InlineAssistant {
.entry(editor.downgrade()) .entry(editor.downgrade())
.or_insert_with(|| EditorInlineAssists::new(&editor, cx)); .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
let mut assist_group = InlineAssistGroup::new(); let mut assist_group = InlineAssistGroup::new();
for (assist_id, prompt_editor, prompt_block_id, end_block_id) in assists { for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
self.assists.insert( self.assists.insert(
assist_id, assist_id,
InlineAssist::new( InlineAssist::new(
@ -206,6 +212,7 @@ impl InlineAssistant {
&prompt_editor, &prompt_editor,
prompt_block_id, prompt_block_id,
end_block_id, end_block_id,
range,
prompt_editor.read(cx).codegen.clone(), prompt_editor.read(cx).codegen.clone(),
workspace.clone(), workspace.clone(),
cx, cx,
@ -227,7 +234,7 @@ impl InlineAssistant {
editor: &View<Editor>, editor: &View<Editor>,
mut range: Range<Anchor>, mut range: Range<Anchor>,
initial_prompt: String, initial_prompt: String,
initial_insertion: Option<String>, initial_insertion: Option<InitialInsertion>,
workspace: Option<WeakView<Workspace>>, workspace: Option<WeakView<Workspace>>,
assistant_panel: Option<&View<AssistantPanel>>, assistant_panel: Option<&View<AssistantPanel>>,
cx: &mut WindowContext, cx: &mut WindowContext,
@ -239,22 +246,30 @@ impl InlineAssistant {
let assist_id = self.next_assist_id.post_inc(); let assist_id = self.next_assist_id.post_inc();
let buffer = editor.read(cx).buffer().clone(); let buffer = editor.read(cx).buffer().clone();
let prepend_transaction_id = initial_insertion.and_then(|initial_insertion| { {
buffer.update(cx, |buffer, cx| { let snapshot = buffer.read(cx).read(cx);
buffer.start_transaction(cx);
buffer.edit([(range.start..range.start, initial_insertion)], None, cx);
buffer.end_transaction(cx)
})
});
range.start = range.start.bias_left(&buffer.read(cx).read(cx)); let mut point_range = range.to_point(&snapshot);
range.end = range.end.bias_right(&buffer.read(cx).read(cx)); if point_range.is_empty() {
point_range.start.column = 0;
point_range.end.column = 0;
} else {
point_range.start.column = 0;
if point_range.end.row > point_range.start.row && point_range.end.column == 0 {
point_range.end.row -= 1;
}
point_range.end.column = snapshot.line_len(MultiBufferRow(point_range.end.row));
}
range.start = snapshot.anchor_before(point_range.start);
range.end = snapshot.anchor_after(point_range.end);
}
let codegen = cx.new_model(|cx| { let codegen = cx.new_model(|cx| {
Codegen::new( Codegen::new(
editor.read(cx).buffer().clone(), editor.read(cx).buffer().clone(),
range.clone(), range.clone(),
prepend_transaction_id, initial_insertion,
self.telemetry.clone(), self.telemetry.clone(),
cx, cx,
) )
@ -295,6 +310,7 @@ impl InlineAssistant {
&prompt_editor, &prompt_editor,
prompt_block_id, prompt_block_id,
end_block_id, end_block_id,
range,
prompt_editor.read(cx).codegen.clone(), prompt_editor.read(cx).codegen.clone(),
workspace.clone(), workspace.clone(),
cx, cx,
@ -445,7 +461,7 @@ impl InlineAssistant {
let buffer = editor.buffer().read(cx).snapshot(cx); let buffer = editor.buffer().read(cx).snapshot(cx);
for assist_id in &editor_assists.assist_ids { for assist_id in &editor_assists.assist_ids {
let assist = &self.assists[assist_id]; let assist = &self.assists[assist_id];
let assist_range = assist.codegen.read(cx).range.to_offset(&buffer); let assist_range = assist.range.to_offset(&buffer);
if assist_range.contains(&selection.start) && assist_range.contains(&selection.end) if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
{ {
if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) { if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
@ -473,7 +489,7 @@ impl InlineAssistant {
let buffer = editor.buffer().read(cx).snapshot(cx); let buffer = editor.buffer().read(cx).snapshot(cx);
for assist_id in &editor_assists.assist_ids { for assist_id in &editor_assists.assist_ids {
let assist = &self.assists[assist_id]; let assist = &self.assists[assist_id];
let assist_range = assist.codegen.read(cx).range.to_offset(&buffer); let assist_range = assist.range.to_offset(&buffer);
if assist.decorations.is_some() if assist.decorations.is_some()
&& assist_range.contains(&selection.start) && assist_range.contains(&selection.start)
&& assist_range.contains(&selection.end) && assist_range.contains(&selection.end)
@ -551,7 +567,7 @@ impl InlineAssistant {
assist.codegen.read(cx).status, assist.codegen.read(cx).status,
CodegenStatus::Error(_) | CodegenStatus::Done CodegenStatus::Error(_) | CodegenStatus::Done
) { ) {
let assist_range = assist.codegen.read(cx).range.to_offset(&snapshot); let assist_range = assist.range.to_offset(&snapshot);
if edited_ranges if edited_ranges
.iter() .iter()
.any(|range| range.overlaps(&assist_range)) .any(|range| range.overlaps(&assist_range))
@ -721,7 +737,7 @@ impl InlineAssistant {
}); });
} }
let position = assist.codegen.read(cx).range.start; let position = assist.range.start;
editor.update(cx, |editor, cx| { editor.update(cx, |editor, cx| {
editor.change_selections(None, cx, |selections| { editor.change_selections(None, cx, |selections| {
selections.select_anchor_ranges([position..position]) selections.select_anchor_ranges([position..position])
@ -740,8 +756,7 @@ impl InlineAssistant {
.0 as f32; .0 as f32;
} else { } else {
let snapshot = editor.snapshot(cx); let snapshot = editor.snapshot(cx);
let codegen = assist.codegen.read(cx); let start_row = assist
let start_row = codegen
.range .range
.start .start
.to_display_point(&snapshot.display_snapshot) .to_display_point(&snapshot.display_snapshot)
@ -829,11 +844,7 @@ impl InlineAssistant {
return; return;
} }
let Some(user_prompt) = assist let Some(user_prompt) = assist.user_prompt(cx) else {
.decorations
.as_ref()
.map(|decorations| decorations.prompt_editor.read(cx).prompt(cx))
else {
return; return;
}; };
@ -843,139 +854,19 @@ impl InlineAssistant {
self.prompt_history.pop_front(); self.prompt_history.pop_front();
} }
let codegen = assist.codegen.clone(); let assistant_panel_context = assist.assistant_panel_context(cx);
let telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
if user_prompt.trim().to_lowercase() == "delete" {
async { Ok(stream::empty().boxed()) }.boxed_local()
} else {
let request = self.request_for_inline_assist(assist_id, cx);
let mut cx = cx.to_async();
async move {
let request = request.await?;
let chunks = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx)
.stream_completion(request, cx)
})?
.await?;
Ok(chunks.boxed())
}
.boxed_local()
};
codegen.update(cx, |codegen, cx| {
codegen.start(telemetry_id, chunks, cx);
});
}
fn request_for_inline_assist( assist
&self, .codegen
assist_id: InlineAssistId, .update(cx, |codegen, cx| {
cx: &mut WindowContext, codegen.start(
) -> Task<Result<LanguageModelRequest>> { assist.range.clone(),
cx.spawn(|mut cx| async move { user_prompt,
let (user_prompt, context_request, project_name, buffer, range) = assistant_panel_context,
cx.read_global(|this: &InlineAssistant, cx: &WindowContext| { cx,
let assist = this.assists.get(&assist_id).context("invalid assist")?;
let decorations = assist.decorations.as_ref().context("invalid assist")?;
let editor = assist.editor.upgrade().context("invalid assist")?;
let user_prompt = decorations.prompt_editor.read(cx).prompt(cx);
let context_request = if assist.include_context {
assist.workspace.as_ref().and_then(|workspace| {
let workspace = workspace.upgrade()?.read(cx);
let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
Some(
assistant_panel
.read(cx)
.active_context(cx)?
.read(cx)
.to_completion_request(cx),
) )
}) })
} else { .log_err();
None
};
let project_name = assist.workspace.as_ref().and_then(|workspace| {
let workspace = workspace.upgrade()?;
Some(
workspace
.read(cx)
.project()
.read(cx)
.worktree_root_names(cx)
.collect::<Vec<&str>>()
.join("/"),
)
});
let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
let range = assist.codegen.read(cx).range.clone();
anyhow::Ok((user_prompt, context_request, project_name, buffer, range))
})??;
let language = buffer.language_at(range.start);
let language_name = if let Some(language) = language.as_ref() {
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
None
} else {
Some(language.name())
}
} else {
None
};
// Higher Temperature increases the randomness of model outputs.
// If Markdown or No Language is Known, increase the randomness for more creative output
// If Code, decrease temperature to get more deterministic outputs
let temperature = if let Some(language) = language_name.clone() {
if language.as_ref() == "Markdown" {
1.0
} else {
0.5
}
} else {
1.0
};
let prompt = cx
.background_executor()
.spawn(async move {
let language_name = language_name.as_deref();
let start = buffer.point_to_buffer_offset(range.start);
let end = buffer.point_to_buffer_offset(range.end);
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
let (start_buffer, start_buffer_offset) = start;
let (end_buffer, end_buffer_offset) = end;
if start_buffer.remote_id() == end_buffer.remote_id() {
(start_buffer.clone(), start_buffer_offset..end_buffer_offset)
} else {
return Err(anyhow!("invalid transformation range"));
}
} else {
return Err(anyhow!("invalid transformation range"));
};
generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
})
.await?;
let mut messages = Vec::new();
if let Some(context_request) = context_request {
messages = context_request.messages;
}
messages.push(LanguageModelRequestMessage {
role: Role::User,
content: prompt,
});
Ok(LanguageModelRequest {
messages,
stop: vec!["|END|>".to_string()],
temperature,
})
})
} }
pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) { pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
@ -1006,12 +897,11 @@ impl InlineAssistant {
let codegen = assist.codegen.read(cx); let codegen = assist.codegen.read(cx);
foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned()); foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
if codegen.edit_position != codegen.range.end { gutter_pending_ranges
gutter_pending_ranges.push(codegen.edit_position..codegen.range.end); .push(codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end);
}
if codegen.range.start != codegen.edit_position { if let Some(edit_position) = codegen.edit_position {
gutter_transformed_ranges.push(codegen.range.start..codegen.edit_position); gutter_transformed_ranges.push(assist.range.start..edit_position);
} }
if assist.decorations.is_some() { if assist.decorations.is_some() {
@ -1268,6 +1158,12 @@ fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
}) })
} }
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum InitialInsertion {
NewlineBefore,
NewlineAfter,
}
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct InlineAssistId(usize); pub struct InlineAssistId(usize);
@ -1629,24 +1525,20 @@ impl PromptEditor {
let assist_id = self.id; let assist_id = self.id;
self.pending_token_count = cx.spawn(|this, mut cx| async move { self.pending_token_count = cx.spawn(|this, mut cx| async move {
cx.background_executor().timer(Duration::from_secs(1)).await; cx.background_executor().timer(Duration::from_secs(1)).await;
let request = cx let token_count = cx
.update_global(|inline_assistant: &mut InlineAssistant, cx| { .update_global(|inline_assistant: &mut InlineAssistant, cx| {
inline_assistant.request_for_inline_assist(assist_id, cx) let assist = inline_assistant
})? .assists
.get(&assist_id)
.context("assist not found")?;
anyhow::Ok(assist.count_tokens(cx))
})??
.await?; .await?;
if let Some(token_count) = cx.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})? {
let token_count = token_count.await?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count); this.token_count = Some(token_count);
cx.notify(); cx.notify();
}) })
} else {
Ok(())
}
}) })
} }
@ -1855,6 +1747,7 @@ impl PromptEditor {
struct InlineAssist { struct InlineAssist {
group_id: InlineAssistGroupId, group_id: InlineAssistGroupId,
range: Range<Anchor>,
editor: WeakView<Editor>, editor: WeakView<Editor>,
decorations: Option<InlineAssistDecorations>, decorations: Option<InlineAssistDecorations>,
codegen: Model<Codegen>, codegen: Model<Codegen>,
@ -1873,6 +1766,7 @@ impl InlineAssist {
prompt_editor: &View<PromptEditor>, prompt_editor: &View<PromptEditor>,
prompt_block_id: CustomBlockId, prompt_block_id: CustomBlockId,
end_block_id: CustomBlockId, end_block_id: CustomBlockId,
range: Range<Anchor>,
codegen: Model<Codegen>, codegen: Model<Codegen>,
workspace: Option<WeakView<Workspace>>, workspace: Option<WeakView<Workspace>>,
cx: &mut WindowContext, cx: &mut WindowContext,
@ -1888,6 +1782,7 @@ impl InlineAssist {
removed_line_block_ids: HashSet::default(), removed_line_block_ids: HashSet::default(),
end_block_id, end_block_id,
}), }),
range,
codegen: codegen.clone(), codegen: codegen.clone(),
workspace: workspace.clone(), workspace: workspace.clone(),
_subscriptions: vec![ _subscriptions: vec![
@ -1963,6 +1858,41 @@ impl InlineAssist {
], ],
} }
} }
fn user_prompt(&self, cx: &AppContext) -> Option<String> {
let decorations = self.decorations.as_ref()?;
Some(decorations.prompt_editor.read(cx).prompt(cx))
}
fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
if self.include_context {
let workspace = self.workspace.as_ref()?;
let workspace = workspace.upgrade()?.read(cx);
let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
Some(
assistant_panel
.read(cx)
.active_context(cx)?
.read(cx)
.to_completion_request(cx),
)
} else {
None
}
}
pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<usize>> {
let Some(user_prompt) = self.user_prompt(cx) else {
return future::ready(Err(anyhow!("no user prompt"))).boxed();
};
let assistant_panel_context = self.assistant_panel_context(cx);
self.codegen.read(cx).count_tokens(
self.range.clone(),
user_prompt,
assistant_panel_context,
cx,
)
}
} }
struct InlineAssistDecorations { struct InlineAssistDecorations {
@ -1982,16 +1912,15 @@ pub struct Codegen {
buffer: Model<MultiBuffer>, buffer: Model<MultiBuffer>,
old_buffer: Model<Buffer>, old_buffer: Model<Buffer>,
snapshot: MultiBufferSnapshot, snapshot: MultiBufferSnapshot,
range: Range<Anchor>, edit_position: Option<Anchor>,
edit_position: Anchor,
last_equal_ranges: Vec<Range<Anchor>>, last_equal_ranges: Vec<Range<Anchor>>,
prepend_transaction_id: Option<TransactionId>, transaction_id: Option<TransactionId>,
generation_transaction_id: Option<TransactionId>,
status: CodegenStatus, status: CodegenStatus,
generation: Task<()>, generation: Task<()>,
diff: Diff, diff: Diff,
telemetry: Option<Arc<Telemetry>>, telemetry: Option<Arc<Telemetry>>,
_subscription: gpui::Subscription, _subscription: gpui::Subscription,
initial_insertion: Option<InitialInsertion>,
} }
enum CodegenStatus { enum CodegenStatus {
@ -2015,7 +1944,7 @@ impl Codegen {
pub fn new( pub fn new(
buffer: Model<MultiBuffer>, buffer: Model<MultiBuffer>,
range: Range<Anchor>, range: Range<Anchor>,
prepend_transaction_id: Option<TransactionId>, initial_insertion: Option<InitialInsertion>,
telemetry: Option<Arc<Telemetry>>, telemetry: Option<Arc<Telemetry>>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Self { ) -> Self {
@ -2044,17 +1973,16 @@ impl Codegen {
Self { Self {
buffer: buffer.clone(), buffer: buffer.clone(),
old_buffer, old_buffer,
edit_position: range.start, edit_position: None,
range,
snapshot, snapshot,
last_equal_ranges: Default::default(), last_equal_ranges: Default::default(),
prepend_transaction_id, transaction_id: None,
generation_transaction_id: None,
status: CodegenStatus::Idle, status: CodegenStatus::Idle,
generation: Task::ready(()), generation: Task::ready(()),
diff: Diff::default(), diff: Diff::default(),
telemetry, telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event), _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
initial_insertion,
} }
} }
@ -2065,13 +1993,8 @@ impl Codegen {
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) { ) {
if let multi_buffer::Event::TransactionUndone { transaction_id } = event { if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
if self.generation_transaction_id == Some(*transaction_id) { if self.transaction_id == Some(*transaction_id) {
self.generation_transaction_id = None; self.transaction_id = None;
self.generation = Task::ready(());
cx.emit(CodegenEvent::Undone);
} else if self.prepend_transaction_id == Some(*transaction_id) {
self.prepend_transaction_id = None;
self.generation_transaction_id = None;
self.generation = Task::ready(()); self.generation = Task::ready(());
cx.emit(CodegenEvent::Undone); cx.emit(CodegenEvent::Undone);
} }
@ -2082,19 +2005,152 @@ impl Codegen {
&self.last_equal_ranges &self.last_equal_ranges
} }
pub fn count_tokens(
&self,
edit_range: Range<Anchor>,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
}
pub fn start( pub fn start(
&mut self, &mut self,
telemetry_id: String, mut edit_range: Range<Anchor>,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
cx: &mut ModelContext<Self>,
) -> Result<()> {
self.undo(cx);
// Handle initial insertion
self.transaction_id = if let Some(initial_insertion) = self.initial_insertion {
self.buffer.update(cx, |buffer, cx| {
buffer.start_transaction(cx);
let offset = edit_range.start.to_offset(&self.snapshot);
let edit_position;
match initial_insertion {
InitialInsertion::NewlineBefore => {
buffer.edit([(offset..offset, "\n\n")], None, cx);
self.snapshot = buffer.snapshot(cx);
edit_position = self.snapshot.anchor_after(offset + 1);
}
InitialInsertion::NewlineAfter => {
buffer.edit([(offset..offset, "\n")], None, cx);
self.snapshot = buffer.snapshot(cx);
edit_position = self.snapshot.anchor_after(offset);
}
}
self.edit_position = Some(edit_position);
edit_range = edit_position.bias_left(&self.snapshot)..edit_position;
buffer.end_transaction(cx)
})
} else {
self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
None
};
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model_telemetry_id()
.context("no active model")?;
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
.trim()
.to_lowercase()
== "delete"
{
async { Ok(stream::empty().boxed()) }.boxed_local()
} else {
let request =
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
let chunks =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
async move { Ok(chunks.await?.boxed()) }.boxed_local()
};
self.handle_stream(model_telemetry_id, edit_range, chunks, cx);
Ok(())
}
fn build_request(
&self,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
edit_range: Range<Anchor>,
cx: &AppContext,
) -> LanguageModelRequest {
let buffer = self.buffer.read(cx).snapshot(cx);
let language = buffer.language_at(edit_range.start);
let language_name = if let Some(language) = language.as_ref() {
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
None
} else {
Some(language.name())
}
} else {
None
};
// Higher Temperature increases the randomness of model outputs.
// If Markdown or No Language is Known, increase the randomness for more creative output
// If Code, decrease temperature to get more deterministic outputs
let temperature = if let Some(language) = language_name.clone() {
if language.as_ref() == "Markdown" {
1.0
} else {
0.5
}
} else {
1.0
};
let language_name = language_name.as_deref();
let start = buffer.point_to_buffer_offset(edit_range.start);
let end = buffer.point_to_buffer_offset(edit_range.end);
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
let (start_buffer, start_buffer_offset) = start;
let (end_buffer, end_buffer_offset) = end;
if start_buffer.remote_id() == end_buffer.remote_id() {
(start_buffer.clone(), start_buffer_offset..end_buffer_offset)
} else {
panic!("invalid transformation range");
}
} else {
panic!("invalid transformation range");
};
let prompt = generate_content_prompt(user_prompt, language_name, buffer, range);
let mut messages = Vec::new();
if let Some(context_request) = assistant_panel_context {
messages = context_request.messages;
}
messages.push(LanguageModelRequestMessage {
role: Role::User,
content: prompt,
});
LanguageModelRequest {
messages,
stop: vec!["|END|>".to_string()],
temperature,
}
}
pub fn handle_stream(
&mut self,
model_telemetry_id: String,
edit_range: Range<Anchor>,
stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>, stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) { ) {
let range = self.range.clone();
let snapshot = self.snapshot.clone(); let snapshot = self.snapshot.clone();
let selected_text = snapshot let selected_text = snapshot
.text_for_range(range.start..range.end) .text_for_range(edit_range.start..edit_range.end)
.collect::<Rope>(); .collect::<Rope>();
let selection_start = range.start.to_point(&snapshot); let selection_start = edit_range.start.to_point(&snapshot);
// Start with the indentation of the first line in the selection // Start with the indentation of the first line in the selection
let mut suggested_line_indent = snapshot let mut suggested_line_indent = snapshot
@ -2105,7 +2161,7 @@ impl Codegen {
// If the first line in the selection does not have indentation, check the following lines // If the first line in the selection does not have indentation, check the following lines
if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space { if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
for row in selection_start.row..=range.end.to_point(&snapshot).row { for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row)); let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
// Prefer tabs if a line in the selection uses tabs as indentation // Prefer tabs if a line in the selection uses tabs as indentation
if line_indent.kind == IndentKind::Tab { if line_indent.kind == IndentKind::Tab {
@ -2116,19 +2172,13 @@ impl Codegen {
} }
let telemetry = self.telemetry.clone(); let telemetry = self.telemetry.clone();
self.edit_position = range.start;
self.diff = Diff::default(); self.diff = Diff::default();
self.status = CodegenStatus::Pending; self.status = CodegenStatus::Pending;
if let Some(transaction_id) = self.generation_transaction_id.take() { let mut edit_start = edit_range.start.to_offset(&snapshot);
self.buffer
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
}
self.generation = cx.spawn(|this, mut cx| { self.generation = cx.spawn(|this, mut cx| {
async move { async move {
let chunks = stream.await; let chunks = stream.await;
let generate = async { let generate = async {
let mut edit_start = range.start.to_offset(&snapshot);
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
let diff: Task<anyhow::Result<()>> = let diff: Task<anyhow::Result<()>> =
cx.background_executor().spawn(async move { cx.background_executor().spawn(async move {
@ -2218,7 +2268,7 @@ impl Codegen {
telemetry.report_assistant_event( telemetry.report_assistant_event(
None, None,
telemetry_events::AssistantKind::Inline, telemetry_events::AssistantKind::Inline,
telemetry_id, model_telemetry_id,
response_latency, response_latency,
error_message, error_message,
); );
@ -2262,13 +2312,13 @@ impl Codegen {
None, None,
cx, cx,
); );
this.edit_position = snapshot.anchor_after(edit_start); this.edit_position = Some(snapshot.anchor_after(edit_start));
buffer.end_transaction(cx) buffer.end_transaction(cx)
}); });
if let Some(transaction) = transaction { if let Some(transaction) = transaction {
if let Some(first_transaction) = this.generation_transaction_id { if let Some(first_transaction) = this.transaction_id {
// Group all assistant edits into the first transaction. // Group all assistant edits into the first transaction.
this.buffer.update(cx, |buffer, cx| { this.buffer.update(cx, |buffer, cx| {
buffer.merge_transactions( buffer.merge_transactions(
@ -2278,14 +2328,14 @@ impl Codegen {
) )
}); });
} else { } else {
this.generation_transaction_id = Some(transaction); this.transaction_id = Some(transaction);
this.buffer.update(cx, |buffer, cx| { this.buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction(cx) buffer.finalize_last_transaction(cx)
}); });
} }
} }
this.update_diff(cx); this.update_diff(edit_range.clone(), cx);
cx.notify(); cx.notify();
})?; })?;
} }
@ -2321,27 +2371,22 @@ impl Codegen {
} }
pub fn undo(&mut self, cx: &mut ModelContext<Self>) { pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
if let Some(transaction_id) = self.prepend_transaction_id.take() { if let Some(transaction_id) = self.transaction_id.take() {
self.buffer
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
}
if let Some(transaction_id) = self.generation_transaction_id.take() {
self.buffer self.buffer
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx)); .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
} }
} }
fn update_diff(&mut self, cx: &mut ModelContext<Self>) { fn update_diff(&mut self, edit_range: Range<Anchor>, cx: &mut ModelContext<Self>) {
if self.diff.task.is_some() { if self.diff.task.is_some() {
self.diff.should_update = true; self.diff.should_update = true;
} else { } else {
self.diff.should_update = false; self.diff.should_update = false;
let old_snapshot = self.snapshot.clone(); let old_snapshot = self.snapshot.clone();
let old_range = self.range.to_point(&old_snapshot); let old_range = edit_range.to_point(&old_snapshot);
let new_snapshot = self.buffer.read(cx).snapshot(cx); let new_snapshot = self.buffer.read(cx).snapshot(cx);
let new_range = self.range.to_point(&new_snapshot); let new_range = edit_range.to_point(&new_snapshot);
self.diff.task = Some(cx.spawn(|this, mut cx| async move { self.diff.task = Some(cx.spawn(|this, mut cx| async move {
let (deleted_row_ranges, inserted_row_ranges) = cx let (deleted_row_ranges, inserted_row_ranges) = cx
@ -2422,7 +2467,7 @@ impl Codegen {
this.diff.inserted_row_ranges = inserted_row_ranges; this.diff.inserted_row_ranges = inserted_row_ranges;
this.diff.task = None; this.diff.task = None;
if this.diff.should_update { if this.diff.should_update {
this.update_diff(cx); this.update_diff(edit_range, cx);
} }
cx.notify(); cx.notify();
}) })
@ -2629,12 +2674,14 @@ mod tests {
let snapshot = buffer.snapshot(cx); let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
}); });
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx)); let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded(); let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| { codegen.update(cx, |codegen, cx| {
codegen.start( codegen.handle_stream(
String::new(), String::new(),
range,
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx, cx,
) )
@ -2690,12 +2737,14 @@ mod tests {
let snapshot = buffer.snapshot(cx); let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6)) snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
}); });
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx)); let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded(); let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| { codegen.update(cx, |codegen, cx| {
codegen.start( codegen.handle_stream(
String::new(), String::new(),
range.clone(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx, cx,
) )
@ -2755,12 +2804,14 @@ mod tests {
let snapshot = buffer.snapshot(cx); let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2)) snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
}); });
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx)); let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded(); let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| { codegen.update(cx, |codegen, cx| {
codegen.start( codegen.handle_stream(
String::new(), String::new(),
range.clone(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx, cx,
) )
@ -2819,12 +2870,14 @@ mod tests {
let snapshot = buffer.snapshot(cx); let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2)) snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
}); });
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx)); let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded(); let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| { codegen.update(cx, |codegen, cx| {
codegen.start( codegen.handle_stream(
String::new(), String::new(),
range.clone(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx, cx,
) )

View file

@ -734,7 +734,8 @@ impl PromptLibrary {
const DEBOUNCE_TIMEOUT: Duration = Duration::from_secs(1); const DEBOUNCE_TIMEOUT: Duration = Duration::from_secs(1);
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await; cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
if let Some(token_count) = cx.update(|cx| { let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens( LanguageModelCompletionProvider::read_global(cx).count_tokens(
LanguageModelRequest { LanguageModelRequest {
messages: vec![LanguageModelRequestMessage { messages: vec![LanguageModelRequestMessage {
@ -746,17 +747,14 @@ impl PromptLibrary {
}, },
cx, cx,
) )
})? { })?
let token_count = token_count.await?; .await?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap(); let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap();
prompt_editor.token_count = Some(token_count); prompt_editor.token_count = Some(token_count);
cx.notify(); cx.notify();
}) })
} else {
Ok(())
}
} }
.log_err() .log_err()
}); });

View file

@ -6,8 +6,7 @@ pub fn generate_content_prompt(
language_name: Option<&str>, language_name: Option<&str>,
buffer: BufferSnapshot, buffer: BufferSnapshot,
range: Range<usize>, range: Range<usize>,
_project_name: Option<String>, ) -> String {
) -> anyhow::Result<String> {
let mut prompt = String::new(); let mut prompt = String::new();
let content_type = match language_name { let content_type = match language_name {
@ -15,14 +14,16 @@ pub fn generate_content_prompt(
writeln!( writeln!(
prompt, prompt,
"Here's a file of text that I'm going to ask you to make an edit to." "Here's a file of text that I'm going to ask you to make an edit to."
)?; )
.unwrap();
"text" "text"
} }
Some(language_name) => { Some(language_name) => {
writeln!( writeln!(
prompt, prompt,
"Here's a file of {language_name} that I'm going to ask you to make an edit to." "Here's a file of {language_name} that I'm going to ask you to make an edit to."
)?; )
.unwrap();
"code" "code"
} }
}; };
@ -70,7 +71,7 @@ pub fn generate_content_prompt(
write!(prompt, "</document>\n\n").unwrap(); write!(prompt, "</document>\n\n").unwrap();
if is_truncated { if is_truncated {
writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n")?; writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n").unwrap();
} }
if range.is_empty() { if range.is_empty() {
@ -107,7 +108,7 @@ pub fn generate_content_prompt(
prompt.push_str("\n\nImmediately start with the following format with no remarks:\n\n```\n{{REWRITTEN_CODE}}\n```"); prompt.push_str("\n\nImmediately start with the following format with no remarks:\n\n```\n{{REWRITTEN_CODE}}\n```");
} }
Ok(prompt) prompt
} }
pub fn generate_terminal_assistant_prompt( pub fn generate_terminal_assistant_prompt(

View file

@ -707,18 +707,15 @@ impl PromptEditor {
inline_assistant.request_for_inline_assist(assist_id, cx) inline_assistant.request_for_inline_assist(assist_id, cx)
})??; })??;
if let Some(token_count) = cx.update(|cx| { let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx) LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})? { })?
let token_count = token_count.await?; .await?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count); this.token_count = Some(token_count);
cx.notify(); cx.notify();
}) })
} else {
Ok(())
}
}) })
} }

View file

@ -10,7 +10,7 @@ use crate::{
ServerId, UpdatedChannelMessage, User, UserId, ServerId, UpdatedChannelMessage, User, UserId,
}, },
executor::Executor, executor::Executor,
AppState, Error, RateLimit, RateLimiter, Result, AppState, Config, Error, RateLimit, RateLimiter, Result,
}; };
use anyhow::{anyhow, bail, Context as _}; use anyhow::{anyhow, bail, Context as _};
use async_tungstenite::tungstenite::{ use async_tungstenite::tungstenite::{
@ -605,17 +605,39 @@ impl Server {
)) ))
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>) .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
.add_message_handler(update_context) .add_message_handler(update_context)
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
let app_state = app_state.clone();
async move {
complete_with_language_model(request, response, session, &app_state.config)
.await
}
}
})
.add_streaming_request_handler({ .add_streaming_request_handler({
let app_state = app_state.clone(); let app_state = app_state.clone();
move |request, response, session| { move |request, response, session| {
complete_with_language_model( let app_state = app_state.clone();
async move {
stream_complete_with_language_model(
request, request,
response, response,
session, session,
app_state.config.openai_api_key.clone(), &app_state.config,
app_state.config.google_ai_api_key.clone(),
app_state.config.anthropic_api_key.clone(),
) )
.await
}
}
})
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
let app_state = app_state.clone();
async move {
count_language_model_tokens(request, response, session, &app_state.config)
.await
}
} }
}) })
.add_request_handler({ .add_request_handler({
@ -4503,116 +4525,172 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
} }
async fn complete_with_language_model( async fn complete_with_language_model(
query: proto::QueryLanguageModel, request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::QueryLanguageModel>, response: Response<proto::CompleteWithLanguageModel>,
session: Session, session: Session,
open_ai_api_key: Option<Arc<str>>, config: &Config,
google_ai_api_key: Option<Arc<str>>,
anthropic_api_key: Option<Arc<str>>,
) -> Result<()> { ) -> Result<()> {
let Some(session) = session.for_user() else { let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?; return Err(anyhow!("user not found"))?;
}; };
authorize_access_to_language_models(&session).await?; authorize_access_to_language_models(&session).await?;
match proto::LanguageModelRequestKind::from_i32(query.kind) {
Some(proto::LanguageModelRequestKind::Complete) => {
session session
.rate_limiter .rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id()) .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?; .await?;
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Anthropic) => {
let api_key = config
.anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?;
anthropic::complete(
session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL,
api_key,
serde_json::from_str(&request.request)?,
)
.await?
} }
Some(proto::LanguageModelRequestKind::CountTokens) => { _ => return Err(anyhow!("unsupported provider"))?,
session };
.rate_limiter
.check::<CountTokensWithLanguageModelRateLimit>(session.user_id()) response.send(proto::CompleteWithLanguageModelResponse {
.await?; completion: serde_json::to_string(&result)?,
} })?;
None => Err(anyhow!("unknown request kind"))?,
Ok(())
} }
match proto::LanguageModelProvider::from_i32(query.provider) { async fn stream_complete_with_language_model(
request: proto::StreamCompleteWithLanguageModel,
response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
session: Session,
config: &Config,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Anthropic) => { Some(proto::LanguageModelProvider::Anthropic) => {
let api_key = let api_key = config
anthropic_api_key.context("no Anthropic AI API key configured on the server")?; .anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?;
let mut chunks = anthropic::stream_completion( let mut chunks = anthropic::stream_completion(
session.http_client.as_ref(), session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL, anthropic::ANTHROPIC_API_URL,
&api_key, api_key,
serde_json::from_str(&query.request)?, serde_json::from_str(&request.request)?,
None, None,
) )
.await?; .await?;
while let Some(chunk) = chunks.next().await { while let Some(event) = chunks.next().await {
let chunk = chunk?; let chunk = event?;
response.send(proto::QueryLanguageModelResponse { response.send(proto::StreamCompleteWithLanguageModelResponse {
response: serde_json::to_string(&chunk)?, event: serde_json::to_string(&chunk)?,
})?; })?;
} }
} }
Some(proto::LanguageModelProvider::OpenAi) => { Some(proto::LanguageModelProvider::OpenAi) => {
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?; let api_key = config
let mut chunks = open_ai::stream_completion( .openai_api_key
.as_ref()
.context("no OpenAI API key configured on the server")?;
let mut events = open_ai::stream_completion(
session.http_client.as_ref(), session.http_client.as_ref(),
open_ai::OPEN_AI_API_URL, open_ai::OPEN_AI_API_URL,
&api_key, api_key,
serde_json::from_str(&query.request)?, serde_json::from_str(&request.request)?,
None, None,
) )
.await?; .await?;
while let Some(chunk) = chunks.next().await { while let Some(event) = events.next().await {
let chunk = chunk?; let event = event?;
response.send(proto::QueryLanguageModelResponse { response.send(proto::StreamCompleteWithLanguageModelResponse {
response: serde_json::to_string(&chunk)?, event: serde_json::to_string(&event)?,
})?; })?;
} }
} }
Some(proto::LanguageModelProvider::Google) => { Some(proto::LanguageModelProvider::Google) => {
let api_key = let api_key = config
google_ai_api_key.context("no Google AI API key configured on the server")?; .google_ai_api_key
.as_ref()
match proto::LanguageModelRequestKind::from_i32(query.kind) { .context("no Google AI API key configured on the server")?;
Some(proto::LanguageModelRequestKind::Complete) => { let mut events = google_ai::stream_generate_content(
let mut chunks = google_ai::stream_generate_content(
session.http_client.as_ref(), session.http_client.as_ref(),
google_ai::API_URL, google_ai::API_URL,
&api_key, api_key,
serde_json::from_str(&query.request)?, serde_json::from_str(&request.request)?,
) )
.await?; .await?;
while let Some(chunk) = chunks.next().await { while let Some(event) = events.next().await {
let chunk = chunk?; let event = event?;
response.send(proto::QueryLanguageModelResponse { response.send(proto::StreamCompleteWithLanguageModelResponse {
response: serde_json::to_string(&chunk)?, event: serde_json::to_string(&event)?,
})?; })?;
} }
} }
Some(proto::LanguageModelRequestKind::CountTokens) => {
let tokens_response = google_ai::count_tokens(
session.http_client.as_ref(),
google_ai::API_URL,
&api_key,
serde_json::from_str(&query.request)?,
)
.await?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&tokens_response)?,
})?;
}
None => Err(anyhow!("unknown request kind"))?,
}
}
None => return Err(anyhow!("unknown provider"))?, None => return Err(anyhow!("unknown provider"))?,
} }
Ok(()) Ok(())
} }
struct CountTokensWithLanguageModelRateLimit; async fn count_language_model_tokens(
request: proto::CountLanguageModelTokens,
response: Response<proto::CountLanguageModelTokens>,
session: Session,
config: &Config,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
impl RateLimit for CountTokensWithLanguageModelRateLimit { session
.rate_limiter
.check::<CountLanguageModelTokensRateLimit>(session.user_id())
.await?;
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Google) => {
let api_key = config
.google_ai_api_key
.as_ref()
.context("no Google AI API key configured on the server")?;
google_ai::count_tokens(
session.http_client.as_ref(),
google_ai::API_URL,
api_key,
serde_json::from_str(&request.request)?,
)
.await?
}
_ => return Err(anyhow!("unsupported provider"))?,
};
response.send(proto::CountLanguageModelTokensResponse {
token_count: result.total_tokens as u32,
})?;
Ok(())
}
struct CountLanguageModelTokensRateLimit;
impl RateLimit for CountLanguageModelTokensRateLimit {
fn capacity() -> usize { fn capacity() -> usize {
std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR") std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
.ok() .ok()
.and_then(|v| v.parse().ok()) .and_then(|v| v.parse().ok())
.unwrap_or(600) // Picked arbitrarily .unwrap_or(600) // Picked arbitrarily
@ -4623,7 +4701,7 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit {
} }
fn db_name() -> &'static str { fn db_name() -> &'static str {
"count-tokens-with-language-model" "count-language-model-tokens"
} }
} }

View file

@ -26,7 +26,9 @@ anyhow.workspace = true
futures.workspace = true futures.workspace = true
gpui.workspace = true gpui.workspace = true
language_model.workspace = true language_model.workspace = true
schemars.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true
settings.workspace = true settings.workspace = true
smol.workspace = true smol.workspace = true
ui.workspace = true ui.workspace = true

View file

@ -3,10 +3,13 @@ use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
use gpui::{AppContext, Global, Model, ModelContext, Task}; use gpui::{AppContext, Global, Model, ModelContext, Task};
use language_model::{ use language_model::{
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
LanguageModelRequest, LanguageModelRequest, LanguageModelTool,
}; };
use smol::lock::{Semaphore, SemaphoreGuardArc}; use smol::{
use std::{pin::Pin, sync::Arc, task::Poll}; future::FutureExt,
lock::{Semaphore, SemaphoreGuardArc},
};
use std::{future, pin::Pin, sync::Arc, task::Poll};
use ui::Context; use ui::Context;
pub fn init(cx: &mut AppContext) { pub fn init(cx: &mut AppContext) {
@ -143,11 +146,11 @@ impl LanguageModelCompletionProvider {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AppContext, cx: &AppContext,
) -> Option<BoxFuture<'static, Result<usize>>> { ) -> BoxFuture<'static, Result<usize>> {
if let Some(model) = self.active_model() { if let Some(model) = self.active_model() {
Some(model.count_tokens(request, cx)) model.count_tokens(request, cx)
} else { } else {
None future::ready(Err(anyhow!("no active model"))).boxed()
} }
} }
@ -183,6 +186,29 @@ impl LanguageModelCompletionProvider {
Ok(completion) Ok(completion)
}) })
} }
pub fn use_tool<T: LanguageModelTool>(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> Task<Result<T>> {
if let Some(language_model) = self.active_model() {
cx.spawn(|cx| async move {
let schema = schemars::schema_for!(T);
let schema_json = serde_json::to_value(&schema).unwrap();
let request =
language_model.use_tool(request, T::name(), T::description(), schema_json, &cx);
let response = request.await?;
Ok(serde_json::from_value(response)?)
})
} else {
Task::ready(Err(anyhow!("No active model set")))
}
}
pub fn active_model_telemetry_id(&self) -> Option<String> {
self.active_model.as_ref().map(|m| m.telemetry_id())
}
} }
#[cfg(test)] #[cfg(test)]

View file

@ -16,6 +16,8 @@ pub use model::*;
pub use registry::*; pub use registry::*;
pub use request::*; pub use request::*;
pub use role::*; pub use role::*;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
pub fn init(client: Arc<Client>, cx: &mut AppContext) { pub fn init(client: Arc<Client>, cx: &mut AppContext) {
settings::init(cx); settings::init(cx);
@ -42,6 +44,20 @@ pub trait LanguageModel: Send + Sync {
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>; ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn use_tool(
&self,
request: LanguageModelRequest,
name: String,
description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>>;
}
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
fn name() -> String;
fn description() -> String;
} }
pub trait LanguageModelProvider: 'static { pub trait LanguageModelProvider: 'static {

View file

@ -1,5 +1,9 @@
use anthropic::stream_completion; use crate::{
use anyhow::{anyhow, Result}; settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
};
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap; use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle}; use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@ -15,12 +19,6 @@ use theme::ThemeSettings;
use ui::prelude::*; use ui::prelude::*;
use util::ResultExt; use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
};
const PROVIDER_ID: &str = "anthropic"; const PROVIDER_ID: &str = "anthropic";
const PROVIDER_NAME: &str = "Anthropic"; const PROVIDER_NAME: &str = "Anthropic";
@ -188,6 +186,61 @@ pub fn count_anthropic_tokens(
.boxed() .boxed()
} }
impl AnthropicModel {
fn request_completion(
&self,
request: anthropic::Request,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<anthropic::Response>> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
anthropic::complete(http_client.as_ref(), &api_url, &api_key, request).await
}
.boxed()
}
fn stream_completion(
&self,
request: anthropic::Request,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event>>>> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(
state.api_key.clone(),
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = anthropic::stream_completion(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
request.await
}
.boxed()
}
}
impl LanguageModel for AnthropicModel { impl LanguageModel for AnthropicModel {
fn id(&self) -> LanguageModelId { fn id(&self) -> LanguageModelId {
self.id.clone() self.id.clone()
@ -227,34 +280,53 @@ impl LanguageModel for AnthropicModel {
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = request.into_anthropic(self.model.id().into()); let request = request.into_anthropic(self.model.id().into());
let request = self.stream_completion(request, cx);
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(
state.api_key.clone(),
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move { async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
let response = request.await?; let response = request.await?;
Ok(anthropic::extract_text_from_events(response).boxed()) Ok(anthropic::extract_text_from_events(response).boxed())
} }
.boxed() .boxed()
} }
fn use_tool(
&self,
request: LanguageModelRequest,
tool_name: String,
tool_description: String,
input_schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
let mut request = request.into_anthropic(self.model.id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});
request.tools = vec![anthropic::Tool {
name: tool_name.clone(),
description: tool_description,
input_schema,
}];
let response = self.request_completion(request, cx);
async move {
let response = response.await?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
} else {
None
}
} else {
None
}
})
.context("tool not used")
}
.boxed()
}
} }
struct AuthenticationPrompt { struct AuthenticationPrompt {

View file

@ -4,7 +4,7 @@ use crate::{
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderState, LanguageModelRequest,
}; };
use anyhow::Result; use anyhow::{anyhow, Context as _, Result};
use client::Client; use client::Client;
use collections::BTreeMap; use collections::BTreeMap;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@ -12,7 +12,7 @@ use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::sync::Arc; use std::{future, sync::Arc};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use ui::prelude::*; use ui::prelude::*;
@ -234,15 +234,13 @@ impl LanguageModel for CloudLanguageModel {
}; };
async move { async move {
let request = serde_json::to_string(&request)?; let request = serde_json::to_string(&request)?;
let response = client.request(proto::QueryLanguageModel { let response = client
.request(proto::CountLanguageModelTokens {
provider: proto::LanguageModelProvider::Google as i32, provider: proto::LanguageModelProvider::Google as i32,
kind: proto::LanguageModelRequestKind::CountTokens as i32,
request, request,
}); })
let response = response.await?; .await?;
let response = Ok(response.token_count as usize)
serde_json::from_str::<google_ai::CountTokensResponse>(&response.response)?;
Ok(response.total_tokens)
} }
.boxed() .boxed()
} }
@ -260,14 +258,14 @@ impl LanguageModel for CloudLanguageModel {
let request = request.into_anthropic(model.id().into()); let request = request.into_anthropic(model.id().into());
async move { async move {
let request = serde_json::to_string(&request)?; let request = serde_json::to_string(&request)?;
let response = client.request_stream(proto::QueryLanguageModel { let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32, provider: proto::LanguageModelProvider::Anthropic as i32,
kind: proto::LanguageModelRequestKind::Complete as i32,
request, request,
}); })
let chunks = response.await?; .await?;
Ok(anthropic::extract_text_from_events( Ok(anthropic::extract_text_from_events(
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)), stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
) )
.boxed()) .boxed())
} }
@ -278,14 +276,14 @@ impl LanguageModel for CloudLanguageModel {
let request = request.into_open_ai(model.id().into()); let request = request.into_open_ai(model.id().into());
async move { async move {
let request = serde_json::to_string(&request)?; let request = serde_json::to_string(&request)?;
let response = client.request_stream(proto::QueryLanguageModel { let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32, provider: proto::LanguageModelProvider::OpenAi as i32,
kind: proto::LanguageModelRequestKind::Complete as i32,
request, request,
}); })
let chunks = response.await?; .await?;
Ok(open_ai::extract_text_from_events( Ok(open_ai::extract_text_from_events(
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)), stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
) )
.boxed()) .boxed())
} }
@ -296,14 +294,14 @@ impl LanguageModel for CloudLanguageModel {
let request = request.into_google(model.id().into()); let request = request.into_google(model.id().into());
async move { async move {
let request = serde_json::to_string(&request)?; let request = serde_json::to_string(&request)?;
let response = client.request_stream(proto::QueryLanguageModel { let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Google as i32, provider: proto::LanguageModelProvider::Google as i32,
kind: proto::LanguageModelRequestKind::Complete as i32,
request, request,
}); })
let chunks = response.await?; .await?;
Ok(google_ai::extract_text_from_events( Ok(google_ai::extract_text_from_events(
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)), stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
) )
.boxed()) .boxed())
} }
@ -311,6 +309,63 @@ impl LanguageModel for CloudLanguageModel {
} }
} }
} }
fn use_tool(
&self,
request: LanguageModelRequest,
tool_name: String,
tool_description: String,
input_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
match &self.model {
CloudModel::Anthropic(model) => {
let client = self.client.clone();
let mut request = request.into_anthropic(model.id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});
request.tools = vec![anthropic::Tool {
name: tool_name.clone(),
description: tool_description,
input_schema,
}];
async move {
let request = serde_json::to_string(&request)?;
let response = client
.request(proto::CompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
let response: anthropic::Response = serde_json::from_str(&response.completion)?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
} else {
None
}
} else {
None
}
})
.context("tool not used")
}
.boxed()
}
CloudModel::OpenAi(_) => {
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
}
CloudModel::Google(_) => {
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
}
}
}
} }
struct AuthenticationPrompt { struct AuthenticationPrompt {

View file

@ -1,15 +1,17 @@
use std::sync::{Arc, Mutex};
use collections::HashMap;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use crate::{ use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelRequest,
}; };
use anyhow::anyhow;
use collections::HashMap;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Task}; use gpui::{AnyView, AppContext, AsyncAppContext, Task};
use http_client::Result; use http_client::Result;
use std::{
future,
sync::{Arc, Mutex},
};
use ui::WindowContext; use ui::WindowContext;
pub fn language_model_id() -> LanguageModelId { pub fn language_model_id() -> LanguageModelId {
@ -170,4 +172,15 @@ impl LanguageModel for FakeLanguageModel {
.insert(serde_json::to_string(&request).unwrap(), tx); .insert(serde_json::to_string(&request).unwrap(), tx);
async move { Ok(rx.map(Ok).boxed()) }.boxed() async move { Ok(rx.map(Ok).boxed()) }.boxed()
} }
fn use_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
} }

View file

@ -9,7 +9,7 @@ use gpui::{
}; };
use http_client::HttpClient; use http_client::HttpClient;
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration}; use std::{future, sync::Arc, time::Duration};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::prelude::*; use ui::prelude::*;
@ -238,6 +238,17 @@ impl LanguageModel for GoogleLanguageModel {
} }
.boxed() .boxed()
} }
fn use_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
} }
struct AuthenticationPrompt { struct AuthenticationPrompt {

View file

@ -6,7 +6,7 @@ use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
}; };
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration}; use std::{future, sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, ElevationIndex}; use ui::{prelude::*, ButtonLike, ElevationIndex};
use crate::{ use crate::{
@ -298,6 +298,17 @@ impl LanguageModel for OllamaLanguageModel {
} }
.boxed() .boxed()
} }
fn use_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
} }
struct DownloadOllamaMessage { struct DownloadOllamaMessage {

View file

@ -9,7 +9,7 @@ use gpui::{
use http_client::HttpClient; use http_client::HttpClient;
use open_ai::stream_completion; use open_ai::stream_completion;
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration}; use std::{future, sync::Arc, time::Duration};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::prelude::*; use ui::prelude::*;
@ -225,6 +225,17 @@ impl LanguageModel for OpenAiLanguageModel {
} }
.boxed() .boxed()
} }
fn use_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
} }
pub fn count_open_ai_tokens( pub fn count_open_ai_tokens(

View file

@ -106,19 +106,27 @@ impl LanguageModelRequest {
messages: new_messages messages: new_messages
.into_iter() .into_iter()
.filter_map(|message| { .filter_map(|message| {
Some(anthropic::RequestMessage { Some(anthropic::Message {
role: match message.role { role: match message.role {
Role::User => anthropic::Role::User, Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant, Role::Assistant => anthropic::Role::Assistant,
Role::System => return None, Role::System => return None,
}, },
content: message.content, content: vec![anthropic::Content::Text {
text: message.content,
}],
}) })
}) })
.collect(), .collect(),
stream: true,
max_tokens: 4092, max_tokens: 4092,
system: system_message, system: Some(system_message),
tools: Vec::new(),
tool_choice: None,
metadata: None,
stop_sequences: Vec::new(),
temperature: None,
top_k: None,
top_p: None,
} }
} }
} }

View file

@ -194,8 +194,12 @@ message Envelope {
JoinHostedProject join_hosted_project = 164; JoinHostedProject join_hosted_project = 164;
QueryLanguageModel query_language_model = 224; CompleteWithLanguageModel complete_with_language_model = 226;
QueryLanguageModelResponse query_language_model_response = 225; // current max CompleteWithLanguageModelResponse complete_with_language_model_response = 227;
StreamCompleteWithLanguageModel stream_complete_with_language_model = 228;
StreamCompleteWithLanguageModelResponse stream_complete_with_language_model_response = 229;
CountLanguageModelTokens count_language_model_tokens = 230;
CountLanguageModelTokensResponse count_language_model_tokens_response = 231; // current max
GetCachedEmbeddings get_cached_embeddings = 189; GetCachedEmbeddings get_cached_embeddings = 189;
GetCachedEmbeddingsResponse get_cached_embeddings_response = 190; GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
ComputeEmbeddings compute_embeddings = 191; ComputeEmbeddings compute_embeddings = 191;
@ -267,6 +271,7 @@ message Envelope {
reserved 158 to 161; reserved 158 to 161;
reserved 166 to 169; reserved 166 to 169;
reserved 224 to 225;
} }
// Messages // Messages
@ -2050,10 +2055,31 @@ enum LanguageModelRole {
reserved 3; reserved 3;
} }
message QueryLanguageModel { message CompleteWithLanguageModel {
LanguageModelProvider provider = 1; LanguageModelProvider provider = 1;
LanguageModelRequestKind kind = 2; string request = 2;
string request = 3; }
message CompleteWithLanguageModelResponse {
string completion = 1;
}
message StreamCompleteWithLanguageModel {
LanguageModelProvider provider = 1;
string request = 2;
}
message StreamCompleteWithLanguageModelResponse {
string event = 1;
}
message CountLanguageModelTokens {
LanguageModelProvider provider = 1;
string request = 2;
}
message CountLanguageModelTokensResponse {
uint32 token_count = 1;
} }
enum LanguageModelProvider { enum LanguageModelProvider {
@ -2062,15 +2088,6 @@ enum LanguageModelProvider {
Google = 2; Google = 2;
} }
enum LanguageModelRequestKind {
Complete = 0;
CountTokens = 1;
}
message QueryLanguageModelResponse {
string response = 1;
}
message GetCachedEmbeddings { message GetCachedEmbeddings {
string model = 1; string model = 1;
repeated bytes digests = 2; repeated bytes digests = 2;

View file

@ -294,8 +294,12 @@ messages!(
(PrepareRename, Background), (PrepareRename, Background),
(PrepareRenameResponse, Background), (PrepareRenameResponse, Background),
(ProjectEntryResponse, Foreground), (ProjectEntryResponse, Foreground),
(QueryLanguageModel, Background), (CompleteWithLanguageModel, Background),
(QueryLanguageModelResponse, Background), (CompleteWithLanguageModelResponse, Background),
(StreamCompleteWithLanguageModel, Background),
(StreamCompleteWithLanguageModelResponse, Background),
(CountLanguageModelTokens, Background),
(CountLanguageModelTokensResponse, Background),
(RefreshInlayHints, Foreground), (RefreshInlayHints, Foreground),
(RejoinChannelBuffers, Foreground), (RejoinChannelBuffers, Foreground),
(RejoinChannelBuffersResponse, Foreground), (RejoinChannelBuffersResponse, Foreground),
@ -463,7 +467,12 @@ request_messages!(
(PerformRename, PerformRenameResponse), (PerformRename, PerformRenameResponse),
(Ping, Ack), (Ping, Ack),
(PrepareRename, PrepareRenameResponse), (PrepareRename, PrepareRenameResponse),
(QueryLanguageModel, QueryLanguageModelResponse), (CompleteWithLanguageModel, CompleteWithLanguageModelResponse),
(
StreamCompleteWithLanguageModel,
StreamCompleteWithLanguageModelResponse
),
(CountLanguageModelTokens, CountLanguageModelTokensResponse),
(RefreshInlayHints, Ack), (RefreshInlayHints, Ack),
(RejoinChannelBuffers, RejoinChannelBuffersResponse), (RejoinChannelBuffers, RejoinChannelBuffersResponse),
(RejoinRoom, RejoinRoomResponse), (RejoinRoom, RejoinRoomResponse),