ZIm/crates/assistant2/src/completion_provider.rs
Marshall Bowers d633a0da78
gpui: Fix Global trait (#11187)
This PR restores the `Global` trait's status as a marker trait.

This was the original intent from #7095, when it was added, that had
been lost in #9777.

The purpose of the `Global` trait is to statically convey what types can
and can't be accessed as `Global` state, as well as provide a way of
restricting access to said globals.

For example, in the case of the `ThemeRegistry` we have a private
`GlobalThemeRegistry` that is marked as `Global`:
91b3c24ed3/crates/theme/src/registry.rs (L25-L34)

We're then able to permit reading the `ThemeRegistry` from the
`GlobalThemeRegistry` via a custom getter, while still restricting which
callers are able to mutate the global:
91b3c24ed3/crates/theme/src/registry.rs (L46-L61)

Release Notes:

- N/A
2024-04-29 16:37:37 -04:00

183 lines
6.8 KiB
Rust

use anyhow::Result;
use assistant_tooling::ToolFunctionDefinition;
use client::{proto, Client};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AppContext, Global};
use std::sync::Arc;
pub use open_ai::RequestMessage as CompletionMessage;
#[derive(Clone)]
pub struct CompletionProvider(Arc<dyn CompletionProviderBackend>);
impl CompletionProvider {
pub fn get(cx: &AppContext) -> &Self {
cx.global::<CompletionProvider>()
}
pub fn new(backend: impl CompletionProviderBackend) -> Self {
Self(Arc::new(backend))
}
pub fn default_model(&self) -> String {
self.0.default_model()
}
pub fn available_models(&self) -> Vec<String> {
self.0.available_models()
}
pub fn complete(
&self,
model: String,
messages: Vec<CompletionMessage>,
stop: Vec<String>,
temperature: f32,
tools: &[ToolFunctionDefinition],
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
{
self.0.complete(model, messages, stop, temperature, tools)
}
}
impl Global for CompletionProvider {}
pub trait CompletionProviderBackend: 'static {
fn default_model(&self) -> String;
fn available_models(&self) -> Vec<String>;
fn complete(
&self,
model: String,
messages: Vec<CompletionMessage>,
stop: Vec<String>,
temperature: f32,
tools: &[ToolFunctionDefinition],
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>;
}
pub struct CloudCompletionProvider {
client: Arc<Client>,
}
impl CloudCompletionProvider {
pub fn new(client: Arc<Client>) -> Self {
Self { client }
}
}
impl CompletionProviderBackend for CloudCompletionProvider {
fn default_model(&self) -> String {
"gpt-4-turbo".into()
}
fn available_models(&self) -> Vec<String> {
vec!["gpt-4-turbo".into(), "gpt-4".into(), "gpt-3.5-turbo".into()]
}
fn complete(
&self,
model: String,
messages: Vec<CompletionMessage>,
stop: Vec<String>,
temperature: f32,
tools: &[ToolFunctionDefinition],
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
{
let client = self.client.clone();
let tools: Vec<proto::ChatCompletionTool> = tools
.iter()
.filter_map(|tool| {
Some(proto::ChatCompletionTool {
variant: Some(proto::chat_completion_tool::Variant::Function(
proto::chat_completion_tool::FunctionObject {
name: tool.name.clone(),
description: Some(tool.description.clone()),
parameters: Some(serde_json::to_string(&tool.parameters).ok()?),
},
)),
})
})
.collect();
let tool_choice = match tools.is_empty() {
true => None,
false => Some("auto".into()),
};
async move {
let stream = client
.request_stream(proto::CompleteWithLanguageModel {
model,
messages: messages
.into_iter()
.map(|message| match message {
CompletionMessage::Assistant {
content,
tool_calls,
} => proto::LanguageModelRequestMessage {
role: proto::LanguageModelRole::LanguageModelAssistant as i32,
content: content.unwrap_or_default(),
tool_call_id: None,
tool_calls: tool_calls
.into_iter()
.map(|tool_call| match tool_call.content {
open_ai::ToolCallContent::Function { function } => {
proto::ToolCall {
id: tool_call.id,
variant: Some(proto::tool_call::Variant::Function(
proto::tool_call::FunctionCall {
name: function.name,
arguments: function.arguments,
},
)),
}
}
})
.collect(),
},
CompletionMessage::User { content } => {
proto::LanguageModelRequestMessage {
role: proto::LanguageModelRole::LanguageModelUser as i32,
content,
tool_call_id: None,
tool_calls: Vec::new(),
}
}
CompletionMessage::System { content } => {
proto::LanguageModelRequestMessage {
role: proto::LanguageModelRole::LanguageModelSystem as i32,
content,
tool_calls: Vec::new(),
tool_call_id: None,
}
}
CompletionMessage::Tool {
content,
tool_call_id,
} => proto::LanguageModelRequestMessage {
role: proto::LanguageModelRole::LanguageModelTool as i32,
content,
tool_call_id: Some(tool_call_id),
tool_calls: Vec::new(),
},
})
.collect(),
stop,
temperature,
tool_choice,
tools,
})
.await?;
Ok(stream
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?)),
Err(error) => Some(Err(error)),
}
})
.boxed())
}
.boxed()
}
}