Allow customization of the model used for tool calling (#15479)

We also eliminate the `completion` crate and moved its logic into
`LanguageModelRegistry`.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-30 16:18:53 +02:00 committed by GitHub
parent 1bfea9d443
commit 99bc90a372
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 478 additions and 691 deletions

View file

@ -1,7 +1,7 @@
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
@ -36,6 +36,7 @@ pub struct AnthropicSettings {
pub struct AvailableModel {
pub name: String,
pub max_tokens: usize,
pub tool_override: Option<String>,
}
pub struct AnthropicLanguageModelProvider {
@ -98,6 +99,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
anthropic::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
tool_override: model.tool_override.clone(),
},
);
}
@ -110,6 +112,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -119,7 +122,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
self.state.read(cx).api_key.is_some()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
@ -152,7 +155,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
.into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
let state = self.state.clone();
let delete_credentials =
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
@ -171,6 +174,7 @@ pub struct AnthropicModel {
model: anthropic::Model,
state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
pub fn count_anthropic_tokens(
@ -296,14 +300,14 @@ impl LanguageModel for AnthropicModel {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = request.into_anthropic(self.model.id().into());
let request = self.stream_completion(request, cx);
async move {
let future = self.request_limiter.stream(async move {
let response = request.await?;
Ok(anthropic::extract_text_from_events(response).boxed())
}
.boxed()
Ok(anthropic::extract_text_from_events(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
request: LanguageModelRequest,
tool_name: String,
@ -311,7 +315,7 @@ impl LanguageModel for AnthropicModel {
input_schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
let mut request = request.into_anthropic(self.model.id().into());
let mut request = request.into_anthropic(self.model.tool_model_id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});
@ -322,25 +326,26 @@ impl LanguageModel for AnthropicModel {
}];
let response = self.request_completion(request, cx);
async move {
let response = response.await?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
self.request_limiter
.run(async move {
let response = response.await?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
} else {
None
}
} else {
None
}
} else {
None
}
})
.context("tool not used")
}
.boxed()
})
.context("tool not used")
})
.boxed()
}
}

View file

@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens;
use crate::{
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
};
use anyhow::{anyhow, Context as _, Result};
use client::Client;
@ -41,6 +41,7 @@ pub struct AvailableModel {
provider: AvailableProvider,
name: String,
max_tokens: usize,
tool_override: Option<String>,
}
pub struct CloudLanguageModelProvider {
@ -56,7 +57,7 @@ struct State {
}
impl State {
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
let client = self.client.clone();
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
}
@ -142,6 +143,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
tool_override: model.tool_override.clone(),
}),
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
name: model.name.clone(),
@ -162,6 +164,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
id: LanguageModelId::from(model.id().to_string()),
model,
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -171,8 +174,8 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
self.state.read(cx).status.is_connected()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
self.state.read(cx).authenticate(cx)
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
@ -182,7 +185,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
.into()
}
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
}
@ -191,6 +194,7 @@ pub struct CloudLanguageModel {
id: LanguageModelId,
model: CloudModel,
client: Arc<Client>,
request_limiter: RateLimiter,
}
impl LanguageModel for CloudLanguageModel {
@ -256,7 +260,7 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::Anthropic(model) => {
let client = self.client.clone();
let request = request.into_anthropic(model.id().into());
async move {
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
@ -266,15 +270,14 @@ impl LanguageModel for CloudLanguageModel {
.await?;
Ok(anthropic::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
.boxed()
))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
let request = request.into_open_ai(model.id().into());
async move {
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
@ -284,15 +287,14 @@ impl LanguageModel for CloudLanguageModel {
.await?;
Ok(open_ai::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
.boxed()
))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::Google(model) => {
let client = self.client.clone();
let request = request.into_google(model.id().into());
async move {
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
@ -302,15 +304,14 @@ impl LanguageModel for CloudLanguageModel {
.await?;
Ok(google_ai::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
.boxed()
))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
}
fn use_tool(
fn use_any_tool(
&self,
request: LanguageModelRequest,
tool_name: String,
@ -321,7 +322,7 @@ impl LanguageModel for CloudLanguageModel {
match &self.model {
CloudModel::Anthropic(model) => {
let client = self.client.clone();
let mut request = request.into_anthropic(model.id().into());
let mut request = request.into_anthropic(model.tool_model_id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});
@ -331,32 +332,34 @@ impl LanguageModel for CloudLanguageModel {
input_schema,
}];
async move {
let request = serde_json::to_string(&request)?;
let response = client
.request(proto::CompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
let response: anthropic::Response = serde_json::from_str(&response.completion)?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request(proto::CompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
let response: anthropic::Response =
serde_json::from_str(&response.completion)?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
} else {
None
}
} else {
None
}
} else {
None
}
})
.context("tool not used")
}
.boxed()
})
.context("tool not used")
})
.boxed()
}
CloudModel::OpenAi(_) => {
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()

View file

@ -27,7 +27,7 @@ use crate::settings::AllLanguageModelSettings;
use crate::LanguageModelProviderState;
use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, Role,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role,
};
use super::open_ai::count_open_ai_tokens;
@ -85,7 +85,12 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
fn provided_models(&self, _cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
CopilotChatModel::iter()
.map(|model| Arc::new(CopilotChatLanguageModel { model }) as Arc<dyn LanguageModel>)
.map(|model| {
Arc::new(CopilotChatLanguageModel {
model,
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
}
@ -95,7 +100,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
.unwrap_or(false)
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
let result = if self.is_authenticated(cx) {
Ok(())
} else if let Some(copilot) = Copilot::global(cx) {
@ -121,7 +126,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
cx.new_view(|cx| AuthenticationPrompt::new(cx)).into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
let Some(copilot) = Copilot::global(cx) else {
return Task::ready(Err(anyhow::anyhow!(
"Copilot is not available. Please ensure Copilot is enabled and running and try again."
@ -145,6 +150,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
pub struct CopilotChatLanguageModel {
model: CopilotChatModel,
request_limiter: RateLimiter,
}
impl LanguageModel for CopilotChatLanguageModel {
@ -215,30 +221,35 @@ impl LanguageModel for CopilotChatLanguageModel {
return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed();
};
cx.spawn(|mut cx| async move {
let response = CopilotChat::stream_completion(request, low_speed_timeout, &mut cx).await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(result) => {
let choice = result.choices.first();
match choice {
Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
))),
let request_limiter = self.request_limiter.clone();
let future = cx.spawn(|cx| async move {
let response = CopilotChat::stream_completion(request, low_speed_timeout, cx);
request_limiter.stream(async move {
let response = response.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(result) => {
let choice = result.choices.first();
match choice {
Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
))),
}
}
Err(err) => Some(Err(err)),
}
Err(err) => Some(Err(err)),
}
})
.boxed();
Ok(stream)
})
.boxed()
})
.boxed();
Ok(stream)
}).await
});
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,

