Allow customization of the model used for tool calling (#15479)
We also eliminate the `completion` crate and moved its logic into `LanguageModelRegistry`. Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
1bfea9d443
commit
99bc90a372
32 changed files with 478 additions and 691 deletions
|
@ -1,24 +1,24 @@
|
|||
mod model;
|
||||
pub mod provider;
|
||||
mod rate_limiter;
|
||||
mod registry;
|
||||
mod request;
|
||||
mod role;
|
||||
pub mod settings;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use client::Client;
|
||||
use futures::{future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
|
||||
|
||||
pub use model::*;
|
||||
use project::Fs;
|
||||
pub(crate) use rate_limiter::*;
|
||||
pub use registry::*;
|
||||
pub use request::*;
|
||||
pub use role::*;
|
||||
use schemars::JsonSchema;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::{future::Future, sync::Arc};
|
||||
|
||||
pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
|
||||
settings::init(fs, cx);
|
||||
|
@ -46,7 +46,7 @@ pub trait LanguageModel: Send + Sync {
|
|||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
|
||||
fn use_tool(
|
||||
fn use_any_tool(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
name: String,
|
||||
|
@ -56,6 +56,22 @@ pub trait LanguageModel: Send + Sync {
|
|||
) -> BoxFuture<'static, Result<serde_json::Value>>;
|
||||
}
|
||||
|
||||
impl dyn LanguageModel {
|
||||
pub fn use_tool<T: LanguageModelTool>(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> impl 'static + Future<Output = Result<T>> {
|
||||
let schema = schemars::schema_for!(T);
|
||||
let schema_json = serde_json::to_value(&schema).unwrap();
|
||||
let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
|
||||
async move {
|
||||
let response = request.await?;
|
||||
Ok(serde_json::from_value(response)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
||||
fn name() -> String;
|
||||
fn description() -> String;
|
||||
|
@ -67,9 +83,9 @@ pub trait LanguageModelProvider: 'static {
|
|||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
|
||||
fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
|
||||
fn is_authenticated(&self, cx: &AppContext) -> bool;
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
|
||||
}
|
||||
|
||||
pub trait LanguageModelProviderState: 'static {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue