Restructure workflow step resolution and fix inserting newlines (#15720)
Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
49e736d8ef
commit
0ec29d6866
18 changed files with 1316 additions and 815 deletions
|
@ -37,6 +37,7 @@ log.workspace = true
|
|||
menu.workspace = true
|
||||
ollama = { workspace = true, features = ["schemars"] }
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
parking_lot.workspace = true
|
||||
proto = { workspace = true, features = ["test-support"] }
|
||||
project.workspace = true
|
||||
schemars.workspace = true
|
||||
|
|
|
@ -75,6 +75,11 @@ pub trait LanguageModel: Send + Sync {
|
|||
schema: serde_json::Value,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>>;
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl dyn LanguageModel {
|
||||
|
|
|
@ -3,15 +3,17 @@ use crate::{
|
|||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
use collections::HashMap;
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use anyhow::Context as _;
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
future::BoxFuture,
|
||||
stream::BoxStream,
|
||||
FutureExt, StreamExt,
|
||||
};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
|
||||
use http_client::Result;
|
||||
use std::{
|
||||
future,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use std::sync::Arc;
|
||||
use ui::WindowContext;
|
||||
|
||||
pub fn language_model_id() -> LanguageModelId {
|
||||
|
@ -31,9 +33,7 @@ pub fn provider_name() -> LanguageModelProviderName {
|
|||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct FakeLanguageModelProvider {
|
||||
current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
|
||||
}
|
||||
pub struct FakeLanguageModelProvider;
|
||||
|
||||
impl LanguageModelProviderState for FakeLanguageModelProvider {
|
||||
type ObservableEntity = ();
|
||||
|
@ -53,9 +53,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
|
|||
}
|
||||
|
||||
fn provided_models(&self, _: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||
vec![Arc::new(FakeLanguageModel {
|
||||
current_completion_txs: self.current_completion_txs.clone(),
|
||||
})]
|
||||
vec![Arc::new(FakeLanguageModel::default())]
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, _: &AppContext) -> bool {
|
||||
|
@ -77,55 +75,80 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
|
|||
|
||||
impl FakeLanguageModelProvider {
|
||||
pub fn test_model(&self) -> FakeLanguageModel {
|
||||
FakeLanguageModel {
|
||||
current_completion_txs: self.current_completion_txs.clone(),
|
||||
}
|
||||
FakeLanguageModel::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct ToolUseRequest {
|
||||
pub request: LanguageModelRequest,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub schema: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct FakeLanguageModel {
|
||||
current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
|
||||
current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
|
||||
current_tool_use_txs: Mutex<Vec<(ToolUseRequest, oneshot::Sender<Result<serde_json::Value>>)>>,
|
||||
}
|
||||
|
||||
impl FakeLanguageModel {
|
||||
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.unwrap()
|
||||
.keys()
|
||||
.map(|k| serde_json::from_str(k).unwrap())
|
||||
.iter()
|
||||
.map(|(request, _)| request.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn completion_count(&self) -> usize {
|
||||
self.current_completion_txs.lock().unwrap().len()
|
||||
self.current_completion_txs.lock().len()
|
||||
}
|
||||
|
||||
pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
|
||||
let json = serde_json::to_string(request).unwrap();
|
||||
pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) {
|
||||
let current_completion_txs = self.current_completion_txs.lock();
|
||||
let tx = current_completion_txs
|
||||
.iter()
|
||||
.find(|(req, _)| req == request)
|
||||
.map(|(_, tx)| tx)
|
||||
.unwrap();
|
||||
tx.unbounded_send(chunk).unwrap();
|
||||
}
|
||||
|
||||
pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(&json)
|
||||
.unwrap()
|
||||
.unbounded_send(chunk)
|
||||
.unwrap();
|
||||
.retain(|(req, _)| req != request);
|
||||
}
|
||||
|
||||
pub fn send_last_completion_chunk(&self, chunk: String) {
|
||||
self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
|
||||
pub fn stream_last_completion_response(&self, chunk: String) {
|
||||
self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
|
||||
}
|
||||
|
||||
pub fn finish_completion(&self, request: &LanguageModelRequest) {
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.unwrap()
|
||||
.remove(&serde_json::to_string(request).unwrap())
|
||||
.unwrap();
|
||||
pub fn end_last_completion_stream(&self) {
|
||||
self.end_completion_stream(self.pending_completions().last().unwrap());
|
||||
}
|
||||
|
||||
pub fn finish_last_completion(&self) {
|
||||
self.finish_completion(self.pending_completions().last().unwrap());
|
||||
pub fn respond_to_tool_use(
|
||||
&self,
|
||||
tool_call: &ToolUseRequest,
|
||||
response: Result<serde_json::Value>,
|
||||
) {
|
||||
let mut current_tool_call_txs = self.current_tool_use_txs.lock();
|
||||
if let Some(index) = current_tool_call_txs
|
||||
.iter()
|
||||
.position(|(call, _)| call == tool_call)
|
||||
{
|
||||
let (_, tx) = current_tool_call_txs.remove(index);
|
||||
tx.send(response).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn respond_to_last_tool_use(&self, response: Result<serde_json::Value>) {
|
||||
let mut current_tool_call_txs = self.current_tool_use_txs.lock();
|
||||
let (_, tx) = current_tool_call_txs.pop().unwrap();
|
||||
tx.send(response).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -168,21 +191,30 @@ impl LanguageModel for FakeLanguageModel {
|
|||
_: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(serde_json::to_string(&request).unwrap(), tx);
|
||||
self.current_completion_txs.lock().push((request, tx));
|
||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn use_any_tool(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
_name: String,
|
||||
_description: String,
|
||||
_schema: serde_json::Value,
|
||||
request: LanguageModelRequest,
|
||||
name: String,
|
||||
description: String,
|
||||
schema: serde_json::Value,
|
||||
_cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
future::ready(Err(anyhow!("not implemented"))).boxed()
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let tool_call = ToolUseRequest {
|
||||
request,
|
||||
name,
|
||||
description,
|
||||
schema,
|
||||
};
|
||||
self.current_tool_use_txs.lock().push((tool_call, tx));
|
||||
async move { rx.await.context("FakeLanguageModel was dropped")? }.boxed()
|
||||
}
|
||||
|
||||
fn as_fake(&self) -> &Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
|
|
@ -103,7 +103,7 @@ impl LanguageModelRegistry {
|
|||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
|
||||
let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default();
|
||||
let fake_provider = crate::provider::fake::FakeLanguageModelProvider;
|
||||
let registry = cx.new_model(|cx| {
|
||||
let mut registry = Self::default();
|
||||
registry.register_provider(fake_provider.clone(), cx);
|
||||
|
@ -239,7 +239,7 @@ mod tests {
|
|||
let registry = cx.new_model(|_| LanguageModelRegistry::default());
|
||||
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.register_provider(FakeLanguageModelProvider::default(), cx);
|
||||
registry.register_provider(FakeLanguageModelProvider, cx);
|
||||
});
|
||||
|
||||
let providers = registry.read(cx).providers();
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
use crate::role::Role;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
|
||||
pub struct LanguageModelRequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub struct LanguageModelRequest {
|
||||
pub messages: Vec<LanguageModelRequestMessage>,
|
||||
pub stop: Vec<String>,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::{self, Display};
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq, Hash)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue