Introduce Context Retrieval in Inline Assistant (#3097)
This PR introduces a new Inline Assistant feature "Retrieve Context", to dynamically fill the content in your generation prompt based on relevant results returned from the Semantic Search for the Prompt. Release Notes: - Introduce "Retrieve Context" button in Inline Assistant
This commit is contained in:
commit
2795091f0c
8 changed files with 734 additions and 139 deletions
|
@ -85,25 +85,6 @@ impl Embedding {
|
|||
}
|
||||
}
|
||||
|
||||
// impl FromSql for Embedding {
|
||||
// fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
// let bytes = value.as_blob()?;
|
||||
// let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
// if embedding.is_err() {
|
||||
// return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
|
||||
// }
|
||||
// Ok(Embedding(embedding.unwrap()))
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl ToSql for Embedding {
|
||||
// fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||
// let bytes = bincode::serialize(&self.0)
|
||||
// .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
|
||||
// Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
|
||||
// }
|
||||
// }
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAIEmbeddings {
|
||||
pub client: Arc<dyn HttpClient>,
|
||||
|
@ -300,6 +281,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
|||
request_timeout,
|
||||
)
|
||||
.await?;
|
||||
|
||||
request_number += 1;
|
||||
|
||||
match response.status() {
|
||||
|
|
|
@ -22,8 +22,11 @@ settings = { path = "../settings" }
|
|||
theme = { path = "../theme" }
|
||||
util = { path = "../util" }
|
||||
workspace = { path = "../workspace" }
|
||||
uuid.workspace = true
|
||||
semantic_index = { path = "../semantic_index" }
|
||||
project = { path = "../project" }
|
||||
|
||||
uuid.workspace = true
|
||||
log.workspace = true
|
||||
anyhow.workspace = true
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
futures.workspace = true
|
||||
|
@ -36,7 +39,7 @@ schemars.workspace = true
|
|||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
smol.workspace = true
|
||||
tiktoken-rs = "0.4"
|
||||
tiktoken-rs = "0.5"
|
||||
|
||||
[dev-dependencies]
|
||||
editor = { path = "../editor", features = ["test-support"] }
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
|
||||
codegen::{self, Codegen, CodegenKind},
|
||||
prompts::generate_content_prompt,
|
||||
prompts::{generate_content_prompt, PromptCodeSnippet},
|
||||
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
|
||||
SavedMessage,
|
||||
};
|
||||
|
@ -29,13 +29,15 @@ use gpui::{
|
|||
},
|
||||
fonts::HighlightStyle,
|
||||
geometry::vector::{vec2f, Vector2F},
|
||||
platform::{CursorStyle, MouseButton},
|
||||
platform::{CursorStyle, MouseButton, PromptLevel},
|
||||
Action, AnyElement, AppContext, AsyncAppContext, ClipboardItem, Element, Entity, ModelContext,
|
||||
ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle,
|
||||
WindowContext,
|
||||
ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle,
|
||||
WeakModelHandle, WeakViewHandle, WindowContext,
|
||||
};
|
||||
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
|
||||
use project::Project;
|
||||
use search::BufferSearchBar;
|
||||
use semantic_index::{SemanticIndex, SemanticIndexStatus};
|
||||
use settings::SettingsStore;
|
||||
use std::{
|
||||
cell::{Cell, RefCell},
|
||||
|
@ -46,7 +48,7 @@ use std::{
|
|||
path::{Path, PathBuf},
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use theme::{
|
||||
components::{action_button::Button, ComponentExt},
|
||||
|
@ -72,6 +74,7 @@ actions!(
|
|||
ResetKey,
|
||||
InlineAssist,
|
||||
ToggleIncludeConversation,
|
||||
ToggleRetrieveContext,
|
||||
]
|
||||
);
|
||||
|
||||
|
@ -108,6 +111,7 @@ pub fn init(cx: &mut AppContext) {
|
|||
cx.add_action(InlineAssistant::confirm);
|
||||
cx.add_action(InlineAssistant::cancel);
|
||||
cx.add_action(InlineAssistant::toggle_include_conversation);
|
||||
cx.add_action(InlineAssistant::toggle_retrieve_context);
|
||||
cx.add_action(InlineAssistant::move_up);
|
||||
cx.add_action(InlineAssistant::move_down);
|
||||
}
|
||||
|
@ -145,6 +149,8 @@ pub struct AssistantPanel {
|
|||
include_conversation_in_next_inline_assist: bool,
|
||||
inline_prompt_history: VecDeque<String>,
|
||||
_watch_saved_conversations: Task<Result<()>>,
|
||||
semantic_index: Option<ModelHandle<SemanticIndex>>,
|
||||
retrieve_context_in_next_inline_assist: bool,
|
||||
}
|
||||
|
||||
impl AssistantPanel {
|
||||
|
@ -191,6 +197,9 @@ impl AssistantPanel {
|
|||
toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
|
||||
toolbar
|
||||
});
|
||||
|
||||
let semantic_index = SemanticIndex::global(cx);
|
||||
|
||||
let mut this = Self {
|
||||
workspace: workspace_handle,
|
||||
active_editor_index: Default::default(),
|
||||
|
@ -215,6 +224,8 @@ impl AssistantPanel {
|
|||
include_conversation_in_next_inline_assist: false,
|
||||
inline_prompt_history: Default::default(),
|
||||
_watch_saved_conversations,
|
||||
semantic_index,
|
||||
retrieve_context_in_next_inline_assist: false,
|
||||
};
|
||||
|
||||
let mut old_dock_position = this.position(cx);
|
||||
|
@ -262,12 +273,19 @@ impl AssistantPanel {
|
|||
return;
|
||||
};
|
||||
|
||||
let project = workspace.project();
|
||||
|
||||
this.update(cx, |assistant, cx| {
|
||||
assistant.new_inline_assist(&active_editor, cx)
|
||||
assistant.new_inline_assist(&active_editor, cx, project)
|
||||
});
|
||||
}
|
||||
|
||||
fn new_inline_assist(&mut self, editor: &ViewHandle<Editor>, cx: &mut ViewContext<Self>) {
|
||||
fn new_inline_assist(
|
||||
&mut self,
|
||||
editor: &ViewHandle<Editor>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
project: &ModelHandle<Project>,
|
||||
) {
|
||||
let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
|
||||
api_key
|
||||
} else {
|
||||
|
@ -312,6 +330,27 @@ impl AssistantPanel {
|
|||
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
|
||||
});
|
||||
|
||||
if let Some(semantic_index) = self.semantic_index.clone() {
|
||||
let project = project.clone();
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
let previously_indexed = semantic_index
|
||||
.update(&mut cx, |index, cx| {
|
||||
index.project_previously_indexed(&project, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
if previously_indexed {
|
||||
let _ = semantic_index
|
||||
.update(&mut cx, |index, cx| {
|
||||
index.index_project(project.clone(), cx)
|
||||
})
|
||||
.await;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
let measurements = Rc::new(Cell::new(BlockMeasurements::default()));
|
||||
let inline_assistant = cx.add_view(|cx| {
|
||||
let assistant = InlineAssistant::new(
|
||||
|
@ -322,6 +361,9 @@ impl AssistantPanel {
|
|||
codegen.clone(),
|
||||
self.workspace.clone(),
|
||||
cx,
|
||||
self.retrieve_context_in_next_inline_assist,
|
||||
self.semantic_index.clone(),
|
||||
project.clone(),
|
||||
);
|
||||
cx.focus_self();
|
||||
assistant
|
||||
|
@ -362,6 +404,7 @@ impl AssistantPanel {
|
|||
editor: editor.downgrade(),
|
||||
inline_assistant: Some((block_id, inline_assistant.clone())),
|
||||
codegen: codegen.clone(),
|
||||
project: project.downgrade(),
|
||||
_subscriptions: vec![
|
||||
cx.subscribe(&inline_assistant, Self::handle_inline_assistant_event),
|
||||
cx.subscribe(editor, {
|
||||
|
@ -440,8 +483,15 @@ impl AssistantPanel {
|
|||
InlineAssistantEvent::Confirmed {
|
||||
prompt,
|
||||
include_conversation,
|
||||
retrieve_context,
|
||||
} => {
|
||||
self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx);
|
||||
self.confirm_inline_assist(
|
||||
assist_id,
|
||||
prompt,
|
||||
*include_conversation,
|
||||
cx,
|
||||
*retrieve_context,
|
||||
);
|
||||
}
|
||||
InlineAssistantEvent::Canceled => {
|
||||
self.finish_inline_assist(assist_id, true, cx);
|
||||
|
@ -454,6 +504,9 @@ impl AssistantPanel {
|
|||
} => {
|
||||
self.include_conversation_in_next_inline_assist = *include_conversation;
|
||||
}
|
||||
InlineAssistantEvent::RetrieveContextToggled { retrieve_context } => {
|
||||
self.retrieve_context_in_next_inline_assist = *retrieve_context
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -532,6 +585,7 @@ impl AssistantPanel {
|
|||
user_prompt: &str,
|
||||
include_conversation: bool,
|
||||
cx: &mut ViewContext<Self>,
|
||||
retrieve_context: bool,
|
||||
) {
|
||||
let conversation = if include_conversation {
|
||||
self.active_editor()
|
||||
|
@ -553,6 +607,8 @@ impl AssistantPanel {
|
|||
return;
|
||||
};
|
||||
|
||||
let project = pending_assist.project.clone();
|
||||
|
||||
self.inline_prompt_history
|
||||
.retain(|prompt| prompt != user_prompt);
|
||||
self.inline_prompt_history.push_back(user_prompt.into());
|
||||
|
@ -593,10 +649,62 @@ impl AssistantPanel {
|
|||
let codegen_kind = codegen.read(cx).kind().clone();
|
||||
let user_prompt = user_prompt.to_string();
|
||||
|
||||
let mut messages = Vec::new();
|
||||
let snippets = if retrieve_context {
|
||||
let Some(project) = project.upgrade(cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let search_results = if let Some(semantic_index) = self.semantic_index.clone() {
|
||||
let search_results = semantic_index.update(cx, |this, cx| {
|
||||
this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx)
|
||||
});
|
||||
|
||||
cx.background()
|
||||
.spawn(async move { search_results.await.unwrap_or_default() })
|
||||
} else {
|
||||
Task::ready(Vec::new())
|
||||
};
|
||||
|
||||
let snippets = cx.spawn(|_, cx| async move {
|
||||
let mut snippets = Vec::new();
|
||||
for result in search_results.await {
|
||||
snippets.push(PromptCodeSnippet::new(result, &cx));
|
||||
|
||||
// snippets.push(result.buffer.read_with(&cx, |buffer, _| {
|
||||
// buffer
|
||||
// .snapshot()
|
||||
// .text_for_range(result.range)
|
||||
// .collect::<String>()
|
||||
// }));
|
||||
}
|
||||
snippets
|
||||
});
|
||||
snippets
|
||||
} else {
|
||||
Task::ready(Vec::new())
|
||||
};
|
||||
|
||||
let mut model = settings::get::<AssistantSettings>(cx)
|
||||
.default_open_ai_model
|
||||
.clone();
|
||||
let model_name = model.full_name();
|
||||
|
||||
let prompt = cx.background().spawn(async move {
|
||||
let snippets = snippets.await;
|
||||
|
||||
let language_name = language_name.as_deref();
|
||||
generate_content_prompt(
|
||||
user_prompt,
|
||||
language_name,
|
||||
&buffer,
|
||||
range,
|
||||
codegen_kind,
|
||||
snippets,
|
||||
model_name,
|
||||
)
|
||||
});
|
||||
|
||||
let mut messages = Vec::new();
|
||||
if let Some(conversation) = conversation {
|
||||
let conversation = conversation.read(cx);
|
||||
let buffer = conversation.buffer.read(cx);
|
||||
|
@ -608,11 +716,6 @@ impl AssistantPanel {
|
|||
model = conversation.model.clone();
|
||||
}
|
||||
|
||||
let prompt = cx.background().spawn(async move {
|
||||
let language_name = language_name.as_deref();
|
||||
generate_content_prompt(user_prompt, language_name, &buffer, range, codegen_kind)
|
||||
});
|
||||
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
let prompt = prompt.await;
|
||||
|
||||
|
@ -1514,12 +1617,14 @@ impl Conversation {
|
|||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: self
|
||||
.buffer
|
||||
.read(cx)
|
||||
.text_for_range(message.offset_range)
|
||||
.collect(),
|
||||
content: Some(
|
||||
self.buffer
|
||||
.read(cx)
|
||||
.text_for_range(message.offset_range)
|
||||
.collect(),
|
||||
),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
@ -2638,12 +2743,16 @@ enum InlineAssistantEvent {
|
|||
Confirmed {
|
||||
prompt: String,
|
||||
include_conversation: bool,
|
||||
retrieve_context: bool,
|
||||
},
|
||||
Canceled,
|
||||
Dismissed,
|
||||
IncludeConversationToggled {
|
||||
include_conversation: bool,
|
||||
},
|
||||
RetrieveContextToggled {
|
||||
retrieve_context: bool,
|
||||
},
|
||||
}
|
||||
|
||||
struct InlineAssistant {
|
||||
|
@ -2659,6 +2768,11 @@ struct InlineAssistant {
|
|||
pending_prompt: String,
|
||||
codegen: ModelHandle<Codegen>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
retrieve_context: bool,
|
||||
semantic_index: Option<ModelHandle<SemanticIndex>>,
|
||||
semantic_permissioned: Option<bool>,
|
||||
project: WeakModelHandle<Project>,
|
||||
maintain_rate_limit: Option<Task<()>>,
|
||||
}
|
||||
|
||||
impl Entity for InlineAssistant {
|
||||
|
@ -2675,51 +2789,65 @@ impl View for InlineAssistant {
|
|||
let theme = theme::current(cx);
|
||||
|
||||
Flex::row()
|
||||
.with_child(
|
||||
Flex::row()
|
||||
.with_child(
|
||||
Button::action(ToggleIncludeConversation)
|
||||
.with_tooltip("Include Conversation", theme.tooltip.clone())
|
||||
.with_children([Flex::row()
|
||||
.with_child(
|
||||
Button::action(ToggleIncludeConversation)
|
||||
.with_tooltip("Include Conversation", theme.tooltip.clone())
|
||||
.with_id(self.id)
|
||||
.with_contents(theme::components::svg::Svg::new("icons/ai.svg"))
|
||||
.toggleable(self.include_conversation)
|
||||
.with_style(theme.assistant.inline.include_conversation.clone())
|
||||
.element()
|
||||
.aligned(),
|
||||
)
|
||||
.with_children(if SemanticIndex::enabled(cx) {
|
||||
Some(
|
||||
Button::action(ToggleRetrieveContext)
|
||||
.with_tooltip("Retrieve Context", theme.tooltip.clone())
|
||||
.with_id(self.id)
|
||||
.with_contents(theme::components::svg::Svg::new("icons/ai.svg"))
|
||||
.toggleable(self.include_conversation)
|
||||
.with_style(theme.assistant.inline.include_conversation.clone())
|
||||
.with_contents(theme::components::svg::Svg::new(
|
||||
"icons/magnifying_glass.svg",
|
||||
))
|
||||
.toggleable(self.retrieve_context)
|
||||
.with_style(theme.assistant.inline.retrieve_context.clone())
|
||||
.element()
|
||||
.aligned(),
|
||||
)
|
||||
.with_children(if let Some(error) = self.codegen.read(cx).error() {
|
||||
Some(
|
||||
Svg::new("icons/error.svg")
|
||||
.with_color(theme.assistant.error_icon.color)
|
||||
.constrained()
|
||||
.with_width(theme.assistant.error_icon.width)
|
||||
.contained()
|
||||
.with_style(theme.assistant.error_icon.container)
|
||||
.with_tooltip::<ErrorIcon>(
|
||||
self.id,
|
||||
error.to_string(),
|
||||
None,
|
||||
theme.tooltip.clone(),
|
||||
cx,
|
||||
)
|
||||
.aligned(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
})
|
||||
.aligned()
|
||||
.constrained()
|
||||
.dynamically({
|
||||
let measurements = self.measurements.clone();
|
||||
move |constraint, _, _| {
|
||||
let measurements = measurements.get();
|
||||
SizeConstraint {
|
||||
min: vec2f(measurements.gutter_width, constraint.min.y()),
|
||||
max: vec2f(measurements.gutter_width, constraint.max.y()),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
})
|
||||
.with_children(if let Some(error) = self.codegen.read(cx).error() {
|
||||
Some(
|
||||
Svg::new("icons/error.svg")
|
||||
.with_color(theme.assistant.error_icon.color)
|
||||
.constrained()
|
||||
.with_width(theme.assistant.error_icon.width)
|
||||
.contained()
|
||||
.with_style(theme.assistant.error_icon.container)
|
||||
.with_tooltip::<ErrorIcon>(
|
||||
self.id,
|
||||
error.to_string(),
|
||||
None,
|
||||
theme.tooltip.clone(),
|
||||
cx,
|
||||
)
|
||||
.aligned(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
})
|
||||
.aligned()
|
||||
.constrained()
|
||||
.dynamically({
|
||||
let measurements = self.measurements.clone();
|
||||
move |constraint, _, _| {
|
||||
let measurements = measurements.get();
|
||||
SizeConstraint {
|
||||
min: vec2f(measurements.gutter_width, constraint.min.y()),
|
||||
max: vec2f(measurements.gutter_width, constraint.max.y()),
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
})])
|
||||
.with_child(Empty::new().constrained().dynamically({
|
||||
let measurements = self.measurements.clone();
|
||||
move |constraint, _, _| {
|
||||
|
@ -2742,6 +2870,16 @@ impl View for InlineAssistant {
|
|||
.left()
|
||||
.flex(1., true),
|
||||
)
|
||||
.with_children(if self.retrieve_context {
|
||||
Some(
|
||||
Flex::row()
|
||||
.with_children(self.retrieve_context_status(cx))
|
||||
.flex(1., true)
|
||||
.aligned(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
})
|
||||
.contained()
|
||||
.with_style(theme.assistant.inline.container)
|
||||
.into_any()
|
||||
|
@ -2767,6 +2905,9 @@ impl InlineAssistant {
|
|||
codegen: ModelHandle<Codegen>,
|
||||
workspace: WeakViewHandle<Workspace>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
retrieve_context: bool,
|
||||
semantic_index: Option<ModelHandle<SemanticIndex>>,
|
||||
project: ModelHandle<Project>,
|
||||
) -> Self {
|
||||
let prompt_editor = cx.add_view(|cx| {
|
||||
let mut editor = Editor::single_line(
|
||||
|
@ -2780,11 +2921,16 @@ impl InlineAssistant {
|
|||
editor.set_placeholder_text(placeholder, cx);
|
||||
editor
|
||||
});
|
||||
let subscriptions = vec![
|
||||
let mut subscriptions = vec![
|
||||
cx.observe(&codegen, Self::handle_codegen_changed),
|
||||
cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
|
||||
];
|
||||
Self {
|
||||
|
||||
if let Some(semantic_index) = semantic_index.clone() {
|
||||
subscriptions.push(cx.observe(&semantic_index, Self::semantic_index_changed));
|
||||
}
|
||||
|
||||
let assistant = Self {
|
||||
id,
|
||||
prompt_editor,
|
||||
workspace,
|
||||
|
@ -2797,7 +2943,33 @@ impl InlineAssistant {
|
|||
pending_prompt: String::new(),
|
||||
codegen,
|
||||
_subscriptions: subscriptions,
|
||||
retrieve_context,
|
||||
semantic_permissioned: None,
|
||||
semantic_index,
|
||||
project: project.downgrade(),
|
||||
maintain_rate_limit: None,
|
||||
};
|
||||
|
||||
assistant.index_project(cx).log_err();
|
||||
|
||||
assistant
|
||||
}
|
||||
|
||||
fn semantic_permissioned(&self, cx: &mut ViewContext<Self>) -> Task<Result<bool>> {
|
||||
if let Some(value) = self.semantic_permissioned {
|
||||
return Task::ready(Ok(value));
|
||||
}
|
||||
|
||||
let Some(project) = self.project.upgrade(cx) else {
|
||||
return Task::ready(Err(anyhow!("project was dropped")));
|
||||
};
|
||||
|
||||
self.semantic_index
|
||||
.as_ref()
|
||||
.map(|semantic| {
|
||||
semantic.update(cx, |this, cx| this.project_previously_indexed(&project, cx))
|
||||
})
|
||||
.unwrap_or(Task::ready(Ok(false)))
|
||||
}
|
||||
|
||||
fn handle_prompt_editor_events(
|
||||
|
@ -2812,6 +2984,37 @@ impl InlineAssistant {
|
|||
}
|
||||
}
|
||||
|
||||
fn semantic_index_changed(
|
||||
&mut self,
|
||||
semantic_index: ModelHandle<SemanticIndex>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) {
|
||||
let Some(project) = self.project.upgrade(cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let status = semantic_index.read(cx).status(&project);
|
||||
match status {
|
||||
SemanticIndexStatus::Indexing {
|
||||
rate_limit_expiry: Some(_),
|
||||
..
|
||||
} => {
|
||||
if self.maintain_rate_limit.is_none() {
|
||||
self.maintain_rate_limit = Some(cx.spawn(|this, mut cx| async move {
|
||||
loop {
|
||||
cx.background().timer(Duration::from_secs(1)).await;
|
||||
this.update(&mut cx, |_, cx| cx.notify()).log_err();
|
||||
}
|
||||
}));
|
||||
}
|
||||
return;
|
||||
}
|
||||
_ => {
|
||||
self.maintain_rate_limit = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_codegen_changed(&mut self, _: ModelHandle<Codegen>, cx: &mut ViewContext<Self>) {
|
||||
let is_read_only = !self.codegen.read(cx).idle();
|
||||
self.prompt_editor.update(cx, |editor, cx| {
|
||||
|
@ -2861,12 +3064,241 @@ impl InlineAssistant {
|
|||
cx.emit(InlineAssistantEvent::Confirmed {
|
||||
prompt,
|
||||
include_conversation: self.include_conversation,
|
||||
retrieve_context: self.retrieve_context,
|
||||
});
|
||||
self.confirmed = true;
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
|
||||
fn toggle_retrieve_context(&mut self, _: &ToggleRetrieveContext, cx: &mut ViewContext<Self>) {
|
||||
let semantic_permissioned = self.semantic_permissioned(cx);
|
||||
|
||||
let Some(project) = self.project.upgrade(cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let project_name = project
|
||||
.read(cx)
|
||||
.worktree_root_names(cx)
|
||||
.collect::<Vec<&str>>()
|
||||
.join("/");
|
||||
let is_plural = project_name.chars().filter(|letter| *letter == '/').count() > 0;
|
||||
let prompt_text = format!("Would you like to index the '{}' project{} for context retrieval? This requires sending code to the OpenAI API", project_name,
|
||||
if is_plural {
|
||||
"s"
|
||||
} else {""});
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
// If Necessary prompt user
|
||||
if !semantic_permissioned.await.unwrap_or(false) {
|
||||
let mut answer = this.update(&mut cx, |_, cx| {
|
||||
cx.prompt(
|
||||
PromptLevel::Info,
|
||||
prompt_text.as_str(),
|
||||
&["Continue", "Cancel"],
|
||||
)
|
||||
})?;
|
||||
|
||||
if answer.next().await == Some(0) {
|
||||
this.update(&mut cx, |this, _| {
|
||||
this.semantic_permissioned = Some(true);
|
||||
})?;
|
||||
} else {
|
||||
return anyhow::Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// If permissioned, update context appropriately
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.retrieve_context = !this.retrieve_context;
|
||||
|
||||
cx.emit(InlineAssistantEvent::RetrieveContextToggled {
|
||||
retrieve_context: this.retrieve_context,
|
||||
});
|
||||
|
||||
if this.retrieve_context {
|
||||
this.index_project(cx).log_err();
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn index_project(&self, cx: &mut ViewContext<Self>) -> anyhow::Result<()> {
|
||||
let Some(project) = self.project.upgrade(cx) else {
|
||||
return Err(anyhow!("project was dropped!"));
|
||||
};
|
||||
|
||||
let semantic_permissioned = self.semantic_permissioned(cx);
|
||||
if let Some(semantic_index) = SemanticIndex::global(cx) {
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
// This has to be updated to accomodate for semantic_permissions
|
||||
if semantic_permissioned.await.unwrap_or(false) {
|
||||
semantic_index
|
||||
.update(&mut cx, |index, cx| index.index_project(project, cx))
|
||||
.await
|
||||
} else {
|
||||
Err(anyhow!("project is not permissioned for semantic indexing"))
|
||||
}
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
|
||||
fn retrieve_context_status(
|
||||
&self,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Option<AnyElement<InlineAssistant>> {
|
||||
enum ContextStatusIcon {}
|
||||
|
||||
let Some(project) = self.project.upgrade(cx) else {
|
||||
return None;
|
||||
};
|
||||
|
||||
if let Some(semantic_index) = SemanticIndex::global(cx) {
|
||||
let status = semantic_index.update(cx, |index, _| index.status(&project));
|
||||
let theme = theme::current(cx);
|
||||
match status {
|
||||
SemanticIndexStatus::NotAuthenticated {} => Some(
|
||||
Svg::new("icons/error.svg")
|
||||
.with_color(theme.assistant.error_icon.color)
|
||||
.constrained()
|
||||
.with_width(theme.assistant.error_icon.width)
|
||||
.contained()
|
||||
.with_style(theme.assistant.error_icon.container)
|
||||
.with_tooltip::<ContextStatusIcon>(
|
||||
self.id,
|
||||
"Not Authenticated. Please ensure you have a valid 'OPENAI_API_KEY' in your environment variables.",
|
||||
None,
|
||||
theme.tooltip.clone(),
|
||||
cx,
|
||||
)
|
||||
.aligned()
|
||||
.into_any(),
|
||||
),
|
||||
SemanticIndexStatus::NotIndexed {} => Some(
|
||||
Svg::new("icons/error.svg")
|
||||
.with_color(theme.assistant.inline.context_status.error_icon.color)
|
||||
.constrained()
|
||||
.with_width(theme.assistant.inline.context_status.error_icon.width)
|
||||
.contained()
|
||||
.with_style(theme.assistant.inline.context_status.error_icon.container)
|
||||
.with_tooltip::<ContextStatusIcon>(
|
||||
self.id,
|
||||
"Not Indexed",
|
||||
None,
|
||||
theme.tooltip.clone(),
|
||||
cx,
|
||||
)
|
||||
.aligned()
|
||||
.into_any(),
|
||||
),
|
||||
SemanticIndexStatus::Indexing {
|
||||
remaining_files,
|
||||
rate_limit_expiry,
|
||||
} => {
|
||||
|
||||
let mut status_text = if remaining_files == 0 {
|
||||
"Indexing...".to_string()
|
||||
} else {
|
||||
format!("Remaining files to index: {remaining_files}")
|
||||
};
|
||||
|
||||
if let Some(rate_limit_expiry) = rate_limit_expiry {
|
||||
let remaining_seconds = rate_limit_expiry.duration_since(Instant::now());
|
||||
if remaining_seconds > Duration::from_secs(0) && remaining_files > 0 {
|
||||
write!(
|
||||
status_text,
|
||||
" (rate limit expires in {}s)",
|
||||
remaining_seconds.as_secs()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
Some(
|
||||
Svg::new("icons/update.svg")
|
||||
.with_color(theme.assistant.inline.context_status.in_progress_icon.color)
|
||||
.constrained()
|
||||
.with_width(theme.assistant.inline.context_status.in_progress_icon.width)
|
||||
.contained()
|
||||
.with_style(theme.assistant.inline.context_status.in_progress_icon.container)
|
||||
.with_tooltip::<ContextStatusIcon>(
|
||||
self.id,
|
||||
status_text,
|
||||
None,
|
||||
theme.tooltip.clone(),
|
||||
cx,
|
||||
)
|
||||
.aligned()
|
||||
.into_any(),
|
||||
)
|
||||
}
|
||||
SemanticIndexStatus::Indexed {} => Some(
|
||||
Svg::new("icons/check.svg")
|
||||
.with_color(theme.assistant.inline.context_status.complete_icon.color)
|
||||
.constrained()
|
||||
.with_width(theme.assistant.inline.context_status.complete_icon.width)
|
||||
.contained()
|
||||
.with_style(theme.assistant.inline.context_status.complete_icon.container)
|
||||
.with_tooltip::<ContextStatusIcon>(
|
||||
self.id,
|
||||
"Index up to date",
|
||||
None,
|
||||
theme.tooltip.clone(),
|
||||
cx,
|
||||
)
|
||||
.aligned()
|
||||
.into_any(),
|
||||
),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// fn retrieve_context_status(&self, cx: &mut ViewContext<Self>) -> String {
|
||||
// let project = self.project.clone();
|
||||
// if let Some(semantic_index) = self.semantic_index.clone() {
|
||||
// let status = semantic_index.update(cx, |index, cx| index.status(&project));
|
||||
// return match status {
|
||||
// // This theoretically shouldnt be a valid code path
|
||||
// // As the inline assistant cant be launched without an API key
|
||||
// // We keep it here for safety
|
||||
// semantic_index::SemanticIndexStatus::NotAuthenticated => {
|
||||
// "Not Authenticated!\nPlease ensure you have an `OPENAI_API_KEY` in your environment variables.".to_string()
|
||||
// }
|
||||
// semantic_index::SemanticIndexStatus::Indexed => {
|
||||
// "Indexing Complete!".to_string()
|
||||
// }
|
||||
// semantic_index::SemanticIndexStatus::Indexing { remaining_files, rate_limit_expiry } => {
|
||||
|
||||
// let mut status = format!("Remaining files to index for Context Retrieval: {remaining_files}");
|
||||
|
||||
// if let Some(rate_limit_expiry) = rate_limit_expiry {
|
||||
// let remaining_seconds =
|
||||
// rate_limit_expiry.duration_since(Instant::now());
|
||||
// if remaining_seconds > Duration::from_secs(0) {
|
||||
// write!(status, " (rate limit resets in {}s)", remaining_seconds.as_secs()).unwrap();
|
||||
// }
|
||||
// }
|
||||
// status
|
||||
// }
|
||||
// semantic_index::SemanticIndexStatus::NotIndexed => {
|
||||
// "Not Indexed for Context Retrieval".to_string()
|
||||
// }
|
||||
// };
|
||||
// }
|
||||
|
||||
// "".to_string()
|
||||
// }
|
||||
|
||||
fn toggle_include_conversation(
|
||||
&mut self,
|
||||
_: &ToggleIncludeConversation,
|
||||
|
@ -2929,6 +3361,7 @@ struct PendingInlineAssist {
|
|||
inline_assistant: Option<(BlockId, ViewHandle<InlineAssistant>)>,
|
||||
codegen: ModelHandle<Codegen>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
project: WeakModelHandle<Project>,
|
||||
}
|
||||
|
||||
fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
|
||||
|
|
|
@ -1,8 +1,60 @@
|
|||
use crate::codegen::CodegenKind;
|
||||
use gpui::AsyncAppContext;
|
||||
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
||||
use semantic_index::SearchResult;
|
||||
use std::cmp::{self, Reverse};
|
||||
use std::fmt::Write;
|
||||
use std::ops::Range;
|
||||
use std::path::PathBuf;
|
||||
use tiktoken_rs::ChatCompletionRequestMessage;
|
||||
|
||||
pub struct PromptCodeSnippet {
|
||||
path: Option<PathBuf>,
|
||||
language_name: Option<String>,
|
||||
content: String,
|
||||
}
|
||||
|
||||
impl PromptCodeSnippet {
|
||||
pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self {
|
||||
let (content, language_name, file_path) =
|
||||
search_result.buffer.read_with(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
let content = snapshot
|
||||
.text_for_range(search_result.range.clone())
|
||||
.collect::<String>();
|
||||
|
||||
let language_name = buffer
|
||||
.language()
|
||||
.and_then(|language| Some(language.name().to_string()));
|
||||
|
||||
let file_path = buffer
|
||||
.file()
|
||||
.and_then(|file| Some(file.path().to_path_buf()));
|
||||
|
||||
(content, language_name, file_path)
|
||||
});
|
||||
|
||||
PromptCodeSnippet {
|
||||
path: file_path,
|
||||
language_name,
|
||||
content,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToString for PromptCodeSnippet {
|
||||
fn to_string(&self) -> String {
|
||||
let path = self
|
||||
.path
|
||||
.as_ref()
|
||||
.and_then(|path| Some(path.to_string_lossy().to_string()))
|
||||
.unwrap_or("".to_string());
|
||||
let language_name = self.language_name.clone().unwrap_or("".to_string());
|
||||
let content = self.content.clone();
|
||||
|
||||
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
|
||||
|
@ -121,17 +173,25 @@ pub fn generate_content_prompt(
|
|||
buffer: &BufferSnapshot,
|
||||
range: Range<impl ToOffset>,
|
||||
kind: CodegenKind,
|
||||
search_results: Vec<PromptCodeSnippet>,
|
||||
model: &str,
|
||||
) -> String {
|
||||
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
|
||||
const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
|
||||
|
||||
let mut prompts = Vec::new();
|
||||
let range = range.to_offset(buffer);
|
||||
let mut prompt = String::new();
|
||||
|
||||
// General Preamble
|
||||
if let Some(language_name) = language_name {
|
||||
writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
|
||||
prompts.push(format!("You're an expert {language_name} engineer.\n"));
|
||||
} else {
|
||||
writeln!(prompt, "You're an expert engineer.\n").unwrap();
|
||||
prompts.push("You're an expert engineer.\n".to_string());
|
||||
}
|
||||
|
||||
// Snippets
|
||||
let mut snippet_position = prompts.len() - 1;
|
||||
|
||||
let mut content = String::new();
|
||||
content.extend(buffer.text_for_range(0..range.start));
|
||||
if range.start == range.end {
|
||||
|
@ -145,59 +205,99 @@ pub fn generate_content_prompt(
|
|||
}
|
||||
content.extend(buffer.text_for_range(range.end..buffer.len()));
|
||||
|
||||
writeln!(
|
||||
prompt,
|
||||
"The file you are currently working on has the following content:"
|
||||
)
|
||||
.unwrap();
|
||||
prompts.push("The file you are currently working on has the following content:\n".to_string());
|
||||
|
||||
if let Some(language_name) = language_name {
|
||||
let language_name = language_name.to_lowercase();
|
||||
writeln!(prompt, "```{language_name}\n{content}\n```").unwrap();
|
||||
prompts.push(format!("```{language_name}\n{content}\n```"));
|
||||
} else {
|
||||
writeln!(prompt, "```\n{content}\n```").unwrap();
|
||||
prompts.push(format!("```\n{content}\n```"));
|
||||
}
|
||||
|
||||
match kind {
|
||||
CodegenKind::Generate { position: _ } => {
|
||||
writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Assume the cursor is located where the `<|START|` marker is."
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
|
||||
prompts
|
||||
.push("Assume the cursor is located where the `<|START|` marker is.".to_string());
|
||||
prompts.push(
|
||||
"Text can't be replaced, so assume your answer will be inserted at the cursor."
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
.to_string(),
|
||||
);
|
||||
prompts.push(format!(
|
||||
"Generate text based on the users prompt: {user_prompt}"
|
||||
)
|
||||
.unwrap();
|
||||
));
|
||||
}
|
||||
CodegenKind::Transform { range: _ } => {
|
||||
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Modify the users code selected text based upon the users prompt: {user_prompt}"
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file."
|
||||
)
|
||||
.unwrap();
|
||||
prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
|
||||
prompts.push(format!(
|
||||
"Modify the users code selected text based upon the users prompt: '{user_prompt}'"
|
||||
));
|
||||
prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(language_name) = language_name {
|
||||
writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap();
|
||||
prompts.push(format!(
|
||||
"Your answer MUST always and only be valid {language_name}"
|
||||
));
|
||||
}
|
||||
writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap();
|
||||
writeln!(prompt, "Never make remarks about the output.").unwrap();
|
||||
prompts.push("Never make remarks about the output.".to_string());
|
||||
prompts.push("Do not return any text, except the generated code.".to_string());
|
||||
prompts.push("Do not wrap your text in a Markdown block".to_string());
|
||||
|
||||
prompt
|
||||
let current_messages = [ChatCompletionRequestMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some(prompts.join("\n")),
|
||||
function_call: None,
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let mut remaining_token_count = if let Ok(current_token_count) =
|
||||
tiktoken_rs::num_tokens_from_messages(model, ¤t_messages)
|
||||
{
|
||||
let max_token_count = tiktoken_rs::model::get_context_size(model);
|
||||
let intermediate_token_count = max_token_count - current_token_count;
|
||||
|
||||
if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION {
|
||||
0
|
||||
} else {
|
||||
intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION
|
||||
}
|
||||
} else {
|
||||
// If tiktoken fails to count token count, assume we have no space remaining.
|
||||
0
|
||||
};
|
||||
|
||||
// TODO:
|
||||
// - add repository name to snippet
|
||||
// - add file path
|
||||
// - add language
|
||||
if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
|
||||
let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
|
||||
|
||||
for search_result in search_results {
|
||||
let mut snippet_prompt = template.to_string();
|
||||
let snippet = search_result.to_string();
|
||||
writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap();
|
||||
|
||||
let token_count = encoding
|
||||
.encode_with_special_tokens(snippet_prompt.as_str())
|
||||
.len();
|
||||
if token_count <= remaining_token_count {
|
||||
if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
|
||||
prompts.insert(snippet_position, snippet_prompt);
|
||||
snippet_position += 1;
|
||||
remaining_token_count -= token_count;
|
||||
// If you have already added the template to the prompt, remove the template.
|
||||
template = "";
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prompts.join("\n")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -1199,6 +1199,15 @@ pub struct InlineAssistantStyle {
|
|||
pub disabled_editor: FieldEditor,
|
||||
pub pending_edit_background: Color,
|
||||
pub include_conversation: ToggleIconButtonStyle,
|
||||
pub retrieve_context: ToggleIconButtonStyle,
|
||||
pub context_status: ContextStatusStyle,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Default, JsonSchema)]
|
||||
pub struct ContextStatusStyle {
|
||||
pub error_icon: Icon,
|
||||
pub in_progress_icon: Icon,
|
||||
pub complete_icon: Icon,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Default, JsonSchema)]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue