Compare commits

...
Sign in to create a new pull request.

3 commits

13 changed files with 971 additions and 394 deletions

3
Cargo.lock generated
View file

@ -93,6 +93,7 @@ dependencies = [
"anyhow",
"async-trait",
"bincode",
"erased-serde",
"futures 0.3.28",
"gpui",
"isahc",
@ -103,6 +104,7 @@ dependencies = [
"parking_lot 0.11.2",
"parse_duration",
"postage",
"project",
"rand 0.8.5",
"regex",
"rusqlite",
@ -321,6 +323,7 @@ dependencies = [
"regex",
"schemars",
"search",
"semantic_index",
"serde",
"serde_json",
"settings",

View file

@ -11,6 +11,7 @@ doctest = false
[dependencies]
gpui = { path = "../gpui" }
util = { path = "../util" }
project = { path = "../project" }
async-trait.workspace = true
anyhow.workspace = true
futures.workspace = true
@ -29,6 +30,7 @@ tiktoken-rs = "0.5.0"
matrixmultiply = "0.3.7"
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
bincode = "1.3.3"
erased-serde = "0.3"
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }

View file

@ -1,2 +1,50 @@
pub mod completion;
pub mod embedding;
pub mod function_calling;
pub mod skills;
use core::fmt;
use std::fmt::Display;
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}

View file

@ -1,3 +1,4 @@
use crate::{OpenAIUsage, RequestMessage, Role};
use anyhow::{anyhow, Result};
use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
@ -6,48 +7,10 @@ use futures::{
use gpui::executor::Background;
use isahc::{http::StatusCode, Request, RequestExt};
use serde::{Deserialize, Serialize};
use std::{
fmt::{self, Display},
io,
sync::Arc,
};
use std::{io, sync::Arc};
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest {
pub model: String,
@ -61,13 +24,6 @@ pub struct ResponseMessage {
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,

View file

@ -0,0 +1,176 @@
use crate::{OpenAIUsage, RequestMessage, Role};
use anyhow::anyhow;
use erased_serde::serialize_trait_object;
use futures::AsyncReadExt;
use isahc::{http::StatusCode, Request, RequestExt};
use serde::{Deserialize, Serialize};
use std::fmt::Write;
const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
pub trait OpenAIFunction: erased_serde::Serialize {
fn name(&self) -> String;
fn description(&self) -> String;
fn system_prompt(&self) -> String;
fn parameters(&self) -> serde_json::Value;
fn complete(&self, arguments: serde_json::Value) -> anyhow::Result<String>;
}
serialize_trait_object!(OpenAIFunction);
#[derive(Serialize)]
struct OpenAIFunctionCallingRequest {
model: String,
messages: Vec<RequestMessage>,
functions: Vec<Box<dyn OpenAIFunction>>,
temperature: f32,
}
#[derive(Debug, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
impl FunctionCall {
fn arguments(&self) -> anyhow::Result<serde_json::Value> {
serde_json::from_str(&self.arguments).map_err(|err| anyhow!(err))
}
}
#[derive(Debug, Deserialize)]
pub struct FunctionCallingMessage {
pub role: Option<Role>,
pub content: Option<String>,
pub function_call: FunctionCall,
}
#[derive(Debug, Deserialize)]
pub struct FunctionCallingChoice {
pub message: FunctionCallingMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct OpenAIFunctionCallingResponse {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<FunctionCallingChoice>,
pub usage: OpenAIUsage,
}
impl OpenAIFunctionCallingResponse {
fn get_details(&self) -> anyhow::Result<FunctionCallDetails> {
if let Some(choice) = self.choices.first() {
let name = choice.message.function_call.name.clone();
let arguments = choice.message.function_call.arguments()?;
Ok(FunctionCallDetails {
name,
arguments,
message: choice.message.content.clone(),
})
} else {
Err(anyhow!("no function call details available"))
}
}
}
#[derive(Debug)]
pub struct FunctionCallDetails {
pub name: String, // name of function to call
pub message: Option<String>, // message if provided
pub arguments: serde_json::Value, // json object respresenting provided arguments
}
#[derive(Clone)]
pub struct OpenAIFunctionCallingProvider {
api_key: String,
}
impl OpenAIFunctionCallingProvider {
pub fn new(api_key: String) -> Self {
Self { api_key }
}
fn generate_system_message(
&self,
messages: &Vec<RequestMessage>,
functions: &Vec<Box<dyn OpenAIFunction>>,
) -> RequestMessage {
let mut system_message = messages
.iter()
.filter_map(|message| {
if message.role == Role::System {
Some(message.content.clone())
} else {
None
}
})
.collect::<Vec<String>>()
.join("\n");
let function_string = functions
.iter()
.map(|function| format!("'{}'", function.name()))
.collect::<Vec<String>>()
.join(",");
writeln!(
system_message,
"You have access to the following functions: {function_string} you MUST return a function calling response using at one of the below functions."
)
.unwrap();
for function in functions {
writeln!(system_message, "\n{}", function.system_prompt()).unwrap();
}
RequestMessage {
role: Role::System,
content: system_message,
}
}
pub async fn complete(
&self,
model: String,
mut messages: Vec<RequestMessage>,
functions: Vec<Box<dyn OpenAIFunction>>,
) -> anyhow::Result<FunctionCallDetails> {
// TODO: Rename all this.
let mut system_message = vec![self.generate_system_message(&messages, &functions)];
messages.retain(|message| message.role != Role::System);
system_message.extend(messages);
// Lower temperature values, result in less randomness,
// this is helping keep the function calling consistent
let request = OpenAIFunctionCallingRequest {
model,
messages: system_message,
functions,
temperature: 0.0,
};
let json_data = serde_json::to_string(&request)?;
println!("\nREQUEST: {:?}\n", &json_data);
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.body(json_data)?
.send_async()
.await?;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
println!("\nRESPONSE: {:?}\n", &body);
match response.status() {
StatusCode::OK => {
let response_data: OpenAIFunctionCallingResponse = serde_json::from_str(&body)?;
response_data.get_details()
}
_ => Err(anyhow!("open ai function calling failed: {:?}", body)),
}
}
}

50
crates/ai/src/skills.rs Normal file
View file

@ -0,0 +1,50 @@
use crate::function_calling::OpenAIFunction;
use gpui::{AppContext, ModelHandle};
use project::Project;
use serde::{Serialize, Serializer};
use serde_json::json;
pub struct RewritePrompt;
impl Serialize for RewritePrompt {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
json!({"name": self.name(),
"description": self.description(),
"parameters": self.parameters()})
.serialize(serializer)
}
}
impl RewritePrompt {
pub fn load() -> Self {
Self {}
}
}
impl OpenAIFunction for RewritePrompt {
fn name(&self) -> String {
"rewrite_prompt".to_string()
}
fn description(&self) -> String {
"Rewrite prompt given prompt from user".to_string()
}
fn system_prompt(&self) -> String {
"'rewrite_prompt':
If all information is available in the above prompt, and you need no further information.
Rewrite the entire prompt to clarify what should be generated, do not actually complete the users request.
Assume this rewritten message will be passed to another completion agent, to fulfill the users request.".to_string()
}
fn parameters(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"prompt": {}
}
})
}
fn complete(&self, arguments: serde_json::Value) -> anyhow::Result<String> {
Ok(arguments.get("prompt").unwrap().to_string())
}
}

View file

@ -23,6 +23,8 @@ theme = { path = "../theme" }
util = { path = "../util" }
uuid = { version = "1.1.2", features = ["v4"] }
workspace = { path = "../workspace" }
semantic_index = { path = "../semantic_index" }
project = { path = "../project" }
anyhow.workspace = true
chrono = { version = "0.4", features = ["serde"] }

View file

@ -4,7 +4,7 @@ mod codegen;
mod prompts;
mod streaming_diff;
use ai::completion::Role;
use ai::Role;
use anyhow::Result;
pub use assistant_panel::AssistantPanel;
use assistant_settings::OpenAIModel;

View file

@ -1,13 +1,16 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
codegen::{self, Codegen, CodegenKind},
prompts::generate_content_prompt,
prompts::{generate_codegen_planning_prompt, generate_content_prompt},
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
SavedMessage,
};
use ai::completion::{
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
use ai::{
completion::{stream_completion, OpenAICompletionProvider, OpenAIRequest, OPENAI_API_URL},
function_calling::OpenAIFunctionCallingProvider,
skills::RewritePrompt,
};
use ai::{function_calling::OpenAIFunction, RequestMessage};
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
@ -35,7 +38,9 @@ use gpui::{
WindowContext,
};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
use project::Project;
use search::BufferSearchBar;
use semantic_index::{skills::RepositoryContextRetriever, SemanticIndex};
use settings::SettingsStore;
use std::{
cell::{Cell, RefCell},
@ -145,6 +150,8 @@ pub struct AssistantPanel {
include_conversation_in_next_inline_assist: bool,
inline_prompt_history: VecDeque<String>,
_watch_saved_conversations: Task<Result<()>>,
semantic_index: ModelHandle<SemanticIndex>,
project: ModelHandle<Project>,
}
impl AssistantPanel {
@ -154,6 +161,7 @@ impl AssistantPanel {
workspace: WeakViewHandle<Workspace>,
cx: AsyncAppContext,
) -> Task<Result<ViewHandle<Self>>> {
let index = cx.read(|cx| SemanticIndex::global(cx).unwrap());
cx.spawn(|mut cx| async move {
let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
let saved_conversations = SavedConversationMetadata::list(fs.clone())
@ -191,6 +199,9 @@ impl AssistantPanel {
toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
toolbar
});
let project = workspace.project().clone();
let mut this = Self {
workspace: workspace_handle,
active_editor_index: Default::default(),
@ -215,6 +226,8 @@ impl AssistantPanel {
include_conversation_in_next_inline_assist: false,
inline_prompt_history: Default::default(),
_watch_saved_conversations,
semantic_index: index,
project,
};
let mut old_dock_position = this.position(cx);
@ -277,9 +290,10 @@ impl AssistantPanel {
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
let provider = Arc::new(OpenAICompletionProvider::new(
api_key,
api_key.clone(),
cx.background().clone(),
));
let fc_provider = OpenAIFunctionCallingProvider::new(api_key);
let selection = editor.read(cx).selections.newest_anchor().clone();
let codegen_kind = if editor.read(cx).selections.newest::<usize>(cx).is_empty() {
CodegenKind::Generate {
@ -290,8 +304,18 @@ impl AssistantPanel {
range: selection.start..selection.end,
}
};
let project = self.project.clone();
let codegen = cx.add_model(|cx| {
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
Codegen::new(
editor.read(cx).buffer().clone(),
codegen_kind,
provider,
fc_provider,
cx,
project.clone(),
)
});
let measurements = Rc::new(Cell::new(BlockMeasurements::default()));
@ -573,42 +597,74 @@ impl AssistantPanel {
let language_name = language_name.as_deref();
let codegen_kind = pending_assist.codegen.read(cx).kind().clone();
let prompt = generate_content_prompt(
user_prompt.to_string(),
language_name,
&snapshot,
language_range,
cx,
codegen_kind,
);
let index = self.semantic_index.clone();
let mut messages = Vec::new();
let mut model = settings::get::<AssistantSettings>(cx)
.default_open_ai_model
.clone();
if let Some(conversation) = conversation {
let conversation = conversation.read(cx);
let buffer = conversation.buffer.read(cx);
messages.extend(
conversation
.messages(cx)
.map(|message| message.to_open_ai_message(buffer)),
);
model = conversation.model.clone();
}
messages.push(RequestMessage {
role: Role::User,
content: prompt,
pending_assist.codegen.update(cx, |codegen, cx| {
codegen.start(
user_prompt.to_string(),
cx,
language_name,
snapshot,
language_range.clone(),
codegen_kind.clone(),
index,
)
});
let request = OpenAIRequest {
model: model.full_name().into(),
messages,
stream: true,
};
pending_assist
.codegen
.update(cx, |codegen, cx| codegen.start(request, cx));
// let api_key = self.api_key.as_ref().clone().into_inner().clone().unwrap();
// let function_provider = OpenAIFunctionCallingProvider::new(api_key);
// let planning_messages = vec![RequestMessage {
// role: Role::User,
// content: planning_prompt,
// }];
// println!("GETTING HERE");
// let function_call = cx
// .spawn(|this, mut cx| async move {
// let result = function_provider
// .complete("gpt-4".to_string(), planning_messages, functions)
// .await;
// dbg!(&result);
// result
// })
// .detach();
// let function_name = function_call.name.as_str();
// let prompt = match function_name {
// "rewrite_prompt" => {
// let user_prompt = RewritePrompt::load()
// .complete(function_call.arguments)
// .unwrap();
// generate_content_prompt(
// user_prompt.to_string(),
// language_name,
// &snapshot,
// language_range,
// cx,
// codegen_kind,
// )
// }
// _ => {
// todo!();
// }
// };
// let mut messages = Vec::new();
// let mut model = settings::get::<AssistantSettings>(cx)
// .default_open_ai_model
// .clone();
// if let Some(conversation) = conversation {
// let conversation = conversation.read(cx);
// let buffer = conversation.buffer.read(cx);
// messages.extend(
// conversation
// .messages(cx)
// .map(|message| message.to_open_ai_message(buffer)),
// );
// model = conversation.model.clone();
// }
}
fn update_highlights_for_editor(

View file

@ -1,12 +1,22 @@
use crate::streaming_diff::{Hunk, StreamingDiff};
use ai::completion::{CompletionProvider, OpenAIRequest};
use crate::{
prompts::{generate_codegen_planning_prompt, generate_content_prompt},
streaming_diff::{Hunk, StreamingDiff},
};
use ai::{
completion::{CompletionProvider, OpenAIRequest},
function_calling::{OpenAIFunction, OpenAIFunctionCallingProvider},
skills::RewritePrompt,
RequestMessage, Role,
};
use anyhow::Result;
use editor::{
multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use gpui::{Entity, ModelContext, ModelHandle, Task};
use language::{Rope, TransactionId};
use gpui::{BorrowAppContext, Entity, ModelContext, ModelHandle, Task};
use language::{BufferSnapshot, Rope, TransactionId};
use project::Project;
use semantic_index::{skills::RepositoryContextRetriever, SemanticIndex};
use std::{cmp, future, ops::Range, sync::Arc};
pub enum Event {
@ -22,6 +32,7 @@ pub enum CodegenKind {
pub struct Codegen {
provider: Arc<dyn CompletionProvider>,
fc_provider: OpenAIFunctionCallingProvider,
buffer: ModelHandle<MultiBuffer>,
snapshot: MultiBufferSnapshot,
kind: CodegenKind,
@ -31,6 +42,7 @@ pub struct Codegen {
generation: Task<()>,
idle: bool,
_subscription: gpui::Subscription,
project: ModelHandle<Project>,
}
impl Entity for Codegen {
@ -42,7 +54,9 @@ impl Codegen {
buffer: ModelHandle<MultiBuffer>,
mut kind: CodegenKind,
provider: Arc<dyn CompletionProvider>,
fc_provider: OpenAIFunctionCallingProvider,
cx: &mut ModelContext<Self>,
project: ModelHandle<Project>,
) -> Self {
let snapshot = buffer.read(cx).snapshot(cx);
match &mut kind {
@ -62,6 +76,7 @@ impl Codegen {
Self {
provider,
fc_provider,
buffer: buffer.clone(),
snapshot,
kind,
@ -71,6 +86,7 @@ impl Codegen {
idle: true,
generation: Task::ready(()),
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
project,
}
}
@ -112,7 +128,17 @@ impl Codegen {
self.error.as_ref()
}
pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
pub fn start(
&mut self,
prompt: String,
cx: &mut ModelContext<Self>,
language_name: Option<&str>,
buffer: BufferSnapshot,
range: Range<language::Anchor>,
kind: CodegenKind,
index: ModelHandle<SemanticIndex>,
) {
let language_range = range.clone();
let range = self.range();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
@ -126,9 +152,101 @@ impl Codegen {
.next()
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
let response = self.provider.complete(prompt);
let messages = vec![RequestMessage {
role: Role::User,
content: prompt.clone(),
}];
let request = OpenAIRequest {
model: "gpt-4".to_string(),
messages: messages.clone(),
stream: true,
};
let (planning_prompt, outline) = generate_codegen_planning_prompt(
prompt.clone(),
language_name.clone(),
&buffer,
language_range.clone(),
cx,
kind.clone(),
);
let project = self.project.clone();
self.generation = cx.spawn_weak(|this, mut cx| {
// Plan Ahead
let planning_messages = vec![RequestMessage {
role: Role::User,
content: planning_prompt,
}];
let repo_retriever = RepositoryContextRetriever::load(index, project);
let functions: Vec<Box<dyn OpenAIFunction>> = vec![
Box::new(RewritePrompt::load()),
Box::new(repo_retriever.clone()),
];
let completion_provider = self.provider.clone();
let fc_provider = self.fc_provider.clone();
let language_name = language_name.clone();
let language_name = if let Some(language_name) = language_name.clone() {
Some(language_name.to_string())
} else {
None
};
let kind = kind.clone();
async move {
let mut user_prompt = prompt.clone();
let user_prompt = if let Ok(function_call) = fc_provider
.complete("gpt-4".to_string(), planning_messages, functions)
.await
{
let function_name = function_call.name.as_str();
println!("FUNCTION NAME: {:?}", function_name);
let user_prompt = match function_name {
"rewrite_prompt" => {
let user_prompt = RewritePrompt::load()
.complete(function_call.arguments)
.unwrap();
generate_content_prompt(
user_prompt,
language_name,
outline,
kind,
vec![],
)
}
_ => {
let arguments = function_call.arguments.clone();
let snippet = repo_retriever
.complete_test(arguments, &mut cx)
.await
.unwrap();
let snippet = vec![snippet];
generate_content_prompt(prompt, language_name, outline, kind, snippet)
}
};
user_prompt
} else {
user_prompt
};
println!("{:?}", user_prompt.clone());
let messages = vec![RequestMessage {
role: Role::User,
content: user_prompt.clone(),
}];
let request = OpenAIRequest {
model: "gpt-4".to_string(),
messages: messages.clone(),
stream: true,
};
let response = completion_provider.complete(request);
let generate = async {
let mut edit_start = range.start.to_offset(&snapshot);
@ -349,315 +467,317 @@ fn strip_markdown_codeblock(
})
}
#[cfg(test)]
mod tests {
use super::*;
use futures::{
future::BoxFuture,
stream::{self, BoxStream},
};
use gpui::{executor::Deterministic, TestAppContext};
use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use parking_lot::Mutex;
use rand::prelude::*;
use settings::SettingsStore;
use smol::future::FutureExt;
// #[cfg(test)]
// mod tests {
// use super::*;
// use futures::{
// future::BoxFuture,
// stream::{self, BoxStream},
// };
// use gpui::{executor::Deterministic, TestAppContext};
// use indoc::indoc;
// use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
// use parking_lot::Mutex;
// use rand::prelude::*;
// use settings::SettingsStore;
// use smol::future::FutureExt;
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
// #[gpui::test(iterations = 10)]
// async fn test_transform_autoindent(
// cx: &mut TestAppContext,
// mut rng: StdRng,
// deterministic: Arc<Deterministic>,
// ) {
// cx.set_global(cx.read(SettingsStore::test));
// cx.update(language_settings::init);
let text = indoc! {"
fn main() {
let x = 0;
for _ in 0..10 {
x += 1;
}
}
"};
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
});
let provider = Arc::new(TestCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Transform { range },
provider.clone(),
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
// let text = indoc! {"
// fn main() {
// let x = 0;
// for _ in 0..10 {
// x += 1;
// }
// }
// "};
// let buffer =
// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
// let range = buffer.read_with(cx, |buffer, cx| {
// let snapshot = buffer.snapshot(cx);
// snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
// });
// let provider = Arc::new(TestCompletionProvider::new());
// let fc_provider = OpenAIFunctionCallingProvider::new("".to_string());
// let codegen = cx.add_model(|cx| {
// Codegen::new(
// buffer.clone(),
// CodegenKind::Transform { range },
// provider.clone(),
// fc_provider,
// cx,
// )
// });
// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let mut new_text = concat!(
" let mut x = 0;\n",
" while x < 10 {\n",
" x += 1;\n",
" }",
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
// let mut new_text = concat!(
// " let mut x = 0;\n",
// " while x < 10 {\n",
// " x += 1;\n",
// " }",
// );
// while !new_text.is_empty() {
// let max_len = cmp::min(new_text.len(), 10);
// let len = rng.gen_range(1..=max_len);
// let (chunk, suffix) = new_text.split_at(len);
// provider.send_completion(chunk);
// new_text = suffix;
// deterministic.run_until_parked();
// }
// provider.finish_completion();
// deterministic.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
// assert_eq!(
// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
// indoc! {"
// fn main() {
// let mut x = 0;
// while x < 10 {
// x += 1;
// }
// }
// "}
// );
// }
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_past_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
// #[gpui::test(iterations = 10)]
// async fn test_autoindent_when_generating_past_indentation(
// cx: &mut TestAppContext,
// mut rng: StdRng,
// deterministic: Arc<Deterministic>,
// ) {
// cx.set_global(cx.read(SettingsStore::test));
// cx.update(language_settings::init);
let text = indoc! {"
fn main() {
le
}
"};
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))
});
let provider = Arc::new(TestCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
// let text = indoc! {"
// fn main() {
// le
// }
// "};
// let buffer =
// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
// let position = buffer.read_with(cx, |buffer, cx| {
// let snapshot = buffer.snapshot(cx);
// snapshot.anchor_before(Point::new(1, 6))
// });
// let provider = Arc::new(TestCompletionProvider::new());
// let codegen = cx.add_model(|cx| {
// Codegen::new(
// buffer.clone(),
// CodegenKind::Generate { position },
// provider.clone(),
// cx,
// )
// });
// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let mut new_text = concat!(
"t mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
// let mut new_text = concat!(
// "t mut x = 0;\n",
// "while x < 10 {\n",
// " x += 1;\n",
// "}", //
// );
// while !new_text.is_empty() {
// let max_len = cmp::min(new_text.len(), 10);
// let len = rng.gen_range(1..=max_len);
// let (chunk, suffix) = new_text.split_at(len);
// provider.send_completion(chunk);
// new_text = suffix;
// deterministic.run_until_parked();
// }
// provider.finish_completion();
// deterministic.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
// assert_eq!(
// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
// indoc! {"
// fn main() {
// let mut x = 0;
// while x < 10 {
// x += 1;
// }
// }
// "}
// );
// }
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_before_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
// #[gpui::test(iterations = 10)]
// async fn test_autoindent_when_generating_before_indentation(
// cx: &mut TestAppContext,
// mut rng: StdRng,
// deterministic: Arc<Deterministic>,
// ) {
// cx.set_global(cx.read(SettingsStore::test));
// cx.update(language_settings::init);
let text = concat!(
"fn main() {\n",
" \n",
"}\n" //
);
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))
});
let provider = Arc::new(TestCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
// let text = concat!(
// "fn main() {\n",
// " \n",
// "}\n" //
// );
// let buffer =
// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
// let position = buffer.read_with(cx, |buffer, cx| {
// let snapshot = buffer.snapshot(cx);
// snapshot.anchor_before(Point::new(1, 2))
// });
// let provider = Arc::new(TestCompletionProvider::new());
// let codegen = cx.add_model(|cx| {
// Codegen::new(
// buffer.clone(),
// CodegenKind::Generate { position },
// provider.clone(),
// cx,
// )
// });
// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let mut new_text = concat!(
"let mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
// let mut new_text = concat!(
// "let mut x = 0;\n",
// "while x < 10 {\n",
// " x += 1;\n",
// "}", //
// );
// while !new_text.is_empty() {
// let max_len = cmp::min(new_text.len(), 10);
// let len = rng.gen_range(1..=max_len);
// let (chunk, suffix) = new_text.split_at(len);
// provider.send_completion(chunk);
// new_text = suffix;
// deterministic.run_until_parked();
// }
// provider.finish_completion();
// deterministic.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
// assert_eq!(
// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
// indoc! {"
// fn main() {
// let mut x = 0;
// while x < 10 {
// x += 1;
// }
// }
// "}
// );
// }
#[gpui::test]
async fn test_strip_markdown_codeblock() {
assert_eq!(
strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"```js\nLorem ipsum dolor\n```"
);
assert_eq!(
strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"``\nLorem ipsum dolor\n```"
);
// #[gpui::test]
// async fn test_strip_markdown_codeblock() {
// assert_eq!(
// strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "Lorem ipsum dolor"
// );
// assert_eq!(
// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "Lorem ipsum dolor"
// );
// assert_eq!(
// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "Lorem ipsum dolor"
// );
// assert_eq!(
// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "Lorem ipsum dolor"
// );
// assert_eq!(
// strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "```js\nLorem ipsum dolor\n```"
// );
// assert_eq!(
// strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "``\nLorem ipsum dolor\n```"
// );
fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
stream::iter(
text.chars()
.collect::<Vec<_>>()
.chunks(size)
.map(|chunk| Ok(chunk.iter().collect::<String>()))
.collect::<Vec<_>>(),
)
}
}
// fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
// stream::iter(
// text.chars()
// .collect::<Vec<_>>()
// .chunks(size)
// .map(|chunk| Ok(chunk.iter().collect::<String>()))
// .collect::<Vec<_>>(),
// )
// }
// }
struct TestCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
// struct TestCompletionProvider {
// last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
// }
impl TestCompletionProvider {
fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
// impl TestCompletionProvider {
// fn new() -> Self {
// Self {
// last_completion_tx: Mutex::new(None),
// }
// }
fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
// fn send_completion(&self, completion: impl Into<String>) {
// let mut tx = self.last_completion_tx.lock();
// tx.as_mut().unwrap().try_send(completion.into()).unwrap();
// }
fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
// fn finish_completion(&self) {
// self.last_completion_tx.lock().take().unwrap();
// }
// }
impl CompletionProvider for TestCompletionProvider {
fn complete(
&self,
_prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
}
// impl CompletionProvider for TestCompletionProvider {
// fn complete(
// &self,
// _prompt: OpenAIRequest,
// ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
// let (tx, rx) = mpsc::channel(1);
// *self.last_completion_tx.lock() = Some(tx);
// async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
// }
// }
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_indents_query(
r#"
(call_expression) @indent
(field_expression) @indent
(_ "(" ")" @end) @indent
(_ "{" "}" @end) @indent
"#,
)
.unwrap()
}
}
// fn rust_lang() -> Language {
// Language::new(
// LanguageConfig {
// name: "Rust".into(),
// path_suffixes: vec!["rs".to_string()],
// ..Default::default()
// },
// Some(tree_sitter_rust::language()),
// )
// .with_indents_query(
// r#"
// (call_expression) @indent
// (field_expression) @indent
// (_ "(" ")" @end) @indent
// (_ "{" "}" @end) @indent
// "#,
// )
// .unwrap()
// }
// }

View file

@ -1,4 +1,4 @@
use gpui::AppContext;
use gpui::{AppContext, AsyncAppContext};
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp;
use std::ops::Range;
@ -83,28 +83,28 @@ fn outline_for_prompt(
Some(text)
}
pub fn generate_content_prompt(
pub fn generate_codegen_planning_prompt(
user_prompt: String,
language_name: Option<&str>,
buffer: &BufferSnapshot,
range: Range<language::Anchor>,
cx: &AppContext,
kind: CodegenKind,
) -> String {
) -> (String, Option<String>) {
let mut prompt = String::new();
// General Preamble
if let Some(language_name) = language_name {
writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
} else {
writeln!(prompt, "You're an expert engineer.\n").unwrap();
writeln!(prompt, "You're an expert software engineer.\n").unwrap();
}
let outline = outline_for_prompt(buffer, range.clone(), cx);
if let Some(outline) = outline {
if let Some(outline) = outline.clone() {
writeln!(
prompt,
"The file you are currently working on has the following outline:"
"You're currently working inside the Zed editor on a file with the following outline:"
)
.unwrap();
if let Some(language_name) = language_name {
@ -115,15 +115,71 @@ pub fn generate_content_prompt(
}
}
// Assume for now that we are just generating
if range.clone().start == range.end {
writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap();
match kind {
CodegenKind::Generate { position: _ } => {
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.").unwrap();
}
CodegenKind::Transform { range: _ } => {
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
}
}
writeln!(
prompt,
"The user has provided the following prompt: '{user_prompt}'\n"
)
.unwrap();
writeln!(
prompt,
"It is your task to identify if any additional context is needed from the repository"
)
.unwrap();
(prompt, outline)
}
pub fn generate_content_prompt(
user_prompt: String,
language_name: Option<String>,
outline: Option<String>,
kind: CodegenKind,
snippet: Vec<String>,
) -> String {
let mut prompt = String::new();
// General Preamble
if let Some(language_name) = language_name.clone() {
writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
} else {
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
writeln!(prompt, "You're an expert software engineer.\n").unwrap();
}
if snippet.len() > 0 {
writeln!(
prompt,
"Here are a few snippets from the codebase which may help: "
);
}
for snip in snippet {
writeln!(prompt, "{snip}");
}
if let Some(outline) = outline {
writeln!(
prompt,
"The file you are currently working on has the following outline:"
)
.unwrap();
if let Some(language_name) = language_name.clone() {
let language_name = language_name.to_lowercase();
writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
} else {
writeln!(prompt, "```\n{outline}\n```").unwrap();
}
}
match kind {
CodegenKind::Generate { position: _ } => {
writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap();
writeln!(
prompt,
"Assume the cursor is located where the `<|START|` marker is."
@ -141,6 +197,7 @@ pub fn generate_content_prompt(
.unwrap();
}
CodegenKind::Transform { range: _ } => {
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
writeln!(
prompt,
"Modify the users code selected text based upon the users prompt: {user_prompt}"

View file

@ -2,6 +2,7 @@ mod db;
mod embedding_queue;
mod parsing;
pub mod semantic_index_settings;
pub mod skills;
#[cfg(test)]
mod semantic_index_tests;

View file

@ -0,0 +1,106 @@
use ai::function_calling::OpenAIFunction;
use anyhow::anyhow;
use gpui::{AppContext, AsyncAppContext, ModelHandle};
use project::Project;
use serde::{Serialize, Serializer};
use serde_json::json;
use std::fmt::Write;
use crate::SemanticIndex;
#[derive(Clone)]
pub struct RepositoryContextRetriever {
index: ModelHandle<SemanticIndex>,
project: ModelHandle<Project>,
}
impl RepositoryContextRetriever {
pub fn load(index: ModelHandle<SemanticIndex>, project: ModelHandle<Project>) -> Self {
Self { index, project }
}
pub async fn complete_test(
&self,
arguments: serde_json::Value,
cx: &mut AsyncAppContext,
) -> anyhow::Result<String> {
let queries = arguments.get("queries").unwrap().as_array().unwrap();
let mut prompt = String::new();
let query = queries
.iter()
.map(|query| query.to_string())
.collect::<Vec<String>>()
.join(";");
let project = self.project.clone();
let results = self
.index
.update(cx, |this, cx| {
this.search_project(project, query, 10, vec![], vec![], cx)
})
.await?;
for result in results {
result.buffer.read_with(cx, |buffer, cx| {
let text = buffer.text_for_range(result.range).collect::<String>();
let file_path = buffer.file().unwrap().path().to_string_lossy();
let language = buffer.language();
writeln!(
prompt,
"The following is a relevant snippet from file ({}):",
file_path
)
.unwrap();
if let Some(language) = language {
writeln!(prompt, "```{}\n{text}\n```", language.name().to_lowercase()).unwrap();
} else {
writeln!(prompt, "```\n{text}\n```").unwrap();
}
});
}
Ok(prompt)
}
}
impl OpenAIFunction for RepositoryContextRetriever {
fn name(&self) -> String {
"retrieve_context_from_repository".to_string()
}
fn description(&self) -> String {
"Retrieve relevant content from repository with natural language".to_string()
}
fn system_prompt(&self) -> String {
"'retrieve_context_from_repository'
If more information is needed from the repository, to complete the users prompt reliably, pass up to 3 queries describing pieces of code or text you would like additional context upon.
Do not make these queries general about programming, include very specific lexical references to the pieces of code you need more information on.
We are passing these into a semantic similarity retrieval engine, with all the information in the current codebase included.
As such, these should be phrased as descriptions of code of interest as opposed to questions".to_string()
}
fn parameters(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"queries": {
"title": "queries",
"type": "array",
"items": {"type": "string"}
}
},
"required": ["queries"]
})
}
fn complete(&self, arguments: serde_json::Value) -> anyhow::Result<String> {
todo!();
}
}
impl Serialize for RepositoryContextRetriever {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
json!({"name": self.name(),
"description": self.description(),
"parameters": self.parameters()})
.serialize(serializer)
}
}