assistant2: Wire up context picker with inline assist (#22106)

This PR wire up the context picker with the inline assist.

UI is not finalized.

Release Notes:

- N/A

---------

Co-authored-by: Richard <richard@zed.dev>
Co-authored-by: Agus <agus@zed.dev>
This commit is contained in:
Marshall Bowers 2024-12-16 15:46:28 -05:00 committed by GitHub
parent 082469e173
commit 4bf005ef52
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 391 additions and 241 deletions

View file

@ -1,3 +1,8 @@
use crate::context::attach_context_to_message;
use crate::context_store::ContextStore;
use crate::context_strip::ContextStrip;
use crate::thread_store::ThreadStore;
use crate::AssistantPanel;
use crate::{
assistant_settings::AssistantSettings,
prompts::PromptBuilder,
@ -24,7 +29,8 @@ use futures::{channel::mpsc, future::LocalBoxFuture, join, SinkExt, Stream, Stre
use gpui::{
anchored, deferred, point, AnyElement, AppContext, ClickEvent, CursorStyle, EventEmitter,
FocusHandle, FocusableView, FontWeight, Global, HighlightStyle, Model, ModelContext,
Subscription, Task, TextStyle, UpdateGlobal, View, ViewContext, WeakView, WindowContext,
Subscription, Task, TextStyle, UpdateGlobal, View, ViewContext, WeakModel, WeakView,
WindowContext,
};
use language::{Buffer, IndentKind, Point, Selection, TransactionId};
use language_model::{
@ -178,10 +184,16 @@ impl InlineAssistant {
) {
if let Some(editor) = item.act_as::<Editor>(cx) {
editor.update(cx, |editor, cx| {
let thread_store = workspace
.read(cx)
.panel::<AssistantPanel>(cx)
.map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
editor.push_code_action_provider(
Rc::new(AssistantCodeActionProvider {
editor: cx.view().downgrade(),
workspace: workspace.downgrade(),
thread_store,
}),
cx,
);
@ -212,7 +224,11 @@ impl InlineAssistant {
let handle_assist = |cx: &mut ViewContext<Workspace>| match inline_assist_target {
InlineAssistTarget::Editor(active_editor) => {
InlineAssistant::update_global(cx, |assistant, cx| {
assistant.assist(&active_editor, cx.view().downgrade(), cx)
let thread_store = workspace
.panel::<AssistantPanel>(cx)
.map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
assistant.assist(&active_editor, cx.view().downgrade(), thread_store, cx)
})
}
InlineAssistTarget::Terminal(active_terminal) => {
@ -265,6 +281,7 @@ impl InlineAssistant {
&mut self,
editor: &View<Editor>,
workspace: WeakView<Workspace>,
thread_store: Option<WeakModel<ThreadStore>>,
cx: &mut WindowContext,
) {
let (snapshot, initial_selections) = editor.update(cx, |editor, cx| {
@ -343,11 +360,13 @@ impl InlineAssistant {
let mut assist_to_focus = None;
for range in codegen_ranges {
let assist_id = self.next_assist_id.post_inc();
let context_store = cx.new_model(|_cx| ContextStore::new());
let codegen = cx.new_model(|cx| {
Codegen::new(
editor.read(cx).buffer().clone(),
range.clone(),
None,
context_store.clone(),
self.telemetry.clone(),
self.prompt_builder.clone(),
cx,
@ -363,6 +382,9 @@ impl InlineAssistant {
prompt_buffer.clone(),
codegen.clone(),
self.fs.clone(),
context_store,
workspace.clone(),
thread_store.clone(),
cx,
)
});
@ -430,6 +452,7 @@ impl InlineAssistant {
initial_transaction_id: Option<TransactionId>,
focus: bool,
workspace: WeakView<Workspace>,
thread_store: Option<WeakModel<ThreadStore>>,
cx: &mut WindowContext,
) -> InlineAssistId {
let assist_group_id = self.next_assist_group_id.post_inc();
@ -445,11 +468,14 @@ impl InlineAssistant {
range.end = range.end.bias_right(&snapshot);
}
let context_store = cx.new_model(|_cx| ContextStore::new());
let codegen = cx.new_model(|cx| {
Codegen::new(
editor.read(cx).buffer().clone(),
range.clone(),
initial_transaction_id,
context_store.clone(),
self.telemetry.clone(),
self.prompt_builder.clone(),
cx,
@ -465,6 +491,9 @@ impl InlineAssistant {
prompt_buffer.clone(),
codegen.clone(),
self.fs.clone(),
context_store,
workspace.clone(),
thread_store,
cx,
)
});
@ -1456,6 +1485,7 @@ enum PromptEditorEvent {
struct PromptEditor {
id: InlineAssistId,
editor: View<Editor>,
context_strip: View<ContextStrip>,
language_model_selector: View<LanguageModelSelector>,
edited_since_done: bool,
gutter_dimensions: Arc<Mutex<GutterDimensions>>,
@ -1473,11 +1503,7 @@ impl EventEmitter<PromptEditorEvent> for PromptEditor {}
impl Render for PromptEditor {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let gutter_dimensions = *self.gutter_dimensions.lock();
let mut buttons = vec![Button::new("add-context", "Add Context")
.style(ButtonStyle::Filled)
.icon(IconName::Plus)
.icon_position(IconPosition::Start)
.into_any_element()];
let mut buttons = Vec::new();
let codegen = self.codegen.read(cx);
if codegen.alternative_count(cx) > 1 {
buttons.push(self.render_cycle_controls(cx));
@ -1570,91 +1596,114 @@ impl Render for PromptEditor {
}
});
h_flex()
.key_context("PromptEditor")
.bg(cx.theme().colors().editor_background)
.block_mouse_down()
.cursor(CursorStyle::Arrow)
v_flex()
.border_y_1()
.border_color(cx.theme().status().info_border)
.size_full()
.py(cx.line_height() / 2.5)
.on_action(cx.listener(Self::confirm))
.on_action(cx.listener(Self::cancel))
.on_action(cx.listener(Self::move_up))
.on_action(cx.listener(Self::move_down))
.capture_action(cx.listener(Self::cycle_prev))
.capture_action(cx.listener(Self::cycle_next))
.child(
h_flex()
.w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
.justify_center()
.gap_2()
.child(LanguageModelSelectorPopoverMenu::new(
self.language_model_selector.clone(),
IconButton::new("context", IconName::SettingsAlt)
.shape(IconButtonShape::Square)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.tooltip(move |cx| {
Tooltip::with_meta(
format!(
"Using {}",
LanguageModelRegistry::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()),
),
None,
"Change Model",
cx,
)
}),
))
.map(|el| {
let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx) else {
return el;
};
let error_message = SharedString::from(error.to_string());
if error.error_code() == proto::ErrorCode::RateLimitExceeded
&& cx.has_flag::<ZedPro>()
{
el.child(
v_flex()
.child(
IconButton::new("rate-limit-error", IconName::XCircle)
.toggle_state(self.show_rate_limit_notice)
.shape(IconButtonShape::Square)
.icon_size(IconSize::Small)
.on_click(cx.listener(Self::toggle_rate_limit_notice)),
)
.children(self.show_rate_limit_notice.then(|| {
deferred(
anchored()
.position_mode(gpui::AnchoredPositionMode::Local)
.position(point(px(0.), px(24.)))
.anchor(gpui::AnchorCorner::TopLeft)
.child(self.render_rate_limit_notice(cx)),
.key_context("PromptEditor")
.bg(cx.theme().colors().editor_background)
.block_mouse_down()
.cursor(CursorStyle::Arrow)
.on_action(cx.listener(Self::confirm))
.on_action(cx.listener(Self::cancel))
.on_action(cx.listener(Self::move_up))
.on_action(cx.listener(Self::move_down))
.capture_action(cx.listener(Self::cycle_prev))
.capture_action(cx.listener(Self::cycle_next))
.child(
h_flex()
.w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
.justify_center()
.gap_2()
.child(LanguageModelSelectorPopoverMenu::new(
self.language_model_selector.clone(),
IconButton::new("context", IconName::SettingsAlt)
.shape(IconButtonShape::Square)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.tooltip(move |cx| {
Tooltip::with_meta(
format!(
"Using {}",
LanguageModelRegistry::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()),
),
None,
"Change Model",
cx,
)
})),
)
} else {
el.child(
div()
.id("error")
.tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
.child(
Icon::new(IconName::XCircle)
.size(IconSize::Small)
.color(Color::Error),
),
)
}
}),
}),
))
.map(|el| {
let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx)
else {
return el;
};
let error_message = SharedString::from(error.to_string());
if error.error_code() == proto::ErrorCode::RateLimitExceeded
&& cx.has_flag::<ZedPro>()
{
el.child(
v_flex()
.child(
IconButton::new(
"rate-limit-error",
IconName::XCircle,
)
.toggle_state(self.show_rate_limit_notice)
.shape(IconButtonShape::Square)
.icon_size(IconSize::Small)
.on_click(
cx.listener(Self::toggle_rate_limit_notice),
),
)
.children(self.show_rate_limit_notice.then(|| {
deferred(
anchored()
.position_mode(
gpui::AnchoredPositionMode::Local,
)
.position(point(px(0.), px(24.)))
.anchor(gpui::AnchorCorner::TopLeft)
.child(self.render_rate_limit_notice(cx)),
)
})),
)
} else {
el.child(
div()
.id("error")
.tooltip(move |cx| {
Tooltip::text(error_message.clone(), cx)
})
.child(
Icon::new(IconName::XCircle)
.size(IconSize::Small)
.color(Color::Error),
),
)
}
}),
)
.child(div().flex_1().child(self.render_editor(cx)))
.child(h_flex().gap_2().pr_6().children(buttons)),
)
.child(
h_flex()
.child(
h_flex()
.w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
.justify_center()
.gap_2(),
)
.child(self.context_strip.clone()),
)
.child(div().flex_1().child(self.render_editor(cx)))
.child(h_flex().gap_2().pr_6().children(buttons))
}
}
@ -1675,6 +1724,9 @@ impl PromptEditor {
prompt_buffer: Model<MultiBuffer>,
codegen: Model<Codegen>,
fs: Arc<dyn Fs>,
context_store: Model<ContextStore>,
workspace: WeakView<Workspace>,
thread_store: Option<WeakModel<ThreadStore>>,
cx: &mut ViewContext<Self>,
) -> Self {
let prompt_editor = cx.new_view(|cx| {
@ -1699,6 +1751,9 @@ impl PromptEditor {
let mut this = Self {
id,
editor: prompt_editor,
context_strip: cx.new_view(|cx| {
ContextStrip::new(context_store, workspace.clone(), thread_store.clone(), cx)
}),
language_model_selector: cx.new_view(|cx| {
let fs = fs.clone();
LanguageModelSelector::new(
@ -2293,6 +2348,7 @@ pub struct Codegen {
buffer: Model<MultiBuffer>,
range: Range<Anchor>,
initial_transaction_id: Option<TransactionId>,
context_store: Model<ContextStore>,
telemetry: Arc<Telemetry>,
builder: Arc<PromptBuilder>,
is_insertion: bool,
@ -2303,6 +2359,7 @@ impl Codegen {
buffer: Model<MultiBuffer>,
range: Range<Anchor>,
initial_transaction_id: Option<TransactionId>,
context_store: Model<ContextStore>,
telemetry: Arc<Telemetry>,
builder: Arc<PromptBuilder>,
cx: &mut ModelContext<Self>,
@ -2312,6 +2369,7 @@ impl Codegen {
buffer.clone(),
range.clone(),
false,
Some(context_store.clone()),
Some(telemetry.clone()),
builder.clone(),
cx,
@ -2326,6 +2384,7 @@ impl Codegen {
buffer,
range,
initial_transaction_id,
context_store,
telemetry,
builder,
};
@ -2398,6 +2457,7 @@ impl Codegen {
self.buffer.clone(),
self.range.clone(),
false,
Some(self.context_store.clone()),
Some(self.telemetry.clone()),
self.builder.clone(),
cx,
@ -2477,6 +2537,7 @@ pub struct CodegenAlternative {
status: CodegenStatus,
generation: Task<()>,
diff: Diff,
context_store: Option<Model<ContextStore>>,
telemetry: Option<Arc<Telemetry>>,
_subscription: gpui::Subscription,
builder: Arc<PromptBuilder>,
@ -2515,6 +2576,7 @@ impl CodegenAlternative {
buffer: Model<MultiBuffer>,
range: Range<Anchor>,
active: bool,
context_store: Option<Model<ContextStore>>,
telemetry: Option<Arc<Telemetry>>,
builder: Arc<PromptBuilder>,
cx: &mut ModelContext<Self>,
@ -2552,6 +2614,7 @@ impl CodegenAlternative {
status: CodegenStatus::Idle,
generation: Task::ready(()),
diff: Diff::default(),
context_store,
telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
builder,
@ -2637,7 +2700,11 @@ impl CodegenAlternative {
Ok(())
}
fn build_request(&self, user_prompt: String, cx: &AppContext) -> Result<LanguageModelRequest> {
fn build_request(
&self,
user_prompt: String,
cx: &mut AppContext,
) -> Result<LanguageModelRequest> {
let buffer = self.buffer.read(cx).snapshot(cx);
let language = buffer.language_at(self.range.start);
let language_name = if let Some(language) = language.as_ref() {
@ -2670,15 +2737,24 @@ impl CodegenAlternative {
.generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: Vec::new(),
cache: false,
};
if let Some(context_store) = &self.context_store {
let context = context_store.update(cx, |this, _cx| this.context().clone());
attach_context_to_message(&mut request_message, context);
}
request_message.content.push(prompt.into());
Ok(LanguageModelRequest {
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
cache: false,
}],
messages: vec![request_message],
})
}
@ -3273,6 +3349,7 @@ where
struct AssistantCodeActionProvider {
editor: WeakView<Editor>,
workspace: WeakView<Workspace>,
thread_store: Option<WeakModel<ThreadStore>>,
}
impl CodeActionProvider for AssistantCodeActionProvider {
@ -3337,6 +3414,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
) -> Task<Result<ProjectTransaction>> {
let editor = self.editor.clone();
let workspace = self.workspace.clone();
let thread_store = self.thread_store.clone();
cx.spawn(|mut cx| async move {
let editor = editor.upgrade().context("editor was released")?;
let range = editor
@ -3384,6 +3462,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
None,
true,
workspace,
thread_store,
cx,
);
assistant.start_assist(assist_id, cx);
@ -3469,6 +3548,7 @@ mod tests {
range.clone(),
true,
None,
None,
prompt_builder,
cx,
)
@ -3533,6 +3613,7 @@ mod tests {
range.clone(),
true,
None,
None,
prompt_builder,
cx,
)
@ -3600,6 +3681,7 @@ mod tests {
range.clone(),
true,
None,
None,
prompt_builder,
cx,
)
@ -3666,6 +3748,7 @@ mod tests {
range.clone(),
true,
None,
None,
prompt_builder,
cx,
)
@ -3721,6 +3804,7 @@ mod tests {
range.clone(),
false,
None,
None,
prompt_builder,
cx,
)