zeta: Refresh LLM token in case it expired (#21796)

Release Notes:

- N/A

---------

Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Bennet <bennet@zed.dev>
This commit is contained in:
Thorsten Ball 2024-12-10 14:12:49 +01:00 committed by GitHub
parent 09006aaee9
commit 96499b7b25
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 209 additions and 132 deletions

View file

@ -59,6 +59,11 @@ impl FeatureFlag for ToolUseFeatureFlag {
}
}
pub struct ZetaFeatureFlag;
impl FeatureFlag for ZetaFeatureFlag {
const NAME: &'static str = "zeta";
}
pub struct Remoting {}
impl FeatureFlag for Remoting {
const NAME: &'static str = "remoting";

View file

@ -1,7 +1,7 @@
use anyhow::Result;
use copilot::{Copilot, Status};
use editor::{scroll::Autoscroll, Editor};
use feature_flags::FeatureFlagAppExt;
use feature_flags::{FeatureFlagAppExt, ZetaFeatureFlag};
use fs::Fs;
use gpui::{
div, Action, AnchorCorner, AppContext, AsyncWindowContext, Entity, IntoElement, ParentElement,
@ -199,7 +199,7 @@ impl Render for InlineCompletionButton {
}
InlineCompletionProvider::Zeta => {
if !cx.is_staff() {
if !cx.has_flag::<ZetaFeatureFlag>() {
return div();
}

View file

@ -4,9 +4,9 @@ use client::Client;
use collections::HashMap;
use copilot::{Copilot, CopilotCompletionProvider};
use editor::{Editor, EditorMode};
use feature_flags::FeatureFlagAppExt;
use feature_flags::{FeatureFlagAppExt, ZetaFeatureFlag};
use gpui::{AnyWindowHandle, AppContext, Context, ViewContext, WeakView};
use language::language_settings::all_language_settings;
use language::language_settings::{all_language_settings, InlineCompletionProvider};
use settings::SettingsStore;
use supermaven::{Supermaven, SupermavenCompletionProvider};
@ -49,22 +49,45 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
});
}
cx.observe_global::<SettingsStore>(move |cx| {
let new_provider = all_language_settings(None, cx).inline_completions.provider;
if new_provider != provider {
provider = new_provider;
for (editor, window) in editors.borrow().iter() {
_ = window.update(cx, |_window, cx| {
_ = editor.update(cx, |editor, cx| {
assign_inline_completion_provider(editor, provider, &client, cx);
})
});
cx.observe_flag::<ZetaFeatureFlag, _>({
let editors = editors.clone();
let client = client.clone();
move |_flag, cx| {
let provider = all_language_settings(None, cx).inline_completions.provider;
assign_inline_completion_providers(&editors, provider, &client, cx)
}
})
.detach();
cx.observe_global::<SettingsStore>({
let editors = editors.clone();
let client = client.clone();
move |cx| {
let new_provider = all_language_settings(None, cx).inline_completions.provider;
if new_provider != provider {
provider = new_provider;
assign_inline_completion_providers(&editors, provider, &client, cx)
}
}
})
.detach();
}
fn assign_inline_completion_providers(
editors: &Rc<RefCell<HashMap<WeakView<Editor>, AnyWindowHandle>>>,
provider: InlineCompletionProvider,
client: &Arc<Client>,
cx: &mut AppContext,
) {
for (editor, window) in editors.borrow().iter() {
_ = window.update(cx, |_window, cx| {
_ = editor.update(cx, |editor, cx| {
assign_inline_completion_provider(editor, provider, &client, cx);
})
});
}
}
fn register_backward_compatible_actions(editor: &mut Editor, cx: &ViewContext<Editor>) {
// We renamed some of these actions to not be copilot-specific, but that
// would have not been backwards-compatible. So here we are re-registering
@ -129,7 +152,7 @@ fn assign_inline_completion_provider(
}
}
language::language_settings::InlineCompletionProvider::Zeta => {
if cx.is_staff() {
if cx.has_flag::<ZetaFeatureFlag>() {
let zeta = zeta::Zeta::register(client.clone(), cx);
if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
if buffer.read(cx).file().is_some() {

View file

@ -13,7 +13,7 @@ use language::{
Point, ToOffset, ToPoint,
};
use language_models::LlmApiToken;
use rpc::{PredictEditsParams, PredictEditsResponse};
use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
use std::{
borrow::Cow,
cmp,
@ -269,8 +269,6 @@ impl Zeta {
cx.spawn(|this, mut cx| async move {
let start = std::time::Instant::now();
let token = llm_token.acquire(&client).await?;
let mut input_events = String::new();
for event in events {
if !input_events.is_empty() {
@ -283,130 +281,26 @@ impl Zeta {
log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
let http_client = client.http_client();
let body = PredictEditsParams {
input_events: input_events.clone(),
input_excerpt: input_excerpt.clone(),
};
let request_builder = http_client::Request::builder();
let request = request_builder
.method(Method::POST)
.uri(
client
.http_client()
.build_zed_llm_url("/predict_edits", &[])?
.as_ref(),
)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", token))
.body(serde_json::to_string(&body)?.into())?;
let mut response = http_client.send(request).await?;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
if !response.status().is_success() {
return Err(anyhow!(
"error predicting edits.\nStatus: {:?}\nBody: {}",
response.status(),
body
));
}
let response = serde_json::from_str::<PredictEditsResponse>(&body)?;
let response = Self::perform_predict_edits(&client, llm_token, body).await?;
let output_excerpt = response.output_excerpt;
log::debug!("prediction took: {:?}", start.elapsed());
log::debug!("completion response: {}", output_excerpt);
let content = output_excerpt.replace(CURSOR_MARKER, "");
let mut new_text = content.as_str();
let codefence_start = new_text
.find(EDITABLE_REGION_START_MARKER)
.context("could not find start marker")?;
new_text = &new_text[codefence_start..];
let newline_ix = new_text.find('\n').context("could not find newline")?;
new_text = &new_text[newline_ix + 1..];
let codefence_end = new_text
.rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
.context("could not find end marker")?;
new_text = &new_text[..codefence_end];
log::debug!("sanitized completion response: {}", new_text);
let old_text = snapshot
.text_for_range(excerpt_range.clone())
.collect::<String>();
let diff = similar::TextDiff::from_chars(old_text.as_str(), new_text);
let mut edits: Vec<(Range<usize>, String)> = Vec::new();
let mut old_start = excerpt_range.start;
for change in diff.iter_all_changes() {
let value = change.value();
match change.tag() {
similar::ChangeTag::Equal => {
old_start += value.len();
}
similar::ChangeTag::Delete => {
let old_end = old_start + value.len();
if let Some((last_old_range, _)) = edits.last_mut() {
if last_old_range.end == old_start {
last_old_range.end = old_end;
} else {
edits.push((old_start..old_end, String::new()));
}
} else {
edits.push((old_start..old_end, String::new()));
}
old_start = old_end;
}
similar::ChangeTag::Insert => {
if let Some((last_old_range, last_new_text)) = edits.last_mut() {
if last_old_range.end == old_start {
last_new_text.push_str(value);
} else {
edits.push((old_start..old_start, value.into()));
}
} else {
edits.push((old_start..old_start, value.into()));
}
}
}
}
let edits = edits
.into_iter()
.map(|(mut old_range, new_text)| {
let prefix_len = common_prefix(
snapshot.chars_for_range(old_range.clone()),
new_text.chars(),
);
old_range.start += prefix_len;
let suffix_len = common_prefix(
snapshot.reversed_chars_for_range(old_range.clone()),
new_text[prefix_len..].chars().rev(),
);
old_range.end = old_range.end.saturating_sub(suffix_len);
let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
(
snapshot.anchor_after(old_range.start)
..snapshot.anchor_before(old_range.end),
new_text,
)
})
.collect();
let inline_completion = InlineCompletion {
id: InlineCompletionId::new(),
path,
let inline_completion = Self::process_completion_response(
output_excerpt,
&snapshot,
excerpt_range,
edits,
snapshot,
input_events: input_events.into(),
input_excerpt: input_excerpt.into(),
output_excerpt: output_excerpt.into(),
};
path,
input_events,
input_excerpt,
)?;
this.update(&mut cx, |this, cx| {
this.recent_completions
.push_front(inline_completion.clone());
@ -420,6 +314,161 @@ impl Zeta {
})
}
async fn perform_predict_edits(
client: &Arc<Client>,
llm_token: LlmApiToken,
body: PredictEditsParams,
) -> Result<PredictEditsResponse> {
let http_client = client.http_client();
let mut token = llm_token.acquire(client).await?;
let mut did_retry = false;
loop {
let request_builder = http_client::Request::builder();
let request = request_builder
.method(Method::POST)
.uri(
http_client
.build_zed_llm_url("/predict_edits", &[])?
.as_ref(),
)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", token))
.body(serde_json::to_string(&body)?.into())?;
let mut response = http_client.send(request).await?;
if response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Ok(serde_json::from_str(&body)?);
} else if !did_retry
&& response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
{
did_retry = true;
token = llm_token.refresh(client).await?;
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"error predicting edits.\nStatus: {:?}\nBody: {}",
response.status(),
body
));
}
}
}
fn process_completion_response(
output_excerpt: String,
snapshot: &BufferSnapshot,
excerpt_range: Range<usize>,
path: Arc<Path>,
input_events: String,
input_excerpt: String,
) -> Result<InlineCompletion> {
let content = output_excerpt.replace(CURSOR_MARKER, "");
let codefence_start = content
.find(EDITABLE_REGION_START_MARKER)
.context("could not find start marker")?;
let content = &content[codefence_start..];
let newline_ix = content.find('\n').context("could not find newline")?;
let content = &content[newline_ix + 1..];
let codefence_end = content
.rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
.context("could not find end marker")?;
let new_text = &content[..codefence_end];
let old_text = snapshot
.text_for_range(excerpt_range.clone())
.collect::<String>();
let edits = Self::compute_edits(old_text, new_text, excerpt_range.start, snapshot);
Ok(InlineCompletion {
id: InlineCompletionId::new(),
path,
excerpt_range,
edits: edits.into(),
snapshot: snapshot.clone(),
input_events: input_events.into(),
input_excerpt: input_excerpt.into(),
output_excerpt: output_excerpt.into(),
})
}
fn compute_edits(
old_text: String,
new_text: &str,
offset: usize,
snapshot: &BufferSnapshot,
) -> Vec<(Range<Anchor>, String)> {
let diff = similar::TextDiff::from_chars(old_text.as_str(), new_text);
let mut edits: Vec<(Range<usize>, String)> = Vec::new();
let mut old_start = offset;
for change in diff.iter_all_changes() {
let value = change.value();
match change.tag() {
similar::ChangeTag::Equal => {
old_start += value.len();
}
similar::ChangeTag::Delete => {
let old_end = old_start + value.len();
if let Some((last_old_range, _)) = edits.last_mut() {
if last_old_range.end == old_start {
last_old_range.end = old_end;
} else {
edits.push((old_start..old_end, String::new()));
}
} else {
edits.push((old_start..old_end, String::new()));
}
old_start = old_end;
}
similar::ChangeTag::Insert => {
if let Some((last_old_range, last_new_text)) = edits.last_mut() {
if last_old_range.end == old_start {
last_new_text.push_str(value);
} else {
edits.push((old_start..old_start, value.into()));
}
} else {
edits.push((old_start..old_start, value.into()));
}
}
}
}
edits
.into_iter()
.map(|(mut old_range, new_text)| {
let prefix_len = common_prefix(
snapshot.chars_for_range(old_range.clone()),
new_text.chars(),
);
old_range.start += prefix_len;
let suffix_len = common_prefix(
snapshot.reversed_chars_for_range(old_range.clone()),
new_text[prefix_len..].chars().rev(),
);
old_range.end = old_range.end.saturating_sub(suffix_len);
let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
(
snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end),
new_text,
)
})
.collect()
}
pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
self.rated_completions.contains(&completion_id)
}