Compare commits
3 commits
main
...
semantic_c
Author | SHA1 | Date | |
---|---|---|---|
![]() |
a29b70552f | ||
![]() |
3cc499ee7c | ||
![]() |
e21a92284f |
13 changed files with 971 additions and 394 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -93,6 +93,7 @@ dependencies = [
|
|||
"anyhow",
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"erased-serde",
|
||||
"futures 0.3.28",
|
||||
"gpui",
|
||||
"isahc",
|
||||
|
@ -103,6 +104,7 @@ dependencies = [
|
|||
"parking_lot 0.11.2",
|
||||
"parse_duration",
|
||||
"postage",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"rusqlite",
|
||||
|
@ -321,6 +323,7 @@ dependencies = [
|
|||
"regex",
|
||||
"schemars",
|
||||
"search",
|
||||
"semantic_index",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
|
|
|
@ -11,6 +11,7 @@ doctest = false
|
|||
[dependencies]
|
||||
gpui = { path = "../gpui" }
|
||||
util = { path = "../util" }
|
||||
project = { path = "../project" }
|
||||
async-trait.workspace = true
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
|
@ -29,6 +30,7 @@ tiktoken-rs = "0.5.0"
|
|||
matrixmultiply = "0.3.7"
|
||||
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
|
||||
bincode = "1.3.3"
|
||||
erased-serde = "0.3"
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
|
|
|
@ -1,2 +1,50 @@
|
|||
pub mod completion;
|
||||
pub mod embedding;
|
||||
pub mod function_calling;
|
||||
pub mod skills;
|
||||
|
||||
use core::fmt;
|
||||
use std::fmt::Display;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn cycle(&mut self) {
|
||||
*self = match self {
|
||||
Role::User => Role::Assistant,
|
||||
Role::Assistant => Role::System,
|
||||
Role::System => Role::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Role::User => write!(f, "User"),
|
||||
Role::Assistant => write!(f, "Assistant"),
|
||||
Role::System => write!(f, "System"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAIUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::{OpenAIUsage, RequestMessage, Role};
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{
|
||||
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
||||
|
@ -6,48 +7,10 @@ use futures::{
|
|||
use gpui::executor::Background;
|
||||
use isahc::{http::StatusCode, Request, RequestExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fmt::{self, Display},
|
||||
io,
|
||||
sync::Arc,
|
||||
};
|
||||
use std::{io, sync::Arc};
|
||||
|
||||
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn cycle(&mut self) {
|
||||
*self = match self {
|
||||
Role::User => Role::Assistant,
|
||||
Role::Assistant => Role::System,
|
||||
Role::System => Role::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Role::User => write!(f, "User"),
|
||||
Role::Assistant => write!(f, "Assistant"),
|
||||
Role::System => write!(f, "System"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
pub struct OpenAIRequest {
|
||||
pub model: String,
|
||||
|
@ -61,13 +24,6 @@ pub struct ResponseMessage {
|
|||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAIUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ChatChoiceDelta {
|
||||
pub index: u32,
|
||||
|
|
176
crates/ai/src/function_calling.rs
Normal file
176
crates/ai/src/function_calling.rs
Normal file
|
@ -0,0 +1,176 @@
|
|||
use crate::{OpenAIUsage, RequestMessage, Role};
|
||||
use anyhow::anyhow;
|
||||
use erased_serde::serialize_trait_object;
|
||||
use futures::AsyncReadExt;
|
||||
use isahc::{http::StatusCode, Request, RequestExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Write;
|
||||
|
||||
const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
||||
|
||||
pub trait OpenAIFunction: erased_serde::Serialize {
|
||||
fn name(&self) -> String;
|
||||
fn description(&self) -> String;
|
||||
fn system_prompt(&self) -> String;
|
||||
fn parameters(&self) -> serde_json::Value;
|
||||
fn complete(&self, arguments: serde_json::Value) -> anyhow::Result<String>;
|
||||
}
|
||||
serialize_trait_object!(OpenAIFunction);
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OpenAIFunctionCallingRequest {
|
||||
model: String,
|
||||
messages: Vec<RequestMessage>,
|
||||
functions: Vec<Box<dyn OpenAIFunction>>,
|
||||
temperature: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
impl FunctionCall {
|
||||
fn arguments(&self) -> anyhow::Result<serde_json::Value> {
|
||||
serde_json::from_str(&self.arguments).map_err(|err| anyhow!(err))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FunctionCallingMessage {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
pub function_call: FunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FunctionCallingChoice {
|
||||
pub message: FunctionCallingMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OpenAIFunctionCallingResponse {
|
||||
pub id: Option<String>,
|
||||
pub object: String,
|
||||
pub created: u32,
|
||||
pub model: String,
|
||||
pub choices: Vec<FunctionCallingChoice>,
|
||||
pub usage: OpenAIUsage,
|
||||
}
|
||||
|
||||
impl OpenAIFunctionCallingResponse {
|
||||
fn get_details(&self) -> anyhow::Result<FunctionCallDetails> {
|
||||
if let Some(choice) = self.choices.first() {
|
||||
let name = choice.message.function_call.name.clone();
|
||||
let arguments = choice.message.function_call.arguments()?;
|
||||
|
||||
Ok(FunctionCallDetails {
|
||||
name,
|
||||
arguments,
|
||||
message: choice.message.content.clone(),
|
||||
})
|
||||
} else {
|
||||
Err(anyhow!("no function call details available"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FunctionCallDetails {
|
||||
pub name: String, // name of function to call
|
||||
pub message: Option<String>, // message if provided
|
||||
pub arguments: serde_json::Value, // json object respresenting provided arguments
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAIFunctionCallingProvider {
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl OpenAIFunctionCallingProvider {
|
||||
pub fn new(api_key: String) -> Self {
|
||||
Self { api_key }
|
||||
}
|
||||
|
||||
fn generate_system_message(
|
||||
&self,
|
||||
messages: &Vec<RequestMessage>,
|
||||
functions: &Vec<Box<dyn OpenAIFunction>>,
|
||||
) -> RequestMessage {
|
||||
let mut system_message = messages
|
||||
.iter()
|
||||
.filter_map(|message| {
|
||||
if message.role == Role::System {
|
||||
Some(message.content.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
let function_string = functions
|
||||
.iter()
|
||||
.map(|function| format!("'{}'", function.name()))
|
||||
.collect::<Vec<String>>()
|
||||
.join(",");
|
||||
|
||||
writeln!(
|
||||
system_message,
|
||||
"You have access to the following functions: {function_string} you MUST return a function calling response using at one of the below functions."
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
for function in functions {
|
||||
writeln!(system_message, "\n{}", function.system_prompt()).unwrap();
|
||||
}
|
||||
|
||||
RequestMessage {
|
||||
role: Role::System,
|
||||
content: system_message,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn complete(
|
||||
&self,
|
||||
model: String,
|
||||
mut messages: Vec<RequestMessage>,
|
||||
functions: Vec<Box<dyn OpenAIFunction>>,
|
||||
) -> anyhow::Result<FunctionCallDetails> {
|
||||
// TODO: Rename all this.
|
||||
let mut system_message = vec![self.generate_system_message(&messages, &functions)];
|
||||
messages.retain(|message| message.role != Role::System);
|
||||
|
||||
system_message.extend(messages);
|
||||
// Lower temperature values, result in less randomness,
|
||||
// this is helping keep the function calling consistent
|
||||
let request = OpenAIFunctionCallingRequest {
|
||||
model,
|
||||
messages: system_message,
|
||||
functions,
|
||||
temperature: 0.0,
|
||||
};
|
||||
|
||||
let json_data = serde_json::to_string(&request)?;
|
||||
println!("\nREQUEST: {:?}\n", &json_data);
|
||||
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.body(json_data)?
|
||||
.send_async()
|
||||
.await?;
|
||||
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
println!("\nRESPONSE: {:?}\n", &body);
|
||||
match response.status() {
|
||||
StatusCode::OK => {
|
||||
let response_data: OpenAIFunctionCallingResponse = serde_json::from_str(&body)?;
|
||||
response_data.get_details()
|
||||
}
|
||||
_ => Err(anyhow!("open ai function calling failed: {:?}", body)),
|
||||
}
|
||||
}
|
||||
}
|
50
crates/ai/src/skills.rs
Normal file
50
crates/ai/src/skills.rs
Normal file
|
@ -0,0 +1,50 @@
|
|||
use crate::function_calling::OpenAIFunction;
|
||||
use gpui::{AppContext, ModelHandle};
|
||||
use project::Project;
|
||||
use serde::{Serialize, Serializer};
|
||||
use serde_json::json;
|
||||
|
||||
pub struct RewritePrompt;
|
||||
impl Serialize for RewritePrompt {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
json!({"name": self.name(),
|
||||
"description": self.description(),
|
||||
"parameters": self.parameters()})
|
||||
.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl RewritePrompt {
|
||||
pub fn load() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIFunction for RewritePrompt {
|
||||
fn name(&self) -> String {
|
||||
"rewrite_prompt".to_string()
|
||||
}
|
||||
fn description(&self) -> String {
|
||||
"Rewrite prompt given prompt from user".to_string()
|
||||
}
|
||||
fn system_prompt(&self) -> String {
|
||||
"'rewrite_prompt':
|
||||
If all information is available in the above prompt, and you need no further information.
|
||||
Rewrite the entire prompt to clarify what should be generated, do not actually complete the users request.
|
||||
Assume this rewritten message will be passed to another completion agent, to fulfill the users request.".to_string()
|
||||
}
|
||||
fn parameters(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {}
|
||||
}
|
||||
})
|
||||
}
|
||||
fn complete(&self, arguments: serde_json::Value) -> anyhow::Result<String> {
|
||||
Ok(arguments.get("prompt").unwrap().to_string())
|
||||
}
|
||||
}
|
|
@ -23,6 +23,8 @@ theme = { path = "../theme" }
|
|||
util = { path = "../util" }
|
||||
uuid = { version = "1.1.2", features = ["v4"] }
|
||||
workspace = { path = "../workspace" }
|
||||
semantic_index = { path = "../semantic_index" }
|
||||
project = { path = "../project" }
|
||||
|
||||
anyhow.workspace = true
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
|
|
|
@ -4,7 +4,7 @@ mod codegen;
|
|||
mod prompts;
|
||||
mod streaming_diff;
|
||||
|
||||
use ai::completion::Role;
|
||||
use ai::Role;
|
||||
use anyhow::Result;
|
||||
pub use assistant_panel::AssistantPanel;
|
||||
use assistant_settings::OpenAIModel;
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
use crate::{
|
||||
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
|
||||
codegen::{self, Codegen, CodegenKind},
|
||||
prompts::generate_content_prompt,
|
||||
prompts::{generate_codegen_planning_prompt, generate_content_prompt},
|
||||
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
|
||||
SavedMessage,
|
||||
};
|
||||
use ai::completion::{
|
||||
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
|
||||
use ai::{
|
||||
completion::{stream_completion, OpenAICompletionProvider, OpenAIRequest, OPENAI_API_URL},
|
||||
function_calling::OpenAIFunctionCallingProvider,
|
||||
skills::RewritePrompt,
|
||||
};
|
||||
use ai::{function_calling::OpenAIFunction, RequestMessage};
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::{DateTime, Local};
|
||||
use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
|
||||
|
@ -35,7 +38,9 @@ use gpui::{
|
|||
WindowContext,
|
||||
};
|
||||
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
|
||||
use project::Project;
|
||||
use search::BufferSearchBar;
|
||||
use semantic_index::{skills::RepositoryContextRetriever, SemanticIndex};
|
||||
use settings::SettingsStore;
|
||||
use std::{
|
||||
cell::{Cell, RefCell},
|
||||
|
@ -145,6 +150,8 @@ pub struct AssistantPanel {
|
|||
include_conversation_in_next_inline_assist: bool,
|
||||
inline_prompt_history: VecDeque<String>,
|
||||
_watch_saved_conversations: Task<Result<()>>,
|
||||
semantic_index: ModelHandle<SemanticIndex>,
|
||||
project: ModelHandle<Project>,
|
||||
}
|
||||
|
||||
impl AssistantPanel {
|
||||
|
@ -154,6 +161,7 @@ impl AssistantPanel {
|
|||
workspace: WeakViewHandle<Workspace>,
|
||||
cx: AsyncAppContext,
|
||||
) -> Task<Result<ViewHandle<Self>>> {
|
||||
let index = cx.read(|cx| SemanticIndex::global(cx).unwrap());
|
||||
cx.spawn(|mut cx| async move {
|
||||
let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
|
||||
let saved_conversations = SavedConversationMetadata::list(fs.clone())
|
||||
|
@ -191,6 +199,9 @@ impl AssistantPanel {
|
|||
toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
|
||||
toolbar
|
||||
});
|
||||
|
||||
let project = workspace.project().clone();
|
||||
|
||||
let mut this = Self {
|
||||
workspace: workspace_handle,
|
||||
active_editor_index: Default::default(),
|
||||
|
@ -215,6 +226,8 @@ impl AssistantPanel {
|
|||
include_conversation_in_next_inline_assist: false,
|
||||
inline_prompt_history: Default::default(),
|
||||
_watch_saved_conversations,
|
||||
semantic_index: index,
|
||||
project,
|
||||
};
|
||||
|
||||
let mut old_dock_position = this.position(cx);
|
||||
|
@ -277,9 +290,10 @@ impl AssistantPanel {
|
|||
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
|
||||
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
|
||||
let provider = Arc::new(OpenAICompletionProvider::new(
|
||||
api_key,
|
||||
api_key.clone(),
|
||||
cx.background().clone(),
|
||||
));
|
||||
let fc_provider = OpenAIFunctionCallingProvider::new(api_key);
|
||||
let selection = editor.read(cx).selections.newest_anchor().clone();
|
||||
let codegen_kind = if editor.read(cx).selections.newest::<usize>(cx).is_empty() {
|
||||
CodegenKind::Generate {
|
||||
|
@ -290,8 +304,18 @@ impl AssistantPanel {
|
|||
range: selection.start..selection.end,
|
||||
}
|
||||
};
|
||||
|
||||
let project = self.project.clone();
|
||||
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
|
||||
Codegen::new(
|
||||
editor.read(cx).buffer().clone(),
|
||||
codegen_kind,
|
||||
provider,
|
||||
fc_provider,
|
||||
cx,
|
||||
project.clone(),
|
||||
)
|
||||
});
|
||||
|
||||
let measurements = Rc::new(Cell::new(BlockMeasurements::default()));
|
||||
|
@ -573,42 +597,74 @@ impl AssistantPanel {
|
|||
let language_name = language_name.as_deref();
|
||||
|
||||
let codegen_kind = pending_assist.codegen.read(cx).kind().clone();
|
||||
let prompt = generate_content_prompt(
|
||||
user_prompt.to_string(),
|
||||
language_name,
|
||||
&snapshot,
|
||||
language_range,
|
||||
cx,
|
||||
codegen_kind,
|
||||
);
|
||||
let index = self.semantic_index.clone();
|
||||
|
||||
let mut messages = Vec::new();
|
||||
let mut model = settings::get::<AssistantSettings>(cx)
|
||||
.default_open_ai_model
|
||||
.clone();
|
||||
if let Some(conversation) = conversation {
|
||||
let conversation = conversation.read(cx);
|
||||
let buffer = conversation.buffer.read(cx);
|
||||
messages.extend(
|
||||
conversation
|
||||
.messages(cx)
|
||||
.map(|message| message.to_open_ai_message(buffer)),
|
||||
);
|
||||
model = conversation.model.clone();
|
||||
}
|
||||
|
||||
messages.push(RequestMessage {
|
||||
role: Role::User,
|
||||
content: prompt,
|
||||
pending_assist.codegen.update(cx, |codegen, cx| {
|
||||
codegen.start(
|
||||
user_prompt.to_string(),
|
||||
cx,
|
||||
language_name,
|
||||
snapshot,
|
||||
language_range.clone(),
|
||||
codegen_kind.clone(),
|
||||
index,
|
||||
)
|
||||
});
|
||||
let request = OpenAIRequest {
|
||||
model: model.full_name().into(),
|
||||
messages,
|
||||
stream: true,
|
||||
};
|
||||
pending_assist
|
||||
.codegen
|
||||
.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
|
||||
// let api_key = self.api_key.as_ref().clone().into_inner().clone().unwrap();
|
||||
// let function_provider = OpenAIFunctionCallingProvider::new(api_key);
|
||||
|
||||
// let planning_messages = vec![RequestMessage {
|
||||
// role: Role::User,
|
||||
// content: planning_prompt,
|
||||
// }];
|
||||
|
||||
// println!("GETTING HERE");
|
||||
|
||||
// let function_call = cx
|
||||
// .spawn(|this, mut cx| async move {
|
||||
// let result = function_provider
|
||||
// .complete("gpt-4".to_string(), planning_messages, functions)
|
||||
// .await;
|
||||
// dbg!(&result);
|
||||
// result
|
||||
// })
|
||||
// .detach();
|
||||
|
||||
// let function_name = function_call.name.as_str();
|
||||
// let prompt = match function_name {
|
||||
// "rewrite_prompt" => {
|
||||
// let user_prompt = RewritePrompt::load()
|
||||
// .complete(function_call.arguments)
|
||||
// .unwrap();
|
||||
// generate_content_prompt(
|
||||
// user_prompt.to_string(),
|
||||
// language_name,
|
||||
// &snapshot,
|
||||
// language_range,
|
||||
// cx,
|
||||
// codegen_kind,
|
||||
// )
|
||||
// }
|
||||
// _ => {
|
||||
// todo!();
|
||||
// }
|
||||
// };
|
||||
|
||||
// let mut messages = Vec::new();
|
||||
// let mut model = settings::get::<AssistantSettings>(cx)
|
||||
// .default_open_ai_model
|
||||
// .clone();
|
||||
// if let Some(conversation) = conversation {
|
||||
// let conversation = conversation.read(cx);
|
||||
// let buffer = conversation.buffer.read(cx);
|
||||
// messages.extend(
|
||||
// conversation
|
||||
// .messages(cx)
|
||||
// .map(|message| message.to_open_ai_message(buffer)),
|
||||
// );
|
||||
// model = conversation.model.clone();
|
||||
// }
|
||||
}
|
||||
|
||||
fn update_highlights_for_editor(
|
||||
|
|
|
@ -1,12 +1,22 @@
|
|||
use crate::streaming_diff::{Hunk, StreamingDiff};
|
||||
use ai::completion::{CompletionProvider, OpenAIRequest};
|
||||
use crate::{
|
||||
prompts::{generate_codegen_planning_prompt, generate_content_prompt},
|
||||
streaming_diff::{Hunk, StreamingDiff},
|
||||
};
|
||||
use ai::{
|
||||
completion::{CompletionProvider, OpenAIRequest},
|
||||
function_calling::{OpenAIFunction, OpenAIFunctionCallingProvider},
|
||||
skills::RewritePrompt,
|
||||
RequestMessage, Role,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use editor::{
|
||||
multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
|
||||
};
|
||||
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
|
||||
use gpui::{Entity, ModelContext, ModelHandle, Task};
|
||||
use language::{Rope, TransactionId};
|
||||
use gpui::{BorrowAppContext, Entity, ModelContext, ModelHandle, Task};
|
||||
use language::{BufferSnapshot, Rope, TransactionId};
|
||||
use project::Project;
|
||||
use semantic_index::{skills::RepositoryContextRetriever, SemanticIndex};
|
||||
use std::{cmp, future, ops::Range, sync::Arc};
|
||||
|
||||
pub enum Event {
|
||||
|
@ -22,6 +32,7 @@ pub enum CodegenKind {
|
|||
|
||||
pub struct Codegen {
|
||||
provider: Arc<dyn CompletionProvider>,
|
||||
fc_provider: OpenAIFunctionCallingProvider,
|
||||
buffer: ModelHandle<MultiBuffer>,
|
||||
snapshot: MultiBufferSnapshot,
|
||||
kind: CodegenKind,
|
||||
|
@ -31,6 +42,7 @@ pub struct Codegen {
|
|||
generation: Task<()>,
|
||||
idle: bool,
|
||||
_subscription: gpui::Subscription,
|
||||
project: ModelHandle<Project>,
|
||||
}
|
||||
|
||||
impl Entity for Codegen {
|
||||
|
@ -42,7 +54,9 @@ impl Codegen {
|
|||
buffer: ModelHandle<MultiBuffer>,
|
||||
mut kind: CodegenKind,
|
||||
provider: Arc<dyn CompletionProvider>,
|
||||
fc_provider: OpenAIFunctionCallingProvider,
|
||||
cx: &mut ModelContext<Self>,
|
||||
project: ModelHandle<Project>,
|
||||
) -> Self {
|
||||
let snapshot = buffer.read(cx).snapshot(cx);
|
||||
match &mut kind {
|
||||
|
@ -62,6 +76,7 @@ impl Codegen {
|
|||
|
||||
Self {
|
||||
provider,
|
||||
fc_provider,
|
||||
buffer: buffer.clone(),
|
||||
snapshot,
|
||||
kind,
|
||||
|
@ -71,6 +86,7 @@ impl Codegen {
|
|||
idle: true,
|
||||
generation: Task::ready(()),
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
project,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -112,7 +128,17 @@ impl Codegen {
|
|||
self.error.as_ref()
|
||||
}
|
||||
|
||||
pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
|
||||
pub fn start(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
cx: &mut ModelContext<Self>,
|
||||
language_name: Option<&str>,
|
||||
buffer: BufferSnapshot,
|
||||
range: Range<language::Anchor>,
|
||||
kind: CodegenKind,
|
||||
index: ModelHandle<SemanticIndex>,
|
||||
) {
|
||||
let language_range = range.clone();
|
||||
let range = self.range();
|
||||
let snapshot = self.snapshot.clone();
|
||||
let selected_text = snapshot
|
||||
|
@ -126,9 +152,101 @@ impl Codegen {
|
|||
.next()
|
||||
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
|
||||
|
||||
let response = self.provider.complete(prompt);
|
||||
let messages = vec![RequestMessage {
|
||||
role: Role::User,
|
||||
content: prompt.clone(),
|
||||
}];
|
||||
|
||||
let request = OpenAIRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: messages.clone(),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let (planning_prompt, outline) = generate_codegen_planning_prompt(
|
||||
prompt.clone(),
|
||||
language_name.clone(),
|
||||
&buffer,
|
||||
language_range.clone(),
|
||||
cx,
|
||||
kind.clone(),
|
||||
);
|
||||
|
||||
let project = self.project.clone();
|
||||
|
||||
self.generation = cx.spawn_weak(|this, mut cx| {
|
||||
// Plan Ahead
|
||||
let planning_messages = vec![RequestMessage {
|
||||
role: Role::User,
|
||||
content: planning_prompt,
|
||||
}];
|
||||
|
||||
let repo_retriever = RepositoryContextRetriever::load(index, project);
|
||||
let functions: Vec<Box<dyn OpenAIFunction>> = vec![
|
||||
Box::new(RewritePrompt::load()),
|
||||
Box::new(repo_retriever.clone()),
|
||||
];
|
||||
|
||||
let completion_provider = self.provider.clone();
|
||||
let fc_provider = self.fc_provider.clone();
|
||||
let language_name = language_name.clone();
|
||||
let language_name = if let Some(language_name) = language_name.clone() {
|
||||
Some(language_name.to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let kind = kind.clone();
|
||||
async move {
|
||||
let mut user_prompt = prompt.clone();
|
||||
let user_prompt = if let Ok(function_call) = fc_provider
|
||||
.complete("gpt-4".to_string(), planning_messages, functions)
|
||||
.await
|
||||
{
|
||||
let function_name = function_call.name.as_str();
|
||||
println!("FUNCTION NAME: {:?}", function_name);
|
||||
let user_prompt = match function_name {
|
||||
"rewrite_prompt" => {
|
||||
let user_prompt = RewritePrompt::load()
|
||||
.complete(function_call.arguments)
|
||||
.unwrap();
|
||||
generate_content_prompt(
|
||||
user_prompt,
|
||||
language_name,
|
||||
outline,
|
||||
kind,
|
||||
vec![],
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
let arguments = function_call.arguments.clone();
|
||||
let snippet = repo_retriever
|
||||
.complete_test(arguments, &mut cx)
|
||||
.await
|
||||
.unwrap();
|
||||
let snippet = vec![snippet];
|
||||
|
||||
generate_content_prompt(prompt, language_name, outline, kind, snippet)
|
||||
}
|
||||
};
|
||||
user_prompt
|
||||
} else {
|
||||
user_prompt
|
||||
};
|
||||
|
||||
println!("{:?}", user_prompt.clone());
|
||||
|
||||
let messages = vec![RequestMessage {
|
||||
role: Role::User,
|
||||
content: user_prompt.clone(),
|
||||
}];
|
||||
|
||||
let request = OpenAIRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: messages.clone(),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let response = completion_provider.complete(request);
|
||||
let generate = async {
|
||||
let mut edit_start = range.start.to_offset(&snapshot);
|
||||
|
||||
|
@ -349,315 +467,317 @@ fn strip_markdown_codeblock(
|
|||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures::{
|
||||
future::BoxFuture,
|
||||
stream::{self, BoxStream},
|
||||
};
|
||||
use gpui::{executor::Deterministic, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
|
||||
use parking_lot::Mutex;
|
||||
use rand::prelude::*;
|
||||
use settings::SettingsStore;
|
||||
use smol::future::FutureExt;
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use super::*;
|
||||
// use futures::{
|
||||
// future::BoxFuture,
|
||||
// stream::{self, BoxStream},
|
||||
// };
|
||||
// use gpui::{executor::Deterministic, TestAppContext};
|
||||
// use indoc::indoc;
|
||||
// use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
|
||||
// use parking_lot::Mutex;
|
||||
// use rand::prelude::*;
|
||||
// use settings::SettingsStore;
|
||||
// use smol::future::FutureExt;
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_transform_autoindent(
|
||||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
deterministic: Arc<Deterministic>,
|
||||
) {
|
||||
cx.set_global(cx.read(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
// #[gpui::test(iterations = 10)]
|
||||
// async fn test_transform_autoindent(
|
||||
// cx: &mut TestAppContext,
|
||||
// mut rng: StdRng,
|
||||
// deterministic: Arc<Deterministic>,
|
||||
// ) {
|
||||
// cx.set_global(cx.read(SettingsStore::test));
|
||||
// cx.update(language_settings::init);
|
||||
|
||||
let text = indoc! {"
|
||||
fn main() {
|
||||
let x = 0;
|
||||
for _ in 0..10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"};
|
||||
let buffer =
|
||||
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let range = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
|
||||
});
|
||||
let provider = Arc::new(TestCompletionProvider::new());
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
CodegenKind::Transform { range },
|
||||
provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
||||
// let text = indoc! {"
|
||||
// fn main() {
|
||||
// let x = 0;
|
||||
// for _ in 0..10 {
|
||||
// x += 1;
|
||||
// }
|
||||
// }
|
||||
// "};
|
||||
// let buffer =
|
||||
// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
// let range = buffer.read_with(cx, |buffer, cx| {
|
||||
// let snapshot = buffer.snapshot(cx);
|
||||
// snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
|
||||
// });
|
||||
// let provider = Arc::new(TestCompletionProvider::new());
|
||||
// let fc_provider = OpenAIFunctionCallingProvider::new("".to_string());
|
||||
// let codegen = cx.add_model(|cx| {
|
||||
// Codegen::new(
|
||||
// buffer.clone(),
|
||||
// CodegenKind::Transform { range },
|
||||
// provider.clone(),
|
||||
// fc_provider,
|
||||
// cx,
|
||||
// )
|
||||
// });
|
||||
// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
" let mut x = 0;\n",
|
||||
" while x < 10 {\n",
|
||||
" x += 1;\n",
|
||||
" }",
|
||||
);
|
||||
while !new_text.is_empty() {
|
||||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
provider.send_completion(chunk);
|
||||
new_text = suffix;
|
||||
deterministic.run_until_parked();
|
||||
}
|
||||
provider.finish_completion();
|
||||
deterministic.run_until_parked();
|
||||
// let mut new_text = concat!(
|
||||
// " let mut x = 0;\n",
|
||||
// " while x < 10 {\n",
|
||||
// " x += 1;\n",
|
||||
// " }",
|
||||
// );
|
||||
// while !new_text.is_empty() {
|
||||
// let max_len = cmp::min(new_text.len(), 10);
|
||||
// let len = rng.gen_range(1..=max_len);
|
||||
// let (chunk, suffix) = new_text.split_at(len);
|
||||
// provider.send_completion(chunk);
|
||||
// new_text = suffix;
|
||||
// deterministic.run_until_parked();
|
||||
// }
|
||||
// provider.finish_completion();
|
||||
// deterministic.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
let mut x = 0;
|
||||
while x < 10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"}
|
||||
);
|
||||
}
|
||||
// assert_eq!(
|
||||
// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
// indoc! {"
|
||||
// fn main() {
|
||||
// let mut x = 0;
|
||||
// while x < 10 {
|
||||
// x += 1;
|
||||
// }
|
||||
// }
|
||||
// "}
|
||||
// );
|
||||
// }
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_autoindent_when_generating_past_indentation(
|
||||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
deterministic: Arc<Deterministic>,
|
||||
) {
|
||||
cx.set_global(cx.read(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
// #[gpui::test(iterations = 10)]
|
||||
// async fn test_autoindent_when_generating_past_indentation(
|
||||
// cx: &mut TestAppContext,
|
||||
// mut rng: StdRng,
|
||||
// deterministic: Arc<Deterministic>,
|
||||
// ) {
|
||||
// cx.set_global(cx.read(SettingsStore::test));
|
||||
// cx.update(language_settings::init);
|
||||
|
||||
let text = indoc! {"
|
||||
fn main() {
|
||||
le
|
||||
}
|
||||
"};
|
||||
let buffer =
|
||||
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let position = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 6))
|
||||
});
|
||||
let provider = Arc::new(TestCompletionProvider::new());
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
CodegenKind::Generate { position },
|
||||
provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
||||
// let text = indoc! {"
|
||||
// fn main() {
|
||||
// le
|
||||
// }
|
||||
// "};
|
||||
// let buffer =
|
||||
// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
// let position = buffer.read_with(cx, |buffer, cx| {
|
||||
// let snapshot = buffer.snapshot(cx);
|
||||
// snapshot.anchor_before(Point::new(1, 6))
|
||||
// });
|
||||
// let provider = Arc::new(TestCompletionProvider::new());
|
||||
// let codegen = cx.add_model(|cx| {
|
||||
// Codegen::new(
|
||||
// buffer.clone(),
|
||||
// CodegenKind::Generate { position },
|
||||
// provider.clone(),
|
||||
// cx,
|
||||
// )
|
||||
// });
|
||||
// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
"t mut x = 0;\n",
|
||||
"while x < 10 {\n",
|
||||
" x += 1;\n",
|
||||
"}", //
|
||||
);
|
||||
while !new_text.is_empty() {
|
||||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
provider.send_completion(chunk);
|
||||
new_text = suffix;
|
||||
deterministic.run_until_parked();
|
||||
}
|
||||
provider.finish_completion();
|
||||
deterministic.run_until_parked();
|
||||
// let mut new_text = concat!(
|
||||
// "t mut x = 0;\n",
|
||||
// "while x < 10 {\n",
|
||||
// " x += 1;\n",
|
||||
// "}", //
|
||||
// );
|
||||
// while !new_text.is_empty() {
|
||||
// let max_len = cmp::min(new_text.len(), 10);
|
||||
// let len = rng.gen_range(1..=max_len);
|
||||
// let (chunk, suffix) = new_text.split_at(len);
|
||||
// provider.send_completion(chunk);
|
||||
// new_text = suffix;
|
||||
// deterministic.run_until_parked();
|
||||
// }
|
||||
// provider.finish_completion();
|
||||
// deterministic.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
let mut x = 0;
|
||||
while x < 10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"}
|
||||
);
|
||||
}
|
||||
// assert_eq!(
|
||||
// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
// indoc! {"
|
||||
// fn main() {
|
||||
// let mut x = 0;
|
||||
// while x < 10 {
|
||||
// x += 1;
|
||||
// }
|
||||
// }
|
||||
// "}
|
||||
// );
|
||||
// }
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_autoindent_when_generating_before_indentation(
|
||||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
deterministic: Arc<Deterministic>,
|
||||
) {
|
||||
cx.set_global(cx.read(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
// #[gpui::test(iterations = 10)]
|
||||
// async fn test_autoindent_when_generating_before_indentation(
|
||||
// cx: &mut TestAppContext,
|
||||
// mut rng: StdRng,
|
||||
// deterministic: Arc<Deterministic>,
|
||||
// ) {
|
||||
// cx.set_global(cx.read(SettingsStore::test));
|
||||
// cx.update(language_settings::init);
|
||||
|
||||
let text = concat!(
|
||||
"fn main() {\n",
|
||||
" \n",
|
||||
"}\n" //
|
||||
);
|
||||
let buffer =
|
||||
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let position = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 2))
|
||||
});
|
||||
let provider = Arc::new(TestCompletionProvider::new());
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
CodegenKind::Generate { position },
|
||||
provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
||||
// let text = concat!(
|
||||
// "fn main() {\n",
|
||||
// " \n",
|
||||
// "}\n" //
|
||||
// );
|
||||
// let buffer =
|
||||
// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
// let position = buffer.read_with(cx, |buffer, cx| {
|
||||
// let snapshot = buffer.snapshot(cx);
|
||||
// snapshot.anchor_before(Point::new(1, 2))
|
||||
// });
|
||||
// let provider = Arc::new(TestCompletionProvider::new());
|
||||
// let codegen = cx.add_model(|cx| {
|
||||
// Codegen::new(
|
||||
// buffer.clone(),
|
||||
// CodegenKind::Generate { position },
|
||||
// provider.clone(),
|
||||
// cx,
|
||||
// )
|
||||
// });
|
||||
// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
"let mut x = 0;\n",
|
||||
"while x < 10 {\n",
|
||||
" x += 1;\n",
|
||||
"}", //
|
||||
);
|
||||
while !new_text.is_empty() {
|
||||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
provider.send_completion(chunk);
|
||||
new_text = suffix;
|
||||
deterministic.run_until_parked();
|
||||
}
|
||||
provider.finish_completion();
|
||||
deterministic.run_until_parked();
|
||||
// let mut new_text = concat!(
|
||||
// "let mut x = 0;\n",
|
||||
// "while x < 10 {\n",
|
||||
// " x += 1;\n",
|
||||
// "}", //
|
||||
// );
|
||||
// while !new_text.is_empty() {
|
||||
// let max_len = cmp::min(new_text.len(), 10);
|
||||
// let len = rng.gen_range(1..=max_len);
|
||||
// let (chunk, suffix) = new_text.split_at(len);
|
||||
// provider.send_completion(chunk);
|
||||
// new_text = suffix;
|
||||
// deterministic.run_until_parked();
|
||||
// }
|
||||
// provider.finish_completion();
|
||||
// deterministic.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
let mut x = 0;
|
||||
while x < 10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"}
|
||||
);
|
||||
}
|
||||
// assert_eq!(
|
||||
// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
// indoc! {"
|
||||
// fn main() {
|
||||
// let mut x = 0;
|
||||
// while x < 10 {
|
||||
// x += 1;
|
||||
// }
|
||||
// }
|
||||
// "}
|
||||
// );
|
||||
// }
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_strip_markdown_codeblock() {
|
||||
assert_eq!(
|
||||
strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum dolor"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum dolor"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum dolor"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum dolor"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"```js\nLorem ipsum dolor\n```"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"``\nLorem ipsum dolor\n```"
|
||||
);
|
||||
// #[gpui::test]
|
||||
// async fn test_strip_markdown_codeblock() {
|
||||
// assert_eq!(
|
||||
// strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
|
||||
// .map(|chunk| chunk.unwrap())
|
||||
// .collect::<String>()
|
||||
// .await,
|
||||
// "Lorem ipsum dolor"
|
||||
// );
|
||||
// assert_eq!(
|
||||
// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
|
||||
// .map(|chunk| chunk.unwrap())
|
||||
// .collect::<String>()
|
||||
// .await,
|
||||
// "Lorem ipsum dolor"
|
||||
// );
|
||||
// assert_eq!(
|
||||
// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
|
||||
// .map(|chunk| chunk.unwrap())
|
||||
// .collect::<String>()
|
||||
// .await,
|
||||
// "Lorem ipsum dolor"
|
||||
// );
|
||||
// assert_eq!(
|
||||
// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
|
||||
// .map(|chunk| chunk.unwrap())
|
||||
// .collect::<String>()
|
||||
// .await,
|
||||
// "Lorem ipsum dolor"
|
||||
// );
|
||||
// assert_eq!(
|
||||
// strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
|
||||
// .map(|chunk| chunk.unwrap())
|
||||
// .collect::<String>()
|
||||
// .await,
|
||||
// "```js\nLorem ipsum dolor\n```"
|
||||
// );
|
||||
// assert_eq!(
|
||||
// strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
|
||||
// .map(|chunk| chunk.unwrap())
|
||||
// .collect::<String>()
|
||||
// .await,
|
||||
// "``\nLorem ipsum dolor\n```"
|
||||
// );
|
||||
|
||||
fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
|
||||
stream::iter(
|
||||
text.chars()
|
||||
.collect::<Vec<_>>()
|
||||
.chunks(size)
|
||||
.map(|chunk| Ok(chunk.iter().collect::<String>()))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
}
|
||||
// fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
|
||||
// stream::iter(
|
||||
// text.chars()
|
||||
// .collect::<Vec<_>>()
|
||||
// .chunks(size)
|
||||
// .map(|chunk| Ok(chunk.iter().collect::<String>()))
|
||||
// .collect::<Vec<_>>(),
|
||||
// )
|
||||
// }
|
||||
// }
|
||||
|
||||
struct TestCompletionProvider {
|
||||
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||
}
|
||||
// struct TestCompletionProvider {
|
||||
// last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||
// }
|
||||
|
||||
impl TestCompletionProvider {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
last_completion_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
// impl TestCompletionProvider {
|
||||
// fn new() -> Self {
|
||||
// Self {
|
||||
// last_completion_tx: Mutex::new(None),
|
||||
// }
|
||||
// }
|
||||
|
||||
fn send_completion(&self, completion: impl Into<String>) {
|
||||
let mut tx = self.last_completion_tx.lock();
|
||||
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||
}
|
||||
// fn send_completion(&self, completion: impl Into<String>) {
|
||||
// let mut tx = self.last_completion_tx.lock();
|
||||
// tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||
// }
|
||||
|
||||
fn finish_completion(&self) {
|
||||
self.last_completion_tx.lock().take().unwrap();
|
||||
}
|
||||
}
|
||||
// fn finish_completion(&self) {
|
||||
// self.last_completion_tx.lock().take().unwrap();
|
||||
// }
|
||||
// }
|
||||
|
||||
impl CompletionProvider for TestCompletionProvider {
|
||||
fn complete(
|
||||
&self,
|
||||
_prompt: OpenAIRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
*self.last_completion_tx.lock() = Some(tx);
|
||||
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
// impl CompletionProvider for TestCompletionProvider {
|
||||
// fn complete(
|
||||
// &self,
|
||||
// _prompt: OpenAIRequest,
|
||||
// ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
// let (tx, rx) = mpsc::channel(1);
|
||||
// *self.last_completion_tx.lock() = Some(tx);
|
||||
// async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||
// }
|
||||
// }
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::language()),
|
||||
)
|
||||
.with_indents_query(
|
||||
r#"
|
||||
(call_expression) @indent
|
||||
(field_expression) @indent
|
||||
(_ "(" ")" @end) @indent
|
||||
(_ "{" "}" @end) @indent
|
||||
"#,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
// fn rust_lang() -> Language {
|
||||
// Language::new(
|
||||
// LanguageConfig {
|
||||
// name: "Rust".into(),
|
||||
// path_suffixes: vec!["rs".to_string()],
|
||||
// ..Default::default()
|
||||
// },
|
||||
// Some(tree_sitter_rust::language()),
|
||||
// )
|
||||
// .with_indents_query(
|
||||
// r#"
|
||||
// (call_expression) @indent
|
||||
// (field_expression) @indent
|
||||
// (_ "(" ")" @end) @indent
|
||||
// (_ "{" "}" @end) @indent
|
||||
// "#,
|
||||
// )
|
||||
// .unwrap()
|
||||
// }
|
||||
// }
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use gpui::AppContext;
|
||||
use gpui::{AppContext, AsyncAppContext};
|
||||
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
||||
use std::cmp;
|
||||
use std::ops::Range;
|
||||
|
@ -83,28 +83,28 @@ fn outline_for_prompt(
|
|||
Some(text)
|
||||
}
|
||||
|
||||
pub fn generate_content_prompt(
|
||||
pub fn generate_codegen_planning_prompt(
|
||||
user_prompt: String,
|
||||
language_name: Option<&str>,
|
||||
buffer: &BufferSnapshot,
|
||||
range: Range<language::Anchor>,
|
||||
cx: &AppContext,
|
||||
kind: CodegenKind,
|
||||
) -> String {
|
||||
) -> (String, Option<String>) {
|
||||
let mut prompt = String::new();
|
||||
|
||||
// General Preamble
|
||||
if let Some(language_name) = language_name {
|
||||
writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "You're an expert engineer.\n").unwrap();
|
||||
writeln!(prompt, "You're an expert software engineer.\n").unwrap();
|
||||
}
|
||||
|
||||
let outline = outline_for_prompt(buffer, range.clone(), cx);
|
||||
if let Some(outline) = outline {
|
||||
if let Some(outline) = outline.clone() {
|
||||
writeln!(
|
||||
prompt,
|
||||
"The file you are currently working on has the following outline:"
|
||||
"You're currently working inside the Zed editor on a file with the following outline:"
|
||||
)
|
||||
.unwrap();
|
||||
if let Some(language_name) = language_name {
|
||||
|
@ -115,15 +115,71 @@ pub fn generate_content_prompt(
|
|||
}
|
||||
}
|
||||
|
||||
// Assume for now that we are just generating
|
||||
if range.clone().start == range.end {
|
||||
writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap();
|
||||
match kind {
|
||||
CodegenKind::Generate { position: _ } => {
|
||||
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.").unwrap();
|
||||
}
|
||||
CodegenKind::Transform { range: _ } => {
|
||||
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
writeln!(
|
||||
prompt,
|
||||
"The user has provided the following prompt: '{user_prompt}'\n"
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"It is your task to identify if any additional context is needed from the repository"
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
(prompt, outline)
|
||||
}
|
||||
pub fn generate_content_prompt(
|
||||
user_prompt: String,
|
||||
language_name: Option<String>,
|
||||
outline: Option<String>,
|
||||
kind: CodegenKind,
|
||||
snippet: Vec<String>,
|
||||
) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
// General Preamble
|
||||
if let Some(language_name) = language_name.clone() {
|
||||
writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
|
||||
writeln!(prompt, "You're an expert software engineer.\n").unwrap();
|
||||
}
|
||||
|
||||
if snippet.len() > 0 {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Here are a few snippets from the codebase which may help: "
|
||||
);
|
||||
}
|
||||
for snip in snippet {
|
||||
writeln!(prompt, "{snip}");
|
||||
}
|
||||
|
||||
if let Some(outline) = outline {
|
||||
writeln!(
|
||||
prompt,
|
||||
"The file you are currently working on has the following outline:"
|
||||
)
|
||||
.unwrap();
|
||||
if let Some(language_name) = language_name.clone() {
|
||||
let language_name = language_name.to_lowercase();
|
||||
writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "```\n{outline}\n```").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
match kind {
|
||||
CodegenKind::Generate { position: _ } => {
|
||||
writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Assume the cursor is located where the `<|START|` marker is."
|
||||
|
@ -141,6 +197,7 @@ pub fn generate_content_prompt(
|
|||
.unwrap();
|
||||
}
|
||||
CodegenKind::Transform { range: _ } => {
|
||||
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Modify the users code selected text based upon the users prompt: {user_prompt}"
|
||||
|
|
|
@ -2,6 +2,7 @@ mod db;
|
|||
mod embedding_queue;
|
||||
mod parsing;
|
||||
pub mod semantic_index_settings;
|
||||
pub mod skills;
|
||||
|
||||
#[cfg(test)]
|
||||
mod semantic_index_tests;
|
||||
|
|
106
crates/semantic_index/src/skills.rs
Normal file
106
crates/semantic_index/src/skills.rs
Normal file
|
@ -0,0 +1,106 @@
|
|||
use ai::function_calling::OpenAIFunction;
|
||||
use anyhow::anyhow;
|
||||
use gpui::{AppContext, AsyncAppContext, ModelHandle};
|
||||
use project::Project;
|
||||
use serde::{Serialize, Serializer};
|
||||
use serde_json::json;
|
||||
use std::fmt::Write;
|
||||
|
||||
use crate::SemanticIndex;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RepositoryContextRetriever {
|
||||
index: ModelHandle<SemanticIndex>,
|
||||
project: ModelHandle<Project>,
|
||||
}
|
||||
|
||||
impl RepositoryContextRetriever {
|
||||
pub fn load(index: ModelHandle<SemanticIndex>, project: ModelHandle<Project>) -> Self {
|
||||
Self { index, project }
|
||||
}
|
||||
pub async fn complete_test(
|
||||
&self,
|
||||
arguments: serde_json::Value,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> anyhow::Result<String> {
|
||||
let queries = arguments.get("queries").unwrap().as_array().unwrap();
|
||||
let mut prompt = String::new();
|
||||
let query = queries
|
||||
.iter()
|
||||
.map(|query| query.to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(";");
|
||||
let project = self.project.clone();
|
||||
let results = self
|
||||
.index
|
||||
.update(cx, |this, cx| {
|
||||
this.search_project(project, query, 10, vec![], vec![], cx)
|
||||
})
|
||||
.await?;
|
||||
|
||||
for result in results {
|
||||
result.buffer.read_with(cx, |buffer, cx| {
|
||||
let text = buffer.text_for_range(result.range).collect::<String>();
|
||||
let file_path = buffer.file().unwrap().path().to_string_lossy();
|
||||
let language = buffer.language();
|
||||
|
||||
writeln!(
|
||||
prompt,
|
||||
"The following is a relevant snippet from file ({}):",
|
||||
file_path
|
||||
)
|
||||
.unwrap();
|
||||
if let Some(language) = language {
|
||||
writeln!(prompt, "```{}\n{text}\n```", language.name().to_lowercase()).unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "```\n{text}\n```").unwrap();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIFunction for RepositoryContextRetriever {
|
||||
fn name(&self) -> String {
|
||||
"retrieve_context_from_repository".to_string()
|
||||
}
|
||||
fn description(&self) -> String {
|
||||
"Retrieve relevant content from repository with natural language".to_string()
|
||||
}
|
||||
fn system_prompt(&self) -> String {
|
||||
"'retrieve_context_from_repository'
|
||||
If more information is needed from the repository, to complete the users prompt reliably, pass up to 3 queries describing pieces of code or text you would like additional context upon.
|
||||
Do not make these queries general about programming, include very specific lexical references to the pieces of code you need more information on.
|
||||
We are passing these into a semantic similarity retrieval engine, with all the information in the current codebase included.
|
||||
As such, these should be phrased as descriptions of code of interest as opposed to questions".to_string()
|
||||
}
|
||||
fn parameters(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"queries": {
|
||||
"title": "queries",
|
||||
"type": "array",
|
||||
"items": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"required": ["queries"]
|
||||
})
|
||||
}
|
||||
fn complete(&self, arguments: serde_json::Value) -> anyhow::Result<String> {
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
impl Serialize for RepositoryContextRetriever {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
json!({"name": self.name(),
|
||||
"description": self.description(),
|
||||
"parameters": self.parameters()})
|
||||
.serialize(serializer)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue