More resilient eval (#32257)
Bubbles up rate limit information so that we can retry after a certain duration if needed higher up in the stack. Also caps the number of concurrent evals running at once to also help. Release Notes: - N/A
This commit is contained in:
parent
fa54fa80d0
commit
e4bd115a63
22 changed files with 147 additions and 56 deletions
2
.github/workflows/unit_evals.yml
vendored
2
.github/workflows/unit_evals.yml
vendored
|
@ -62,7 +62,7 @@ jobs:
|
||||||
|
|
||||||
- name: Run unit evals
|
- name: Run unit evals
|
||||||
shell: bash -euxo pipefail {0}
|
shell: bash -euxo pipefail {0}
|
||||||
run: cargo nextest run --workspace --no-fail-fast --features eval --no-capture -E 'test(::eval_)' --test-threads 1
|
run: cargo nextest run --workspace --no-fail-fast --features eval --no-capture -E 'test(::eval_)'
|
||||||
env:
|
env:
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||||
|
|
||||||
|
|
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -705,6 +705,7 @@ dependencies = [
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"settings",
|
"settings",
|
||||||
"smallvec",
|
"smallvec",
|
||||||
|
"smol",
|
||||||
"streaming_diff",
|
"streaming_diff",
|
||||||
"strsim",
|
"strsim",
|
||||||
"task",
|
"task",
|
||||||
|
|
|
@ -386,7 +386,9 @@ impl CodegenAlternative {
|
||||||
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
|
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
|
||||||
} else {
|
} else {
|
||||||
let request = self.build_request(&model, user_prompt, cx)?;
|
let request = self.build_request(&model, user_prompt, cx)?;
|
||||||
cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await)
|
cx.spawn(async move |_, cx| {
|
||||||
|
Ok(model.stream_completion_text(request.await, &cx).await?)
|
||||||
|
})
|
||||||
.boxed_local()
|
.boxed_local()
|
||||||
};
|
};
|
||||||
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
|
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
|
||||||
|
|
|
@ -1563,6 +1563,9 @@ impl Thread {
|
||||||
Err(LanguageModelCompletionError::Other(error)) => {
|
Err(LanguageModelCompletionError::Other(error)) => {
|
||||||
return Err(error);
|
return Err(error);
|
||||||
}
|
}
|
||||||
|
Err(err @ LanguageModelCompletionError::RateLimit(..)) => {
|
||||||
|
return Err(err.into());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match event {
|
match event {
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
|
@ -406,6 +407,7 @@ impl RateLimit {
|
||||||
/// <https://docs.anthropic.com/en/api/rate-limits#response-headers>
|
/// <https://docs.anthropic.com/en/api/rate-limits#response-headers>
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct RateLimitInfo {
|
pub struct RateLimitInfo {
|
||||||
|
pub retry_after: Option<Duration>,
|
||||||
pub requests: Option<RateLimit>,
|
pub requests: Option<RateLimit>,
|
||||||
pub tokens: Option<RateLimit>,
|
pub tokens: Option<RateLimit>,
|
||||||
pub input_tokens: Option<RateLimit>,
|
pub input_tokens: Option<RateLimit>,
|
||||||
|
@ -417,10 +419,11 @@ impl RateLimitInfo {
|
||||||
// Check if any rate limit headers exist
|
// Check if any rate limit headers exist
|
||||||
let has_rate_limit_headers = headers
|
let has_rate_limit_headers = headers
|
||||||
.keys()
|
.keys()
|
||||||
.any(|k| k.as_str().starts_with("anthropic-ratelimit-"));
|
.any(|k| k == "retry-after" || k.as_str().starts_with("anthropic-ratelimit-"));
|
||||||
|
|
||||||
if !has_rate_limit_headers {
|
if !has_rate_limit_headers {
|
||||||
return Self {
|
return Self {
|
||||||
|
retry_after: None,
|
||||||
requests: None,
|
requests: None,
|
||||||
tokens: None,
|
tokens: None,
|
||||||
input_tokens: None,
|
input_tokens: None,
|
||||||
|
@ -429,6 +432,11 @@ impl RateLimitInfo {
|
||||||
}
|
}
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
|
retry_after: headers
|
||||||
|
.get("retry-after")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|v| v.parse::<u64>().ok())
|
||||||
|
.map(Duration::from_secs),
|
||||||
requests: RateLimit::from_headers("requests", headers).ok(),
|
requests: RateLimit::from_headers("requests", headers).ok(),
|
||||||
tokens: RateLimit::from_headers("tokens", headers).ok(),
|
tokens: RateLimit::from_headers("tokens", headers).ok(),
|
||||||
input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
|
input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
|
||||||
|
@ -481,8 +489,8 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||||
.send(request)
|
.send(request)
|
||||||
.await
|
.await
|
||||||
.context("failed to send request to Anthropic")?;
|
.context("failed to send request to Anthropic")?;
|
||||||
if response.status().is_success() {
|
|
||||||
let rate_limits = RateLimitInfo::from_headers(response.headers());
|
let rate_limits = RateLimitInfo::from_headers(response.headers());
|
||||||
|
if response.status().is_success() {
|
||||||
let reader = BufReader::new(response.into_body());
|
let reader = BufReader::new(response.into_body());
|
||||||
let stream = reader
|
let stream = reader
|
||||||
.lines()
|
.lines()
|
||||||
|
@ -500,6 +508,8 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||||
})
|
})
|
||||||
.boxed();
|
.boxed();
|
||||||
Ok((stream, Some(rate_limits)))
|
Ok((stream, Some(rate_limits)))
|
||||||
|
} else if let Some(retry_after) = rate_limits.retry_after {
|
||||||
|
Err(AnthropicError::RateLimit(retry_after))
|
||||||
} else {
|
} else {
|
||||||
let mut body = Vec::new();
|
let mut body = Vec::new();
|
||||||
response
|
response
|
||||||
|
@ -769,6 +779,8 @@ pub struct MessageDelta {
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum AnthropicError {
|
pub enum AnthropicError {
|
||||||
|
#[error("rate limit exceeded, retry after {0:?}")]
|
||||||
|
RateLimit(Duration),
|
||||||
#[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
|
#[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
|
||||||
ApiError(ApiError),
|
ApiError(ApiError),
|
||||||
#[error("{0}")]
|
#[error("{0}")]
|
||||||
|
|
|
@ -682,11 +682,12 @@ mod tests {
|
||||||
_: &AsyncApp,
|
_: &AsyncApp,
|
||||||
) -> BoxFuture<
|
) -> BoxFuture<
|
||||||
'static,
|
'static,
|
||||||
http_client::Result<
|
Result<
|
||||||
BoxStream<
|
BoxStream<
|
||||||
'static,
|
'static,
|
||||||
http_client::Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||||
>,
|
>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
|
|
|
@ -80,6 +80,7 @@ rand.workspace = true
|
||||||
pretty_assertions.workspace = true
|
pretty_assertions.workspace = true
|
||||||
reqwest_client.workspace = true
|
reqwest_client.workspace = true
|
||||||
settings = { workspace = true, features = ["test-support"] }
|
settings = { workspace = true, features = ["test-support"] }
|
||||||
|
smol.workspace = true
|
||||||
task = { workspace = true, features = ["test-support"]}
|
task = { workspace = true, features = ["test-support"]}
|
||||||
tempfile.workspace = true
|
tempfile.workspace = true
|
||||||
theme.workspace = true
|
theme.workspace = true
|
||||||
|
|
|
@ -11,7 +11,7 @@ use client::{Client, UserStore};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use fs::FakeFs;
|
use fs::FakeFs;
|
||||||
use futures::{FutureExt, future::LocalBoxFuture};
|
use futures::{FutureExt, future::LocalBoxFuture};
|
||||||
use gpui::{AppContext, TestAppContext};
|
use gpui::{AppContext, TestAppContext, Timer};
|
||||||
use indoc::{formatdoc, indoc};
|
use indoc::{formatdoc, indoc};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult,
|
LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult,
|
||||||
|
@ -1255,8 +1255,11 @@ impl EvalAssertion {
|
||||||
}],
|
}],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let mut response = judge
|
let mut response = retry_on_rate_limit(async || {
|
||||||
.stream_completion_text(request, &cx.to_async())
|
Ok(judge
|
||||||
|
.stream_completion_text(request.clone(), &cx.to_async())
|
||||||
|
.await?)
|
||||||
|
})
|
||||||
.await?;
|
.await?;
|
||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
while let Some(chunk) = response.stream.next().await {
|
while let Some(chunk) = response.stream.next().await {
|
||||||
|
@ -1308,10 +1311,17 @@ fn eval(
|
||||||
run_eval(eval.clone(), tx.clone());
|
run_eval(eval.clone(), tx.clone());
|
||||||
|
|
||||||
let executor = gpui::background_executor();
|
let executor = gpui::background_executor();
|
||||||
|
let semaphore = Arc::new(smol::lock::Semaphore::new(32));
|
||||||
for _ in 1..iterations {
|
for _ in 1..iterations {
|
||||||
let eval = eval.clone();
|
let eval = eval.clone();
|
||||||
let tx = tx.clone();
|
let tx = tx.clone();
|
||||||
executor.spawn(async move { run_eval(eval, tx) }).detach();
|
let semaphore = semaphore.clone();
|
||||||
|
executor
|
||||||
|
.spawn(async move {
|
||||||
|
let _guard = semaphore.acquire().await;
|
||||||
|
run_eval(eval, tx)
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
}
|
}
|
||||||
drop(tx);
|
drop(tx);
|
||||||
|
|
||||||
|
@ -1577,21 +1587,31 @@ impl EditAgentTest {
|
||||||
if let Some(input_content) = eval.input_content.as_deref() {
|
if let Some(input_content) = eval.input_content.as_deref() {
|
||||||
buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
|
buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
|
||||||
}
|
}
|
||||||
let (edit_output, _) = self.agent.edit(
|
retry_on_rate_limit(async || {
|
||||||
|
self.agent
|
||||||
|
.edit(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
eval.edit_file_input.display_description,
|
eval.edit_file_input.display_description.clone(),
|
||||||
&conversation,
|
&conversation,
|
||||||
&mut cx.to_async(),
|
&mut cx.to_async(),
|
||||||
);
|
)
|
||||||
edit_output.await?
|
.0
|
||||||
|
.await
|
||||||
|
})
|
||||||
|
.await?
|
||||||
} else {
|
} else {
|
||||||
let (edit_output, _) = self.agent.overwrite(
|
retry_on_rate_limit(async || {
|
||||||
|
self.agent
|
||||||
|
.overwrite(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
eval.edit_file_input.display_description,
|
eval.edit_file_input.display_description.clone(),
|
||||||
&conversation,
|
&conversation,
|
||||||
&mut cx.to_async(),
|
&mut cx.to_async(),
|
||||||
);
|
)
|
||||||
edit_output.await?
|
.0
|
||||||
|
.await
|
||||||
|
})
|
||||||
|
.await?
|
||||||
};
|
};
|
||||||
|
|
||||||
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
|
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
|
||||||
|
@ -1613,6 +1633,26 @@ impl EditAgentTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) -> Result<R> {
|
||||||
|
loop {
|
||||||
|
match request().await {
|
||||||
|
Ok(result) => return Ok(result),
|
||||||
|
Err(err) => match err.downcast::<LanguageModelCompletionError>() {
|
||||||
|
Ok(err) => match err {
|
||||||
|
LanguageModelCompletionError::RateLimit(duration) => {
|
||||||
|
// Wait until after we are allowed to try again
|
||||||
|
eprintln!("Rate limit exceeded. Waiting for {duration:?}...",);
|
||||||
|
Timer::after(duration).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
_ => return Err(err.into()),
|
||||||
|
},
|
||||||
|
Err(err) => return Err(err),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
||||||
struct EvalAssertionOutcome {
|
struct EvalAssertionOutcome {
|
||||||
score: usize,
|
score: usize,
|
||||||
|
|
|
@ -185,6 +185,7 @@ impl LanguageModel for FakeLanguageModel {
|
||||||
'static,
|
'static,
|
||||||
Result<
|
Result<
|
||||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let (tx, rx) = mpsc::unbounded();
|
let (tx, rx) = mpsc::unbounded();
|
||||||
|
|
|
@ -22,6 +22,7 @@ use std::fmt;
|
||||||
use std::ops::{Add, Sub};
|
use std::ops::{Add, Sub};
|
||||||
use std::str::FromStr as _;
|
use std::str::FromStr as _;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use util::serde::is_default;
|
use util::serde::is_default;
|
||||||
use zed_llm_client::{
|
use zed_llm_client::{
|
||||||
|
@ -74,6 +75,8 @@ pub enum LanguageModelCompletionEvent {
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum LanguageModelCompletionError {
|
pub enum LanguageModelCompletionError {
|
||||||
|
#[error("rate limit exceeded, retry after {0:?}")]
|
||||||
|
RateLimit(Duration),
|
||||||
#[error("received bad input JSON")]
|
#[error("received bad input JSON")]
|
||||||
BadInputJson {
|
BadInputJson {
|
||||||
id: LanguageModelToolUseId,
|
id: LanguageModelToolUseId,
|
||||||
|
@ -270,6 +273,7 @@ pub trait LanguageModel: Send + Sync {
|
||||||
'static,
|
'static,
|
||||||
Result<
|
Result<
|
||||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
@ -277,7 +281,7 @@ pub trait LanguageModel: Send + Sync {
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AsyncApp,
|
cx: &AsyncApp,
|
||||||
) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
|
) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
|
||||||
let future = self.stream_completion(request, cx);
|
let future = self.stream_completion(request, cx);
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
use anyhow::Result;
|
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use smol::lock::{Semaphore, SemaphoreGuardArc};
|
use smol::lock::{Semaphore, SemaphoreGuardArc};
|
||||||
use std::{
|
use std::{
|
||||||
|
@ -8,6 +7,8 @@ use std::{
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::LanguageModelCompletionError;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct RateLimiter {
|
pub struct RateLimiter {
|
||||||
semaphore: Arc<Semaphore>,
|
semaphore: Arc<Semaphore>,
|
||||||
|
@ -36,9 +37,12 @@ impl RateLimiter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn run<'a, Fut, T>(&self, future: Fut) -> impl 'a + Future<Output = Result<T>>
|
pub fn run<'a, Fut, T>(
|
||||||
|
&self,
|
||||||
|
future: Fut,
|
||||||
|
) -> impl 'a + Future<Output = Result<T, LanguageModelCompletionError>>
|
||||||
where
|
where
|
||||||
Fut: 'a + Future<Output = Result<T>>,
|
Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
|
||||||
{
|
{
|
||||||
let guard = self.semaphore.acquire_arc();
|
let guard = self.semaphore.acquire_arc();
|
||||||
async move {
|
async move {
|
||||||
|
@ -52,9 +56,12 @@ impl RateLimiter {
|
||||||
pub fn stream<'a, Fut, T>(
|
pub fn stream<'a, Fut, T>(
|
||||||
&self,
|
&self,
|
||||||
future: Fut,
|
future: Fut,
|
||||||
) -> impl 'a + Future<Output = Result<impl Stream<Item = T::Item> + use<Fut, T>>>
|
) -> impl 'a
|
||||||
|
+ Future<
|
||||||
|
Output = Result<impl Stream<Item = T::Item> + use<Fut, T>, LanguageModelCompletionError>,
|
||||||
|
>
|
||||||
where
|
where
|
||||||
Fut: 'a + Future<Output = Result<T>>,
|
Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
|
||||||
T: Stream,
|
T: Stream,
|
||||||
{
|
{
|
||||||
let guard = self.semaphore.acquire_arc();
|
let guard = self.semaphore.acquire_arc();
|
||||||
|
|
|
@ -387,22 +387,34 @@ impl AnthropicModel {
|
||||||
&self,
|
&self,
|
||||||
request: anthropic::Request,
|
request: anthropic::Request,
|
||||||
cx: &AsyncApp,
|
cx: &AsyncApp,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event, AnthropicError>>>>
|
) -> BoxFuture<
|
||||||
{
|
'static,
|
||||||
|
Result<
|
||||||
|
BoxStream<'static, Result<anthropic::Event, AnthropicError>>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
|
>,
|
||||||
|
> {
|
||||||
let http_client = self.http_client.clone();
|
let http_client = self.http_client.clone();
|
||||||
|
|
||||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||||
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
|
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
|
||||||
(state.api_key.clone(), settings.api_url.clone())
|
(state.api_key.clone(), settings.api_url.clone())
|
||||||
}) else {
|
}) else {
|
||||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
|
||||||
};
|
};
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
let api_key = api_key.context("Missing Anthropic API Key")?;
|
let api_key = api_key.context("Missing Anthropic API Key")?;
|
||||||
let request =
|
let request =
|
||||||
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||||
request.await.context("failed to stream completion")
|
request.await.map_err(|err| match err {
|
||||||
|
AnthropicError::RateLimit(duration) => {
|
||||||
|
LanguageModelCompletionError::RateLimit(duration)
|
||||||
|
}
|
||||||
|
err @ (AnthropicError::ApiError(..) | AnthropicError::Other(..)) => {
|
||||||
|
LanguageModelCompletionError::Other(anthropic_err_to_anyhow(err))
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
@ -473,6 +485,7 @@ impl LanguageModel for AnthropicModel {
|
||||||
'static,
|
'static,
|
||||||
Result<
|
Result<
|
||||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let request = into_anthropic(
|
let request = into_anthropic(
|
||||||
|
@ -484,12 +497,7 @@ impl LanguageModel for AnthropicModel {
|
||||||
);
|
);
|
||||||
let request = self.stream_completion(request, cx);
|
let request = self.stream_completion(request, cx);
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let response = request
|
let response = request.await?;
|
||||||
.await
|
|
||||||
.map_err(|err| match err.downcast::<AnthropicError>() {
|
|
||||||
Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err),
|
|
||||||
Err(err) => anyhow!(err),
|
|
||||||
})?;
|
|
||||||
Ok(AnthropicEventMapper::new().map_stream(response))
|
Ok(AnthropicEventMapper::new().map_stream(response))
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
|
|
|
@ -527,6 +527,7 @@ impl LanguageModel for BedrockModel {
|
||||||
'static,
|
'static,
|
||||||
Result<
|
Result<
|
||||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
|
let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
|
||||||
|
@ -539,16 +540,13 @@ impl LanguageModel for BedrockModel {
|
||||||
.or(settings_region)
|
.or(settings_region)
|
||||||
.unwrap_or(String::from("us-east-1"))
|
.unwrap_or(String::from("us-east-1"))
|
||||||
}) else {
|
}) else {
|
||||||
return async move {
|
return async move { Err(anyhow::anyhow!("App State Dropped").into()) }.boxed();
|
||||||
anyhow::bail!("App State Dropped");
|
|
||||||
}
|
|
||||||
.boxed();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let model_id = match self.model.cross_region_inference_id(®ion) {
|
let model_id = match self.model.cross_region_inference_id(®ion) {
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return async move { Err(e) }.boxed();
|
return async move { Err(e.into()) }.boxed();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -560,7 +558,7 @@ impl LanguageModel for BedrockModel {
|
||||||
self.model.mode(),
|
self.model.mode(),
|
||||||
) {
|
) {
|
||||||
Ok(request) => request,
|
Ok(request) => request,
|
||||||
Err(err) => return futures::future::ready(Err(err)).boxed(),
|
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let owned_handle = self.handler.clone();
|
let owned_handle = self.handler.clone();
|
||||||
|
|
|
@ -807,6 +807,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
'static,
|
'static,
|
||||||
Result<
|
Result<
|
||||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let thread_id = request.thread_id.clone();
|
let thread_id = request.thread_id.clone();
|
||||||
|
@ -848,7 +849,8 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
mode,
|
mode,
|
||||||
provider: zed_llm_client::LanguageModelProvider::Anthropic,
|
provider: zed_llm_client::LanguageModelProvider::Anthropic,
|
||||||
model: request.model.clone(),
|
model: request.model.clone(),
|
||||||
provider_request: serde_json::to_value(&request)?,
|
provider_request: serde_json::to_value(&request)
|
||||||
|
.map_err(|e| anyhow!(e))?,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
@ -884,7 +886,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let model = match open_ai::Model::from_id(&self.model.id.0) {
|
let model = match open_ai::Model::from_id(&self.model.id.0) {
|
||||||
Ok(model) => model,
|
Ok(model) => model,
|
||||||
Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
|
Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
|
||||||
};
|
};
|
||||||
let request = into_open_ai(request, &model, None);
|
let request = into_open_ai(request, &model, None);
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
|
@ -905,7 +907,8 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
mode,
|
mode,
|
||||||
provider: zed_llm_client::LanguageModelProvider::OpenAi,
|
provider: zed_llm_client::LanguageModelProvider::OpenAi,
|
||||||
model: request.model.clone(),
|
model: request.model.clone(),
|
||||||
provider_request: serde_json::to_value(&request)?,
|
provider_request: serde_json::to_value(&request)
|
||||||
|
.map_err(|e| anyhow!(e))?,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
@ -944,7 +947,8 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
mode,
|
mode,
|
||||||
provider: zed_llm_client::LanguageModelProvider::Google,
|
provider: zed_llm_client::LanguageModelProvider::Google,
|
||||||
model: request.model.model_id.clone(),
|
model: request.model.model_id.clone(),
|
||||||
provider_request: serde_json::to_value(&request)?,
|
provider_request: serde_json::to_value(&request)
|
||||||
|
.map_err(|e| anyhow!(e))?,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
|
@ -265,13 +265,15 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||||
'static,
|
'static,
|
||||||
Result<
|
Result<
|
||||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
if let Some(message) = request.messages.last() {
|
if let Some(message) = request.messages.last() {
|
||||||
if message.contents_empty() {
|
if message.contents_empty() {
|
||||||
const EMPTY_PROMPT_MSG: &str =
|
const EMPTY_PROMPT_MSG: &str =
|
||||||
"Empty prompts aren't allowed. Please provide a non-empty prompt.";
|
"Empty prompts aren't allowed. Please provide a non-empty prompt.";
|
||||||
return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG))).boxed();
|
return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG).into()))
|
||||||
|
.boxed();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copilot Chat has a restriction that the final message must be from the user.
|
// Copilot Chat has a restriction that the final message must be from the user.
|
||||||
|
@ -279,13 +281,13 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||||
// and provide a more helpful error message.
|
// and provide a more helpful error message.
|
||||||
if !matches!(message.role, Role::User) {
|
if !matches!(message.role, Role::User) {
|
||||||
const USER_ROLE_MSG: &str = "The final message must be from the user. To provide a system prompt, you must provide the system prompt followed by a user prompt.";
|
const USER_ROLE_MSG: &str = "The final message must be from the user. To provide a system prompt, you must provide the system prompt followed by a user prompt.";
|
||||||
return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG))).boxed();
|
return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG).into())).boxed();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let copilot_request = match into_copilot_chat(&self.model, request) {
|
let copilot_request = match into_copilot_chat(&self.model, request) {
|
||||||
Ok(request) => request,
|
Ok(request) => request,
|
||||||
Err(err) => return futures::future::ready(Err(err)).boxed(),
|
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
|
||||||
};
|
};
|
||||||
let is_streaming = copilot_request.stream;
|
let is_streaming = copilot_request.stream;
|
||||||
|
|
||||||
|
|
|
@ -348,6 +348,7 @@ impl LanguageModel for DeepSeekLanguageModel {
|
||||||
'static,
|
'static,
|
||||||
Result<
|
Result<
|
||||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let request = into_deepseek(request, &self.model, self.max_output_tokens());
|
let request = into_deepseek(request, &self.model, self.max_output_tokens());
|
||||||
|
|
|
@ -409,6 +409,7 @@ impl LanguageModel for GoogleLanguageModel {
|
||||||
'static,
|
'static,
|
||||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||||
>,
|
>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let request = into_google(
|
let request = into_google(
|
||||||
|
|
|
@ -420,6 +420,7 @@ impl LanguageModel for LmStudioLanguageModel {
|
||||||
'static,
|
'static,
|
||||||
Result<
|
Result<
|
||||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let request = self.to_lmstudio_request(request);
|
let request = self.to_lmstudio_request(request);
|
||||||
|
|
|
@ -364,6 +364,7 @@ impl LanguageModel for MistralLanguageModel {
|
||||||
'static,
|
'static,
|
||||||
Result<
|
Result<
|
||||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let request = into_mistral(
|
let request = into_mistral(
|
||||||
|
|
|
@ -406,6 +406,7 @@ impl LanguageModel for OllamaLanguageModel {
|
||||||
'static,
|
'static,
|
||||||
Result<
|
Result<
|
||||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let request = self.to_ollama_request(request);
|
let request = self.to_ollama_request(request);
|
||||||
|
@ -415,7 +416,7 @@ impl LanguageModel for OllamaLanguageModel {
|
||||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||||
settings.api_url.clone()
|
settings.api_url.clone()
|
||||||
}) else {
|
}) else {
|
||||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
|
||||||
};
|
};
|
||||||
|
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
|
|
|
@ -339,6 +339,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||||
'static,
|
'static,
|
||||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||||
>,
|
>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let request = into_open_ai(request, &self.model, self.max_output_tokens());
|
let request = into_open_ai(request, &self.model, self.max_output_tokens());
|
||||||
|
|
|
@ -367,6 +367,7 @@ impl LanguageModel for OpenRouterLanguageModel {
|
||||||
'static,
|
'static,
|
||||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||||
>,
|
>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let request = into_open_router(request, &self.model, self.max_output_tokens());
|
let request = into_open_router(request, &self.model, self.max_output_tokens());
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue