Allow customization of the model used for tool calling (#15479)

We also eliminate the `completion` crate and moved its logic into
`LanguageModelRegistry`.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-30 16:18:53 +02:00 committed by GitHub
parent 1bfea9d443
commit 99bc90a372
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 478 additions and 691 deletions

View file

@ -1,24 +1,24 @@
mod model;
pub mod provider;
mod rate_limiter;
mod registry;
mod request;
mod role;
pub mod settings;
use std::sync::Arc;
use anyhow::Result;
use client::Client;
use futures::{future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
pub use model::*;
use project::Fs;
pub(crate) use rate_limiter::*;
pub use registry::*;
pub use request::*;
pub use role::*;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use std::{future::Future, sync::Arc};
pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
settings::init(fs, cx);
@ -46,7 +46,7 @@ pub trait LanguageModel: Send + Sync {
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn use_tool(
fn use_any_tool(
&self,
request: LanguageModelRequest,
name: String,
@ -56,6 +56,22 @@ pub trait LanguageModel: Send + Sync {
) -> BoxFuture<'static, Result<serde_json::Value>>;
}
impl dyn LanguageModel {
pub fn use_tool<T: LanguageModelTool>(
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> impl 'static + Future<Output = Result<T>> {
let schema = schemars::schema_for!(T);
let schema_json = serde_json::to_value(&schema).unwrap();
let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
async move {
let response = request.await?;
Ok(serde_json::from_value(response)?)
}
}
}
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
fn name() -> String;
fn description() -> String;
@ -67,9 +83,9 @@ pub trait LanguageModelProvider: 'static {
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
fn is_authenticated(&self, cx: &AppContext) -> bool;
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
}
pub trait LanguageModelProviderState: 'static {