Start on adding support for editing via the assistant panel (#14795)

Note that this shouldn't have any visible user-facing behavior yet. The
feature is incomplete but we wanna merge early to avoid a long-running
branch.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-19 11:13:15 +02:00 committed by GitHub
parent 87457f9ae8
commit 4d177918c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
44 changed files with 1999 additions and 968 deletions

View file

@ -20,11 +20,10 @@ use crate::{
};
use anyhow::Result;
use client::Client;
use futures::{future::BoxFuture, stream::BoxStream};
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
use settings::{Settings, SettingsStore};
use std::time::Duration;
use std::{any::Any, sync::Arc};
use std::{any::Any, pin::Pin, sync::Arc, task::Poll, time::Duration};
/// Choose which model to use for openai provider.
/// If the model is not available, try to use the first available model, or fallback to the original model.
@ -55,10 +54,21 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
}
pub struct CompletionResponse {
pub inner: BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>,
inner: BoxStream<'static, Result<String>>,
_lock: SemaphoreGuardArc,
}
impl futures::Stream for CompletionResponse {
type Item = Result<String>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}
pub trait LanguageModelCompletionProvider: Send + Sync {
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>;
fn settings_version(&self) -> usize;
@ -72,7 +82,7 @@ pub trait LanguageModelCompletionProvider: Send + Sync {
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>>;
fn complete(
fn stream_completion(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
@ -136,20 +146,34 @@ impl CompletionProvider {
self.provider.read().count_tokens(request, cx)
}
pub fn complete(
pub fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> Task<CompletionResponse> {
) -> Task<Result<CompletionResponse>> {
let rate_limiter = self.request_limiter.clone();
let provider = self.provider.clone();
cx.background_executor().spawn(async move {
cx.foreground_executor().spawn(async move {
let lock = rate_limiter.acquire_arc().await;
let response = provider.read().complete(request);
CompletionResponse {
let response = provider.read().stream_completion(request);
let response = response.await?;
Ok(CompletionResponse {
inner: response,
_lock: lock,
})
})
}
pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
let response = self.stream_completion(request, cx);
cx.foreground_executor().spawn(async move {
let mut chunks = response.await?;
let mut completion = String::new();
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
completion.push_str(&chunk);
}
Ok(completion)
})
}
}
@ -300,7 +324,7 @@ mod tests {
// Enqueue some requests
for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
let response = provider.complete(
let response = provider.stream_completion(
LanguageModelRequest {
temperature: i as f32 / 10.0,
..Default::default()
@ -309,8 +333,7 @@ mod tests {
);
cx.background_executor()
.spawn(async move {
let response = response.await;
let mut stream = response.inner.await.unwrap();
let mut stream = response.await.unwrap();
while let Some(message) = stream.next().await {
message.unwrap();
}
@ -326,7 +349,7 @@ mod tests {
// Get the first completion request that is in flight and mark it as completed.
let completion = fake_provider
.running_completions()
.pending_completions()
.into_iter()
.next()
.unwrap();
@ -347,7 +370,7 @@ mod tests {
);
// Mark all completion requests as finished that are in flight.
for request in fake_provider.running_completions() {
for request in fake_provider.pending_completions() {
fake_provider.finish_completion(&request);
}
@ -362,7 +385,7 @@ mod tests {
);
// Finish all remaining completion requests.
for request in fake_provider.running_completions() {
for request in fake_provider.pending_completions() {
fake_provider.finish_completion(&request);
}