View file

@ -60,7 +60,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
true
}
fn authenticate(&self, _: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, _: &mut AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
@ -68,7 +68,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
unimplemented!()
}
fn reset_credentials(&self, _: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, _: &mut AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
}
@ -173,7 +173,7 @@ impl LanguageModel for FakeLanguageModel {
async move { Ok(rx.map(Ok).boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,

View file

@ -20,7 +20,7 @@ use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
};
const PROVIDER_ID: &str = "google";
@ -111,6 +111,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
rate_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -120,7 +121,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
self.state.read(cx).api_key.is_some()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
@ -153,7 +154,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
.into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
let state = self.state.clone();
let delete_credentials =
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
@ -172,6 +173,7 @@ pub struct GoogleLanguageModel {
model: google_ai::Model,
state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>,
rate_limiter: RateLimiter,
}
impl LanguageModel for GoogleLanguageModel {
@ -243,17 +245,17 @@ impl LanguageModel for GoogleLanguageModel {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let future = self.rate_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let response =
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
let events = response.await?;
Ok(google_ai::extract_text_from_events(events).boxed())
}
.boxed()
});
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,

View file

@ -12,7 +12,7 @@ use ui::{prelude::*, ButtonLike, ElevationIndex};
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
@ -39,7 +39,7 @@ struct State {
}
impl State {
fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
@ -80,37 +80,10 @@ impl OllamaLanguageModelProvider {
}),
}),
};
this.fetch_models(cx).detach();
this.state
.update(cx, |state, cx| state.fetch_models(cx).detach());
this
}
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let state = self.state.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|mut cx| async move {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let mut models: Vec<ollama::Model> = models
.into_iter()
// Since there is no metadata from the Ollama API
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
.map(|model| ollama::Model::new(&model.name))
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
state.update(&mut cx, |this, cx| {
this.available_models = models;
cx.notify();
})
})
}
}
impl LanguageModelProviderState for OllamaLanguageModelProvider {
@ -140,6 +113,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
id: LanguageModelId::from(model.name.clone()),
model: model.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -158,11 +132,11 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
!self.state.read(cx).available_models.is_empty()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
self.fetch_models(cx)
self.state.update(cx, |state, cx| state.fetch_models(cx))
}
}
@ -176,8 +150,8 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
.into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
self.fetch_models(cx)
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.fetch_models(cx))
}
}
@ -185,6 +159,7 @@ pub struct OllamaLanguageModel {
id: LanguageModelId,
model: ollama::Model,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
impl OllamaLanguageModel {
@ -235,14 +210,14 @@ impl LanguageModel for OllamaLanguageModel {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
fn telemetry_id(&self) -> String {
format!("ollama/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
@ -275,10 +250,10 @@ impl LanguageModel for OllamaLanguageModel {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let request =
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
let response = request.await?;
let future = self.request_limiter.stream(async move {
let response =
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
.await?;
let stream = response
.filter_map(|response| async move {
match response {
@ -295,11 +270,12 @@ impl LanguageModel for OllamaLanguageModel {
})
.boxed();
Ok(stream)
}
.boxed()
});
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,

View file

@ -20,7 +20,7 @@ use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
const PROVIDER_ID: &str = "openai";
@ -112,6 +112,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -121,7 +122,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
self.state.read(cx).api_key.is_some()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
@ -153,7 +154,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
.into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).openai;
let delete_credentials = cx.delete_credentials(&settings.api_url);
let state = self.state.clone();
@ -172,6 +173,7 @@ pub struct OpenAiLanguageModel {
model: open_ai::Model,
state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
impl LanguageModel for OpenAiLanguageModel {
@ -226,7 +228,7 @@ impl LanguageModel for OpenAiLanguageModel {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let future = self.request_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(
http_client.as_ref(),
@ -237,11 +239,12 @@ impl LanguageModel for OpenAiLanguageModel {
);
let response = request.await?;
Ok(open_ai::extract_text_from_events(response).boxed())
}
.boxed()
});
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,