Merge branch 'main' into edit-file-tool

This commit is contained in:
Agus Zubiaga 2025-08-07 20:49:52 -03:00
commit 26cf0cd9df
20 changed files with 318 additions and 241 deletions

View file

@ -13,22 +13,17 @@ self-hosted-runner:
- windows-2025-16 - windows-2025-16
- windows-2025-32 - windows-2025-32
- windows-2025-64 - windows-2025-64
# Buildjet Ubuntu 20.04 - AMD x86_64 # Namespace Ubuntu 20.04 (Release builds)
- buildjet-2vcpu-ubuntu-2004 - namespace-profile-16x32-ubuntu-2004
- buildjet-4vcpu-ubuntu-2004 - namespace-profile-32x64-ubuntu-2004
- buildjet-8vcpu-ubuntu-2004 - namespace-profile-16x32-ubuntu-2004-arm
- buildjet-16vcpu-ubuntu-2004 - namespace-profile-32x64-ubuntu-2004-arm
- buildjet-32vcpu-ubuntu-2004 # Namespace Ubuntu 22.04 (Everything else)
# Buildjet Ubuntu 22.04 - AMD x86_64 - namespace-profile-2x4-ubuntu-2204
- buildjet-2vcpu-ubuntu-2204 - namespace-profile-4x8-ubuntu-2204
- buildjet-4vcpu-ubuntu-2204 - namespace-profile-8x16-ubuntu-2204
- buildjet-8vcpu-ubuntu-2204 - namespace-profile-16x32-ubuntu-2204
- buildjet-16vcpu-ubuntu-2204 - namespace-profile-32x64-ubuntu-2204
- buildjet-32vcpu-ubuntu-2204
# Buildjet Ubuntu 22.04 - Graviton aarch64
- buildjet-8vcpu-ubuntu-2204-arm
- buildjet-16vcpu-ubuntu-2204-arm
- buildjet-32vcpu-ubuntu-2204-arm
# Self Hosted Runners # Self Hosted Runners
- self-mini-macos - self-mini-macos
- self-32vcpu-windows-2022 - self-32vcpu-windows-2022

View file

@ -16,7 +16,7 @@ jobs:
bump_patch_version: bump_patch_version:
if: github.repository_owner == 'zed-industries' if: github.repository_owner == 'zed-industries'
runs-on: runs-on:
- github-16vcpu-ubuntu-2204 - namespace-profile-16x32-ubuntu-2204
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4

View file

@ -137,7 +137,7 @@ jobs:
github.repository_owner == 'zed-industries' && github.repository_owner == 'zed-industries' &&
needs.job_spec.outputs.run_tests == 'true' needs.job_spec.outputs.run_tests == 'true'
runs-on: runs-on:
- github-8vcpu-ubuntu-2204 - namespace-profile-8x16-ubuntu-2204
steps: steps:
- name: Checkout repo - name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
@ -168,7 +168,7 @@ jobs:
needs: [job_spec] needs: [job_spec]
if: github.repository_owner == 'zed-industries' if: github.repository_owner == 'zed-industries'
runs-on: runs-on:
- github-8vcpu-ubuntu-2204 - namespace-profile-4x8-ubuntu-2204
steps: steps:
- name: Checkout repo - name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
@ -221,7 +221,7 @@ jobs:
github.repository_owner == 'zed-industries' && github.repository_owner == 'zed-industries' &&
(needs.job_spec.outputs.run_tests == 'true' || needs.job_spec.outputs.run_docs == 'true') (needs.job_spec.outputs.run_tests == 'true' || needs.job_spec.outputs.run_docs == 'true')
runs-on: runs-on:
- github-8vcpu-ubuntu-2204 - namespace-profile-8x16-ubuntu-2204
steps: steps:
- name: Checkout repo - name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
@ -328,7 +328,7 @@ jobs:
github.repository_owner == 'zed-industries' && github.repository_owner == 'zed-industries' &&
needs.job_spec.outputs.run_tests == 'true' needs.job_spec.outputs.run_tests == 'true'
runs-on: runs-on:
- github-16vcpu-ubuntu-2204 - namespace-profile-16x32-ubuntu-2204
steps: steps:
- name: Add Rust to the PATH - name: Add Rust to the PATH
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
@ -380,7 +380,7 @@ jobs:
github.repository_owner == 'zed-industries' && github.repository_owner == 'zed-industries' &&
needs.job_spec.outputs.run_tests == 'true' needs.job_spec.outputs.run_tests == 'true'
runs-on: runs-on:
- github-8vcpu-ubuntu-2204 - namespace-profile-16x32-ubuntu-2204
steps: steps:
- name: Add Rust to the PATH - name: Add Rust to the PATH
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
@ -597,8 +597,7 @@ jobs:
timeout-minutes: 60 timeout-minutes: 60
name: Linux x86_x64 release bundle name: Linux x86_x64 release bundle
runs-on: runs-on:
- github-16vcpu-ubuntu-2204 - namespace-profile-16x32-ubuntu-2004 # ubuntu 20.04 for minimal glibc
# - buildjet-16vcpu-ubuntu-2004 # ubuntu 20.04 for minimal glibc
if: | if: |
startsWith(github.ref, 'refs/tags/v') startsWith(github.ref, 'refs/tags/v')
|| contains(github.event.pull_request.labels.*.name, 'run-bundling') || contains(github.event.pull_request.labels.*.name, 'run-bundling')
@ -651,7 +650,7 @@ jobs:
timeout-minutes: 60 timeout-minutes: 60
name: Linux arm64 release bundle name: Linux arm64 release bundle
runs-on: runs-on:
- github-16vcpu-ubuntu-2204-arm - namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc
if: | if: |
startsWith(github.ref, 'refs/tags/v') startsWith(github.ref, 'refs/tags/v')
|| contains(github.event.pull_request.labels.*.name, 'run-bundling') || contains(github.event.pull_request.labels.*.name, 'run-bundling')

View file

@ -9,7 +9,7 @@ jobs:
deploy-docs: deploy-docs:
name: Deploy Docs name: Deploy Docs
if: github.repository_owner == 'zed-industries' if: github.repository_owner == 'zed-industries'
runs-on: github-16vcpu-ubuntu-2204 runs-on: namespace-profile-16x32-ubuntu-2204
steps: steps:
- name: Checkout repo - name: Checkout repo

View file

@ -61,7 +61,7 @@ jobs:
- style - style
- tests - tests
runs-on: runs-on:
- github-16vcpu-ubuntu-2204 - namespace-profile-16x32-ubuntu-2204
steps: steps:
- name: Install doctl - name: Install doctl
uses: digitalocean/action-doctl@v2 uses: digitalocean/action-doctl@v2
@ -94,7 +94,7 @@ jobs:
needs: needs:
- publish - publish
runs-on: runs-on:
- github-16vcpu-ubuntu-2204 - namespace-profile-16x32-ubuntu-2204
steps: steps:
- name: Checkout repo - name: Checkout repo

View file

@ -32,7 +32,7 @@ jobs:
github.repository_owner == 'zed-industries' && github.repository_owner == 'zed-industries' &&
(github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-eval')) (github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-eval'))
runs-on: runs-on:
- github-16vcpu-ubuntu-2204 - namespace-profile-16x32-ubuntu-2204
steps: steps:
- name: Add Rust to the PATH - name: Add Rust to the PATH
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"

View file

@ -20,7 +20,7 @@ jobs:
matrix: matrix:
system: system:
- os: x86 Linux - os: x86 Linux
runner: github-16vcpu-ubuntu-2204 runner: namespace-profile-16x32-ubuntu-2204
install_nix: true install_nix: true
- os: arm Mac - os: arm Mac
runner: [macOS, ARM64, test] runner: [macOS, ARM64, test]

View file

@ -20,7 +20,7 @@ jobs:
name: Run randomized tests name: Run randomized tests
if: github.repository_owner == 'zed-industries' if: github.repository_owner == 'zed-industries'
runs-on: runs-on:
- github-16vcpu-ubuntu-2204 - namespace-profile-16x32-ubuntu-2204
steps: steps:
- name: Install Node - name: Install Node
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4

View file

@ -128,8 +128,7 @@ jobs:
name: Create a Linux *.tar.gz bundle for x86 name: Create a Linux *.tar.gz bundle for x86
if: github.repository_owner == 'zed-industries' if: github.repository_owner == 'zed-industries'
runs-on: runs-on:
- github-16vcpu-ubuntu-2204 - namespace-profile-16x32-ubuntu-2004 # ubuntu 20.04 for minimal glibc
# - buildjet-16vcpu-ubuntu-2004
needs: tests needs: tests
steps: steps:
- name: Checkout repo - name: Checkout repo
@ -169,7 +168,7 @@ jobs:
name: Create a Linux *.tar.gz bundle for ARM name: Create a Linux *.tar.gz bundle for ARM
if: github.repository_owner == 'zed-industries' if: github.repository_owner == 'zed-industries'
runs-on: runs-on:
- github-16vcpu-ubuntu-2204-arm - namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc
needs: tests needs: tests
steps: steps:
- name: Checkout repo - name: Checkout repo

View file

@ -23,7 +23,7 @@ jobs:
timeout-minutes: 60 timeout-minutes: 60
name: Run unit evals name: Run unit evals
runs-on: runs-on:
- github-16vcpu-ubuntu-2204 - namespace-profile-16x32-ubuntu-2204
steps: steps:
- name: Add Rust to the PATH - name: Add Rust to the PATH
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"

1
Cargo.lock generated
View file

@ -9132,6 +9132,7 @@ dependencies = [
"anyhow", "anyhow",
"base64 0.22.1", "base64 0.22.1",
"client", "client",
"cloud_api_types",
"cloud_llm_client", "cloud_llm_client",
"collections", "collections",
"futures 0.3.31", "futures 0.3.31",

View file

@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
use project::Project; use project::Project;
use settings::SettingsStore; use settings::SettingsStore;
use smol::process::Child; use smol::process::Child;
use std::cell::{Cell, RefCell}; use std::cell::RefCell;
use std::fmt::Display; use std::fmt::Display;
use std::path::Path; use std::path::Path;
use std::rc::Rc; use std::rc::Rc;
@ -153,20 +153,17 @@ impl AgentConnection for ClaudeAgentConnection {
}) })
.detach(); .detach();
let pending_cancellation = Rc::new(Cell::new(PendingCancellation::None)); let turn_state = Rc::new(RefCell::new(TurnState::None));
let end_turn_tx = Rc::new(RefCell::new(None));
let handler_task = cx.spawn({ let handler_task = cx.spawn({
let end_turn_tx = end_turn_tx.clone(); let turn_state = turn_state.clone();
let mut thread_rx = thread_rx.clone(); let mut thread_rx = thread_rx.clone();
let cancellation_state = pending_cancellation.clone();
async move |cx| { async move |cx| {
while let Some(message) = incoming_message_rx.next().await { while let Some(message) = incoming_message_rx.next().await {
ClaudeAgentSession::handle_message( ClaudeAgentSession::handle_message(
thread_rx.clone(), thread_rx.clone(),
message, message,
end_turn_tx.clone(), turn_state.clone(),
cancellation_state.clone(),
cx, cx,
) )
.await .await
@ -192,8 +189,7 @@ impl AgentConnection for ClaudeAgentConnection {
let session = ClaudeAgentSession { let session = ClaudeAgentSession {
outgoing_tx, outgoing_tx,
end_turn_tx, turn_state,
pending_cancellation,
_handler_task: handler_task, _handler_task: handler_task,
_mcp_server: Some(permission_mcp_server), _mcp_server: Some(permission_mcp_server),
}; };
@ -225,8 +221,8 @@ impl AgentConnection for ClaudeAgentConnection {
))); )));
}; };
let (tx, rx) = oneshot::channel(); let (end_tx, end_rx) = oneshot::channel();
session.end_turn_tx.borrow_mut().replace(tx); session.turn_state.replace(TurnState::InProgress { end_tx });
let mut content = String::new(); let mut content = String::new();
for chunk in params.prompt { for chunk in params.prompt {
@ -260,12 +256,7 @@ impl AgentConnection for ClaudeAgentConnection {
return Task::ready(Err(anyhow!(err))); return Task::ready(Err(anyhow!(err)));
} }
let cancellation_state = session.pending_cancellation.clone(); cx.foreground_executor().spawn(async move { end_rx.await? })
cx.foreground_executor().spawn(async move {
let result = rx.await??;
cancellation_state.set(PendingCancellation::None);
Ok(result)
})
} }
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
@ -277,7 +268,15 @@ impl AgentConnection for ClaudeAgentConnection {
let request_id = new_request_id(); let request_id = new_request_id();
session.pending_cancellation.set(PendingCancellation::Sent { let turn_state = session.turn_state.take();
let TurnState::InProgress { end_tx } = turn_state else {
// Already cancelled or idle, put it back
session.turn_state.replace(turn_state);
return;
};
session.turn_state.replace(TurnState::CancelRequested {
end_tx,
request_id: request_id.clone(), request_id: request_id.clone(),
}); });
@ -349,28 +348,56 @@ fn spawn_claude(
struct ClaudeAgentSession { struct ClaudeAgentSession {
outgoing_tx: UnboundedSender<SdkMessage>, outgoing_tx: UnboundedSender<SdkMessage>,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>, turn_state: Rc<RefCell<TurnState>>,
pending_cancellation: Rc<Cell<PendingCancellation>>,
_mcp_server: Option<ClaudeZedMcpServer>, _mcp_server: Option<ClaudeZedMcpServer>,
_handler_task: Task<()>, _handler_task: Task<()>,
} }
#[derive(Debug, Default, PartialEq)] #[derive(Debug, Default)]
enum PendingCancellation { enum TurnState {
#[default] #[default]
None, None,
Sent { InProgress {
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
},
CancelRequested {
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
request_id: String, request_id: String,
}, },
Confirmed, CancelConfirmed {
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
},
}
impl TurnState {
fn is_cancelled(&self) -> bool {
matches!(self, TurnState::CancelConfirmed { .. })
}
fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
match self {
TurnState::None => None,
TurnState::InProgress { end_tx, .. } => Some(end_tx),
TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
TurnState::CancelConfirmed { end_tx } => Some(end_tx),
}
}
fn confirm_cancellation(self, id: &str) -> Self {
match self {
TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
TurnState::CancelConfirmed { end_tx }
}
_ => self,
}
}
} }
impl ClaudeAgentSession { impl ClaudeAgentSession {
async fn handle_message( async fn handle_message(
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>, mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
message: SdkMessage, message: SdkMessage,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>, turn_state: Rc<RefCell<TurnState>>,
pending_cancellation: Rc<Cell<PendingCancellation>>,
cx: &mut AsyncApp, cx: &mut AsyncApp,
) { ) {
match message { match message {
@ -393,15 +420,13 @@ impl ClaudeAgentSession {
for chunk in message.content.chunks() { for chunk in message.content.chunks() {
match chunk { match chunk {
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
let state = pending_cancellation.take(); if !turn_state.borrow().is_cancelled() {
if state != PendingCancellation::Confirmed {
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.push_user_content_block(text.into(), cx) thread.push_user_content_block(text.into(), cx)
}) })
.log_err(); .log_err();
} }
pending_cancellation.set(state);
} }
ContentChunk::ToolResult { ContentChunk::ToolResult {
content, content,
@ -414,7 +439,12 @@ impl ClaudeAgentSession {
acp::ToolCallUpdate { acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.into()), id: acp::ToolCallId(tool_use_id.into()),
fields: acp::ToolCallUpdateFields { fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed), status: if turn_state.borrow().is_cancelled() {
// Do not set to completed if turn was cancelled
None
} else {
Some(acp::ToolCallStatus::Completed)
},
content: (!content.is_empty()) content: (!content.is_empty())
.then(|| vec![content.into()]), .then(|| vec![content.into()]),
..Default::default() ..Default::default()
@ -541,40 +571,38 @@ impl ClaudeAgentSession {
result, result,
.. ..
} => { } => {
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() { let turn_state = turn_state.take();
if is_error let was_cancelled = turn_state.is_cancelled();
|| (subtype == ResultErrorType::ErrorDuringExecution let Some(end_turn_tx) = turn_state.end_tx() else {
&& pending_cancellation.take() != PendingCancellation::Confirmed) debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
{ return;
end_turn_tx };
.send(Err(anyhow!(
"Error: {}", if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution)
result.unwrap_or_else(|| subtype.to_string()) {
))) end_turn_tx
.ok(); .send(Err(anyhow!(
} else { "Error: {}",
let stop_reason = match subtype { result.unwrap_or_else(|| subtype.to_string())
ResultErrorType::Success => acp::StopReason::EndTurn, )))
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests, .ok();
ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled, } else {
}; let stop_reason = match subtype {
end_turn_tx ResultErrorType::Success => acp::StopReason::EndTurn,
.send(Ok(acp::PromptResponse { stop_reason })) ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
.ok(); ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
} };
end_turn_tx
.send(Ok(acp::PromptResponse { stop_reason }))
.ok();
} }
} }
SdkMessage::ControlResponse { response } => { SdkMessage::ControlResponse { response } => {
if matches!(response.subtype, ResultErrorType::Success) { if matches!(response.subtype, ResultErrorType::Success) {
let pending_cancellation_value = pending_cancellation.take(); let new_state = turn_state.take().confirm_cancellation(&response.request_id);
turn_state.replace(new_state);
if let PendingCancellation::Sent { request_id } = &pending_cancellation_value } else {
&& request_id == &response.request_id log::error!("Control response error: {:?}", response);
{
pending_cancellation.set(PendingCancellation::Confirmed);
} else {
pending_cancellation.set(pending_cancellation_value);
}
} }
} }
SdkMessage::System { .. } => {} SdkMessage::System { .. } => {}

View file

@ -246,7 +246,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
let full_turn = thread.update(cx, |thread, cx| { let _ = thread.update(cx, |thread, cx| {
thread.send_raw( thread.send_raw(
r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
cx, cx,
@ -285,9 +285,8 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
id.clone() id.clone()
}); });
let _ = thread.update(cx, |thread, cx| thread.cancel(cx)); thread.update(cx, |thread, cx| thread.cancel(cx)).await;
full_turn.await.unwrap(); thread.read_with(cx, |thread, _cx| {
thread.read_with(cx, |thread, _| {
let AgentThreadEntry::ToolCall(ToolCall { let AgentThreadEntry::ToolCall(ToolCall {
status: ToolCallStatus::Canceled, status: ToolCallStatus::Canceled,
.. ..

View file

@ -193,7 +193,7 @@ pub fn init(client: &Arc<Client>, cx: &mut App) {
}); });
} }
pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &App) + Send + Sync + 'static>; pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &mut App) + Send + Sync + 'static>;
struct GlobalClient(Arc<Client>); struct GlobalClient(Arc<Client>);
@ -1684,7 +1684,7 @@ impl Client {
pub fn add_message_to_client_handler( pub fn add_message_to_client_handler(
self: &Arc<Client>, self: &Arc<Client>,
handler: impl Fn(&MessageToClient, &App) + Send + Sync + 'static, handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
) { ) {
self.message_to_client_handlers self.message_to_client_handlers
.lock() .lock()

View file

@ -41,9 +41,11 @@ use chrono::Utc;
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion}; pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter}; use core::fmt::{self, Debug, Formatter};
use futures::TryFutureExt as _;
use reqwest_client::ReqwestClient; use reqwest_client::ReqwestClient;
use rpc::proto::{MultiLspQuery, split_repository_update}; use rpc::proto::{MultiLspQuery, split_repository_update};
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
use tracing::Span;
use futures::{ use futures::{
FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture, FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
@ -94,8 +96,13 @@ const MAX_CONCURRENT_CONNECTIONS: usize = 512;
static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0); static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
const TOTAL_DURATION_MS: &str = "total_duration_ms";
const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
const QUEUE_DURATION_MS: &str = "queue_duration_ms";
const HOST_WAITING_MS: &str = "host_waiting_ms";
type MessageHandler = type MessageHandler =
Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>; Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
pub struct ConnectionGuard; pub struct ConnectionGuard;
@ -163,6 +170,42 @@ impl Principal {
} }
} }
#[derive(Clone)]
struct MessageContext {
session: Session,
span: tracing::Span,
}
impl Deref for MessageContext {
type Target = Session;
fn deref(&self) -> &Self::Target {
&self.session
}
}
impl MessageContext {
pub fn forward_request<T: RequestMessage>(
&self,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = anyhow::Result<T::Response>> {
let request_start_time = Instant::now();
let span = self.span.clone();
tracing::info!("start forwarding request");
self.peer
.forward_request(self.connection_id, receiver_id, request)
.inspect(move |_| {
span.record(
HOST_WAITING_MS,
request_start_time.elapsed().as_micros() as f64 / 1000.0,
);
})
.inspect_err(|_| tracing::error!("error forwarding request"))
.inspect_ok(|_| tracing::info!("finished forwarding request"))
}
}
#[derive(Clone)] #[derive(Clone)]
struct Session { struct Session {
principal: Principal, principal: Principal,
@ -646,40 +689,37 @@ impl Server {
fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where where
F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut, F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
Fut: 'static + Send + Future<Output = Result<()>>, Fut: 'static + Send + Future<Output = Result<()>>,
M: EnvelopedMessage, M: EnvelopedMessage,
{ {
let prev_handler = self.handlers.insert( let prev_handler = self.handlers.insert(
TypeId::of::<M>(), TypeId::of::<M>(),
Box::new(move |envelope, session| { Box::new(move |envelope, session, span| {
let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap(); let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
let received_at = envelope.received_at; let received_at = envelope.received_at;
tracing::info!("message received"); tracing::info!("message received");
let start_time = Instant::now(); let start_time = Instant::now();
let future = (handler)(*envelope, session); let future = (handler)(
*envelope,
MessageContext {
session,
span: span.clone(),
},
);
async move { async move {
let result = future.await; let result = future.await;
let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0; let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0; let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
let queue_duration_ms = total_duration_ms - processing_duration_ms; let queue_duration_ms = total_duration_ms - processing_duration_ms;
span.record(TOTAL_DURATION_MS, total_duration_ms);
span.record(PROCESSING_DURATION_MS, processing_duration_ms);
span.record(QUEUE_DURATION_MS, queue_duration_ms);
match result { match result {
Err(error) => { Err(error) => {
tracing::error!( tracing::error!(?error, "error handling message")
?error,
total_duration_ms,
processing_duration_ms,
queue_duration_ms,
"error handling message"
)
} }
Ok(()) => tracing::info!( Ok(()) => tracing::info!("finished handling message"),
total_duration_ms,
processing_duration_ms,
queue_duration_ms,
"finished handling message"
),
} }
} }
.boxed() .boxed()
@ -693,7 +733,7 @@ impl Server {
fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where where
F: 'static + Send + Sync + Fn(M, Session) -> Fut, F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
Fut: 'static + Send + Future<Output = Result<()>>, Fut: 'static + Send + Future<Output = Result<()>>,
M: EnvelopedMessage, M: EnvelopedMessage,
{ {
@ -703,7 +743,7 @@ impl Server {
fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where where
F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut, F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
Fut: Send + Future<Output = Result<()>>, Fut: Send + Future<Output = Result<()>>,
M: RequestMessage, M: RequestMessage,
{ {
@ -889,12 +929,16 @@ impl Server {
login=field::Empty, login=field::Empty,
impersonator=field::Empty, impersonator=field::Empty,
multi_lsp_query_request=field::Empty, multi_lsp_query_request=field::Empty,
{ TOTAL_DURATION_MS }=field::Empty,
{ PROCESSING_DURATION_MS }=field::Empty,
{ QUEUE_DURATION_MS }=field::Empty,
{ HOST_WAITING_MS }=field::Empty
); );
principal.update_span(&span); principal.update_span(&span);
let span_enter = span.enter(); let span_enter = span.enter();
if let Some(handler) = this.handlers.get(&message.payload_type_id()) { if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
let is_background = message.is_background(); let is_background = message.is_background();
let handle_message = (handler)(message, session.clone()); let handle_message = (handler)(message, session.clone(), span.clone());
drop(span_enter); drop(span_enter);
let handle_message = async move { let handle_message = async move {
@ -1386,7 +1430,11 @@ async fn connection_lost(
} }
/// Acknowledges a ping from a client, used to keep the connection alive. /// Acknowledges a ping from a client, used to keep the connection alive.
async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> { async fn ping(
_: proto::Ping,
response: Response<proto::Ping>,
_session: MessageContext,
) -> Result<()> {
response.send(proto::Ack {})?; response.send(proto::Ack {})?;
Ok(()) Ok(())
} }
@ -1395,7 +1443,7 @@ async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session
async fn create_room( async fn create_room(
_request: proto::CreateRoom, _request: proto::CreateRoom,
response: Response<proto::CreateRoom>, response: Response<proto::CreateRoom>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let livekit_room = nanoid::nanoid!(30); let livekit_room = nanoid::nanoid!(30);
@ -1435,7 +1483,7 @@ async fn create_room(
async fn join_room( async fn join_room(
request: proto::JoinRoom, request: proto::JoinRoom,
response: Response<proto::JoinRoom>, response: Response<proto::JoinRoom>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let room_id = RoomId::from_proto(request.id); let room_id = RoomId::from_proto(request.id);
@ -1502,7 +1550,7 @@ async fn join_room(
async fn rejoin_room( async fn rejoin_room(
request: proto::RejoinRoom, request: proto::RejoinRoom,
response: Response<proto::RejoinRoom>, response: Response<proto::RejoinRoom>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let room; let room;
let channel; let channel;
@ -1679,7 +1727,7 @@ fn notify_rejoined_projects(
async fn leave_room( async fn leave_room(
_: proto::LeaveRoom, _: proto::LeaveRoom,
response: Response<proto::LeaveRoom>, response: Response<proto::LeaveRoom>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
leave_room_for_session(&session, session.connection_id).await?; leave_room_for_session(&session, session.connection_id).await?;
response.send(proto::Ack {})?; response.send(proto::Ack {})?;
@ -1690,7 +1738,7 @@ async fn leave_room(
async fn set_room_participant_role( async fn set_room_participant_role(
request: proto::SetRoomParticipantRole, request: proto::SetRoomParticipantRole,
response: Response<proto::SetRoomParticipantRole>, response: Response<proto::SetRoomParticipantRole>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let user_id = UserId::from_proto(request.user_id); let user_id = UserId::from_proto(request.user_id);
let role = ChannelRole::from(request.role()); let role = ChannelRole::from(request.role());
@ -1738,7 +1786,7 @@ async fn set_room_participant_role(
async fn call( async fn call(
request: proto::Call, request: proto::Call,
response: Response<proto::Call>, response: Response<proto::Call>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id); let room_id = RoomId::from_proto(request.room_id);
let calling_user_id = session.user_id(); let calling_user_id = session.user_id();
@ -1807,7 +1855,7 @@ async fn call(
async fn cancel_call( async fn cancel_call(
request: proto::CancelCall, request: proto::CancelCall,
response: Response<proto::CancelCall>, response: Response<proto::CancelCall>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let called_user_id = UserId::from_proto(request.called_user_id); let called_user_id = UserId::from_proto(request.called_user_id);
let room_id = RoomId::from_proto(request.room_id); let room_id = RoomId::from_proto(request.room_id);
@ -1842,7 +1890,7 @@ async fn cancel_call(
} }
/// Decline an incoming call. /// Decline an incoming call.
async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> { async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
let room_id = RoomId::from_proto(message.room_id); let room_id = RoomId::from_proto(message.room_id);
{ {
let room = session let room = session
@ -1877,7 +1925,7 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<(
async fn update_participant_location( async fn update_participant_location(
request: proto::UpdateParticipantLocation, request: proto::UpdateParticipantLocation,
response: Response<proto::UpdateParticipantLocation>, response: Response<proto::UpdateParticipantLocation>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id); let room_id = RoomId::from_proto(request.room_id);
let location = request.location.context("invalid location")?; let location = request.location.context("invalid location")?;
@ -1896,7 +1944,7 @@ async fn update_participant_location(
async fn share_project( async fn share_project(
request: proto::ShareProject, request: proto::ShareProject,
response: Response<proto::ShareProject>, response: Response<proto::ShareProject>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let (project_id, room) = &*session let (project_id, room) = &*session
.db() .db()
@ -1917,7 +1965,7 @@ async fn share_project(
} }
/// Unshare a project from the room. /// Unshare a project from the room.
async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> { async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
let project_id = ProjectId::from_proto(message.project_id); let project_id = ProjectId::from_proto(message.project_id);
unshare_project_internal(project_id, session.connection_id, &session).await unshare_project_internal(project_id, session.connection_id, &session).await
} }
@ -1964,7 +2012,7 @@ async fn unshare_project_internal(
async fn join_project( async fn join_project(
request: proto::JoinProject, request: proto::JoinProject,
response: Response<proto::JoinProject>, response: Response<proto::JoinProject>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id); let project_id = ProjectId::from_proto(request.project_id);
@ -2111,7 +2159,7 @@ async fn join_project(
} }
/// Leave someone elses shared project. /// Leave someone elses shared project.
async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> { async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
let sender_id = session.connection_id; let sender_id = session.connection_id;
let project_id = ProjectId::from_proto(request.project_id); let project_id = ProjectId::from_proto(request.project_id);
let db = session.db().await; let db = session.db().await;
@ -2134,7 +2182,7 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result
async fn update_project( async fn update_project(
request: proto::UpdateProject, request: proto::UpdateProject,
response: Response<proto::UpdateProject>, response: Response<proto::UpdateProject>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id); let project_id = ProjectId::from_proto(request.project_id);
let (room, guest_connection_ids) = &*session let (room, guest_connection_ids) = &*session
@ -2163,7 +2211,7 @@ async fn update_project(
async fn update_worktree( async fn update_worktree(
request: proto::UpdateWorktree, request: proto::UpdateWorktree,
response: Response<proto::UpdateWorktree>, response: Response<proto::UpdateWorktree>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let guest_connection_ids = session let guest_connection_ids = session
.db() .db()
@ -2187,7 +2235,7 @@ async fn update_worktree(
async fn update_repository( async fn update_repository(
request: proto::UpdateRepository, request: proto::UpdateRepository,
response: Response<proto::UpdateRepository>, response: Response<proto::UpdateRepository>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let guest_connection_ids = session let guest_connection_ids = session
.db() .db()
@ -2211,7 +2259,7 @@ async fn update_repository(
async fn remove_repository( async fn remove_repository(
request: proto::RemoveRepository, request: proto::RemoveRepository,
response: Response<proto::RemoveRepository>, response: Response<proto::RemoveRepository>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let guest_connection_ids = session let guest_connection_ids = session
.db() .db()
@ -2235,7 +2283,7 @@ async fn remove_repository(
/// Updates other participants with changes to the diagnostics /// Updates other participants with changes to the diagnostics
async fn update_diagnostic_summary( async fn update_diagnostic_summary(
message: proto::UpdateDiagnosticSummary, message: proto::UpdateDiagnosticSummary,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let guest_connection_ids = session let guest_connection_ids = session
.db() .db()
@ -2259,7 +2307,7 @@ async fn update_diagnostic_summary(
/// Updates other participants with changes to the worktree settings /// Updates other participants with changes to the worktree settings
async fn update_worktree_settings( async fn update_worktree_settings(
message: proto::UpdateWorktreeSettings, message: proto::UpdateWorktreeSettings,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let guest_connection_ids = session let guest_connection_ids = session
.db() .db()
@ -2283,7 +2331,7 @@ async fn update_worktree_settings(
/// Notify other participants that a language server has started. /// Notify other participants that a language server has started.
async fn start_language_server( async fn start_language_server(
request: proto::StartLanguageServer, request: proto::StartLanguageServer,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let guest_connection_ids = session let guest_connection_ids = session
.db() .db()
@ -2306,7 +2354,7 @@ async fn start_language_server(
/// Notify other participants that a language server has changed. /// Notify other participants that a language server has changed.
async fn update_language_server( async fn update_language_server(
request: proto::UpdateLanguageServer, request: proto::UpdateLanguageServer,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id); let project_id = ProjectId::from_proto(request.project_id);
let db = session.db().await; let db = session.db().await;
@ -2339,7 +2387,7 @@ async fn update_language_server(
async fn forward_read_only_project_request<T>( async fn forward_read_only_project_request<T>(
request: T, request: T,
response: Response<T>, response: Response<T>,
session: Session, session: MessageContext,
) -> Result<()> ) -> Result<()>
where where
T: EntityMessage + RequestMessage, T: EntityMessage + RequestMessage,
@ -2350,10 +2398,7 @@ where
.await .await
.host_for_read_only_project_request(project_id, session.connection_id) .host_for_read_only_project_request(project_id, session.connection_id)
.await?; .await?;
let payload = session let payload = session.forward_request(host_connection_id, request).await?;
.peer
.forward_request(session.connection_id, host_connection_id, request)
.await?;
response.send(payload)?; response.send(payload)?;
Ok(()) Ok(())
} }
@ -2363,7 +2408,7 @@ where
async fn forward_mutating_project_request<T>( async fn forward_mutating_project_request<T>(
request: T, request: T,
response: Response<T>, response: Response<T>,
session: Session, session: MessageContext,
) -> Result<()> ) -> Result<()>
where where
T: EntityMessage + RequestMessage, T: EntityMessage + RequestMessage,
@ -2375,10 +2420,7 @@ where
.await .await
.host_for_mutating_project_request(project_id, session.connection_id) .host_for_mutating_project_request(project_id, session.connection_id)
.await?; .await?;
let payload = session let payload = session.forward_request(host_connection_id, request).await?;
.peer
.forward_request(session.connection_id, host_connection_id, request)
.await?;
response.send(payload)?; response.send(payload)?;
Ok(()) Ok(())
} }
@ -2386,7 +2428,7 @@ where
async fn multi_lsp_query( async fn multi_lsp_query(
request: MultiLspQuery, request: MultiLspQuery,
response: Response<MultiLspQuery>, response: Response<MultiLspQuery>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
tracing::Span::current().record("multi_lsp_query_request", request.request_str()); tracing::Span::current().record("multi_lsp_query_request", request.request_str());
tracing::info!("multi_lsp_query message received"); tracing::info!("multi_lsp_query message received");
@ -2396,7 +2438,7 @@ async fn multi_lsp_query(
/// Notify other participants that a new buffer has been created /// Notify other participants that a new buffer has been created
async fn create_buffer_for_peer( async fn create_buffer_for_peer(
request: proto::CreateBufferForPeer, request: proto::CreateBufferForPeer,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
session session
.db() .db()
@ -2418,7 +2460,7 @@ async fn create_buffer_for_peer(
async fn update_buffer( async fn update_buffer(
request: proto::UpdateBuffer, request: proto::UpdateBuffer,
response: Response<proto::UpdateBuffer>, response: Response<proto::UpdateBuffer>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id); let project_id = ProjectId::from_proto(request.project_id);
let mut capability = Capability::ReadOnly; let mut capability = Capability::ReadOnly;
@ -2453,17 +2495,14 @@ async fn update_buffer(
}; };
if host != session.connection_id { if host != session.connection_id {
session session.forward_request(host, request.clone()).await?;
.peer
.forward_request(session.connection_id, host, request.clone())
.await?;
} }
response.send(proto::Ack {})?; response.send(proto::Ack {})?;
Ok(()) Ok(())
} }
async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> { async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
let project_id = ProjectId::from_proto(message.project_id); let project_id = ProjectId::from_proto(message.project_id);
let operation = message.operation.as_ref().context("invalid operation")?; let operation = message.operation.as_ref().context("invalid operation")?;
@ -2508,7 +2547,7 @@ async fn update_context(message: proto::UpdateContext, session: Session) -> Resu
/// Notify other participants that a project has been updated. /// Notify other participants that a project has been updated.
async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>( async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
request: T, request: T,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let project_id = ProjectId::from_proto(request.remote_entity_id()); let project_id = ProjectId::from_proto(request.remote_entity_id());
let project_connection_ids = session let project_connection_ids = session
@ -2533,7 +2572,7 @@ async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProj
async fn follow( async fn follow(
request: proto::Follow, request: proto::Follow,
response: Response<proto::Follow>, response: Response<proto::Follow>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id); let room_id = RoomId::from_proto(request.room_id);
let project_id = request.project_id.map(ProjectId::from_proto); let project_id = request.project_id.map(ProjectId::from_proto);
@ -2546,10 +2585,7 @@ async fn follow(
.check_room_participants(room_id, leader_id, session.connection_id) .check_room_participants(room_id, leader_id, session.connection_id)
.await?; .await?;
let response_payload = session let response_payload = session.forward_request(leader_id, request).await?;
.peer
.forward_request(session.connection_id, leader_id, request)
.await?;
response.send(response_payload)?; response.send(response_payload)?;
if let Some(project_id) = project_id { if let Some(project_id) = project_id {
@ -2565,7 +2601,7 @@ async fn follow(
} }
/// Stop following another user in a call. /// Stop following another user in a call.
async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> { async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id); let room_id = RoomId::from_proto(request.room_id);
let project_id = request.project_id.map(ProjectId::from_proto); let project_id = request.project_id.map(ProjectId::from_proto);
let leader_id = request.leader_id.context("invalid leader id")?.into(); let leader_id = request.leader_id.context("invalid leader id")?.into();
@ -2594,7 +2630,7 @@ async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
} }
/// Notify everyone following you of your current location. /// Notify everyone following you of your current location.
async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> { async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id); let room_id = RoomId::from_proto(request.room_id);
let database = session.db.lock().await; let database = session.db.lock().await;
@ -2629,7 +2665,7 @@ async fn update_followers(request: proto::UpdateFollowers, session: Session) ->
async fn get_users( async fn get_users(
request: proto::GetUsers, request: proto::GetUsers,
response: Response<proto::GetUsers>, response: Response<proto::GetUsers>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let user_ids = request let user_ids = request
.user_ids .user_ids
@ -2657,7 +2693,7 @@ async fn get_users(
async fn fuzzy_search_users( async fn fuzzy_search_users(
request: proto::FuzzySearchUsers, request: proto::FuzzySearchUsers,
response: Response<proto::FuzzySearchUsers>, response: Response<proto::FuzzySearchUsers>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let query = request.query; let query = request.query;
let users = match query.len() { let users = match query.len() {
@ -2689,7 +2725,7 @@ async fn fuzzy_search_users(
async fn request_contact( async fn request_contact(
request: proto::RequestContact, request: proto::RequestContact,
response: Response<proto::RequestContact>, response: Response<proto::RequestContact>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let requester_id = session.user_id(); let requester_id = session.user_id();
let responder_id = UserId::from_proto(request.responder_id); let responder_id = UserId::from_proto(request.responder_id);
@ -2736,7 +2772,7 @@ async fn request_contact(
async fn respond_to_contact_request( async fn respond_to_contact_request(
request: proto::RespondToContactRequest, request: proto::RespondToContactRequest,
response: Response<proto::RespondToContactRequest>, response: Response<proto::RespondToContactRequest>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let responder_id = session.user_id(); let responder_id = session.user_id();
let requester_id = UserId::from_proto(request.requester_id); let requester_id = UserId::from_proto(request.requester_id);
@ -2794,7 +2830,7 @@ async fn respond_to_contact_request(
async fn remove_contact( async fn remove_contact(
request: proto::RemoveContact, request: proto::RemoveContact,
response: Response<proto::RemoveContact>, response: Response<proto::RemoveContact>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let requester_id = session.user_id(); let requester_id = session.user_id();
let responder_id = UserId::from_proto(request.user_id); let responder_id = UserId::from_proto(request.user_id);
@ -3053,7 +3089,10 @@ async fn update_user_plan(session: &Session) -> Result<()> {
Ok(()) Ok(())
} }
async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> { async fn subscribe_to_channels(
_: proto::SubscribeToChannels,
session: MessageContext,
) -> Result<()> {
subscribe_user_to_channels(session.user_id(), &session).await?; subscribe_user_to_channels(session.user_id(), &session).await?;
Ok(()) Ok(())
} }
@ -3079,7 +3118,7 @@ async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Resul
async fn create_channel( async fn create_channel(
request: proto::CreateChannel, request: proto::CreateChannel,
response: Response<proto::CreateChannel>, response: Response<proto::CreateChannel>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
@ -3134,7 +3173,7 @@ async fn create_channel(
async fn delete_channel( async fn delete_channel(
request: proto::DeleteChannel, request: proto::DeleteChannel,
response: Response<proto::DeleteChannel>, response: Response<proto::DeleteChannel>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
@ -3162,7 +3201,7 @@ async fn delete_channel(
async fn invite_channel_member( async fn invite_channel_member(
request: proto::InviteChannelMember, request: proto::InviteChannelMember,
response: Response<proto::InviteChannelMember>, response: Response<proto::InviteChannelMember>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
@ -3199,7 +3238,7 @@ async fn invite_channel_member(
async fn remove_channel_member( async fn remove_channel_member(
request: proto::RemoveChannelMember, request: proto::RemoveChannelMember,
response: Response<proto::RemoveChannelMember>, response: Response<proto::RemoveChannelMember>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
@ -3243,7 +3282,7 @@ async fn remove_channel_member(
async fn set_channel_visibility( async fn set_channel_visibility(
request: proto::SetChannelVisibility, request: proto::SetChannelVisibility,
response: Response<proto::SetChannelVisibility>, response: Response<proto::SetChannelVisibility>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
@ -3288,7 +3327,7 @@ async fn set_channel_visibility(
async fn set_channel_member_role( async fn set_channel_member_role(
request: proto::SetChannelMemberRole, request: proto::SetChannelMemberRole,
response: Response<proto::SetChannelMemberRole>, response: Response<proto::SetChannelMemberRole>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
@ -3336,7 +3375,7 @@ async fn set_channel_member_role(
async fn rename_channel( async fn rename_channel(
request: proto::RenameChannel, request: proto::RenameChannel,
response: Response<proto::RenameChannel>, response: Response<proto::RenameChannel>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
@ -3368,7 +3407,7 @@ async fn rename_channel(
async fn move_channel( async fn move_channel(
request: proto::MoveChannel, request: proto::MoveChannel,
response: Response<proto::MoveChannel>, response: Response<proto::MoveChannel>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
let to = ChannelId::from_proto(request.to); let to = ChannelId::from_proto(request.to);
@ -3410,7 +3449,7 @@ async fn move_channel(
async fn reorder_channel( async fn reorder_channel(
request: proto::ReorderChannel, request: proto::ReorderChannel,
response: Response<proto::ReorderChannel>, response: Response<proto::ReorderChannel>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
let direction = request.direction(); let direction = request.direction();
@ -3456,7 +3495,7 @@ async fn reorder_channel(
async fn get_channel_members( async fn get_channel_members(
request: proto::GetChannelMembers, request: proto::GetChannelMembers,
response: Response<proto::GetChannelMembers>, response: Response<proto::GetChannelMembers>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
@ -3476,7 +3515,7 @@ async fn get_channel_members(
async fn respond_to_channel_invite( async fn respond_to_channel_invite(
request: proto::RespondToChannelInvite, request: proto::RespondToChannelInvite,
response: Response<proto::RespondToChannelInvite>, response: Response<proto::RespondToChannelInvite>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
@ -3517,7 +3556,7 @@ async fn respond_to_channel_invite(
async fn join_channel( async fn join_channel(
request: proto::JoinChannel, request: proto::JoinChannel,
response: Response<proto::JoinChannel>, response: Response<proto::JoinChannel>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
join_channel_internal(channel_id, Box::new(response), session).await join_channel_internal(channel_id, Box::new(response), session).await
@ -3540,7 +3579,7 @@ impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
async fn join_channel_internal( async fn join_channel_internal(
channel_id: ChannelId, channel_id: ChannelId,
response: Box<impl JoinChannelInternalResponse>, response: Box<impl JoinChannelInternalResponse>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let joined_room = { let joined_room = {
let mut db = session.db().await; let mut db = session.db().await;
@ -3635,7 +3674,7 @@ async fn join_channel_internal(
async fn join_channel_buffer( async fn join_channel_buffer(
request: proto::JoinChannelBuffer, request: proto::JoinChannelBuffer,
response: Response<proto::JoinChannelBuffer>, response: Response<proto::JoinChannelBuffer>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
@ -3666,7 +3705,7 @@ async fn join_channel_buffer(
/// Edit the channel notes /// Edit the channel notes
async fn update_channel_buffer( async fn update_channel_buffer(
request: proto::UpdateChannelBuffer, request: proto::UpdateChannelBuffer,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
@ -3718,7 +3757,7 @@ async fn update_channel_buffer(
async fn rejoin_channel_buffers( async fn rejoin_channel_buffers(
request: proto::RejoinChannelBuffers, request: proto::RejoinChannelBuffers,
response: Response<proto::RejoinChannelBuffers>, response: Response<proto::RejoinChannelBuffers>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
let buffers = db let buffers = db
@ -3753,7 +3792,7 @@ async fn rejoin_channel_buffers(
async fn leave_channel_buffer( async fn leave_channel_buffer(
request: proto::LeaveChannelBuffer, request: proto::LeaveChannelBuffer,
response: Response<proto::LeaveChannelBuffer>, response: Response<proto::LeaveChannelBuffer>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
@ -3815,7 +3854,7 @@ fn send_notifications(
async fn send_channel_message( async fn send_channel_message(
request: proto::SendChannelMessage, request: proto::SendChannelMessage,
response: Response<proto::SendChannelMessage>, response: Response<proto::SendChannelMessage>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
// Validate the message body. // Validate the message body.
let body = request.body.trim().to_string(); let body = request.body.trim().to_string();
@ -3908,7 +3947,7 @@ async fn send_channel_message(
async fn remove_channel_message( async fn remove_channel_message(
request: proto::RemoveChannelMessage, request: proto::RemoveChannelMessage,
response: Response<proto::RemoveChannelMessage>, response: Response<proto::RemoveChannelMessage>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id); let message_id = MessageId::from_proto(request.message_id);
@ -3943,7 +3982,7 @@ async fn remove_channel_message(
async fn update_channel_message( async fn update_channel_message(
request: proto::UpdateChannelMessage, request: proto::UpdateChannelMessage,
response: Response<proto::UpdateChannelMessage>, response: Response<proto::UpdateChannelMessage>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id); let message_id = MessageId::from_proto(request.message_id);
@ -4027,7 +4066,7 @@ async fn update_channel_message(
/// Mark a channel message as read /// Mark a channel message as read
async fn acknowledge_channel_message( async fn acknowledge_channel_message(
request: proto::AckChannelMessage, request: proto::AckChannelMessage,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id); let message_id = MessageId::from_proto(request.message_id);
@ -4047,7 +4086,7 @@ async fn acknowledge_channel_message(
/// Mark a buffer version as synced /// Mark a buffer version as synced
async fn acknowledge_buffer_version( async fn acknowledge_buffer_version(
request: proto::AckBufferOperation, request: proto::AckBufferOperation,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let buffer_id = BufferId::from_proto(request.buffer_id); let buffer_id = BufferId::from_proto(request.buffer_id);
session session
@ -4067,7 +4106,7 @@ async fn acknowledge_buffer_version(
async fn get_supermaven_api_key( async fn get_supermaven_api_key(
_request: proto::GetSupermavenApiKey, _request: proto::GetSupermavenApiKey,
response: Response<proto::GetSupermavenApiKey>, response: Response<proto::GetSupermavenApiKey>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let user_id: String = session.user_id().to_string(); let user_id: String = session.user_id().to_string();
if !session.is_staff() { if !session.is_staff() {
@ -4096,7 +4135,7 @@ async fn get_supermaven_api_key(
async fn join_channel_chat( async fn join_channel_chat(
request: proto::JoinChannelChat, request: proto::JoinChannelChat,
response: Response<proto::JoinChannelChat>, response: Response<proto::JoinChannelChat>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
@ -4114,7 +4153,10 @@ async fn join_channel_chat(
} }
/// Stop receiving chat updates for a channel /// Stop receiving chat updates for a channel
async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> { async fn leave_channel_chat(
request: proto::LeaveChannelChat,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
session session
.db() .db()
@ -4128,7 +4170,7 @@ async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session)
async fn get_channel_messages( async fn get_channel_messages(
request: proto::GetChannelMessages, request: proto::GetChannelMessages,
response: Response<proto::GetChannelMessages>, response: Response<proto::GetChannelMessages>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
let messages = session let messages = session
@ -4152,7 +4194,7 @@ async fn get_channel_messages(
async fn get_channel_messages_by_id( async fn get_channel_messages_by_id(
request: proto::GetChannelMessagesById, request: proto::GetChannelMessagesById,
response: Response<proto::GetChannelMessagesById>, response: Response<proto::GetChannelMessagesById>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let message_ids = request let message_ids = request
.message_ids .message_ids
@ -4175,7 +4217,7 @@ async fn get_channel_messages_by_id(
async fn get_notifications( async fn get_notifications(
request: proto::GetNotifications, request: proto::GetNotifications,
response: Response<proto::GetNotifications>, response: Response<proto::GetNotifications>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let notifications = session let notifications = session
.db() .db()
@ -4197,7 +4239,7 @@ async fn get_notifications(
async fn mark_notification_as_read( async fn mark_notification_as_read(
request: proto::MarkNotificationRead, request: proto::MarkNotificationRead,
response: Response<proto::MarkNotificationRead>, response: Response<proto::MarkNotificationRead>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let database = &session.db().await; let database = &session.db().await;
let notifications = database let notifications = database
@ -4219,7 +4261,7 @@ async fn mark_notification_as_read(
async fn get_private_user_info( async fn get_private_user_info(
_request: proto::GetPrivateUserInfo, _request: proto::GetPrivateUserInfo,
response: Response<proto::GetPrivateUserInfo>, response: Response<proto::GetPrivateUserInfo>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
@ -4243,7 +4285,7 @@ async fn get_private_user_info(
async fn accept_terms_of_service( async fn accept_terms_of_service(
_request: proto::AcceptTermsOfService, _request: proto::AcceptTermsOfService,
response: Response<proto::AcceptTermsOfService>, response: Response<proto::AcceptTermsOfService>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
@ -4267,7 +4309,7 @@ async fn accept_terms_of_service(
async fn get_llm_api_token( async fn get_llm_api_token(
_request: proto::GetLlmToken, _request: proto::GetLlmToken,
response: Response<proto::GetLlmToken>, response: Response<proto::GetLlmToken>,
session: Session, session: MessageContext,
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;

View file

@ -20,6 +20,7 @@ anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true anyhow.workspace = true
base64.workspace = true base64.workspace = true
client.workspace = true client.workspace = true
cloud_api_types.workspace = true
cloud_llm_client.workspace = true cloud_llm_client.workspace = true
collections.workspace = true collections.workspace = true
futures.workspace = true futures.workspace = true

View file

@ -3,11 +3,9 @@ use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use client::Client; use client::Client;
use cloud_api_types::websocket_protocol::MessageToClient;
use cloud_llm_client::Plan; use cloud_llm_client::Plan;
use gpui::{ use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
};
use proto::TypedEnvelope;
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use thiserror::Error; use thiserror::Error;
@ -82,9 +80,7 @@ impl Global for GlobalRefreshLlmTokenListener {}
pub struct RefreshLlmTokenEvent; pub struct RefreshLlmTokenEvent;
pub struct RefreshLlmTokenListener { pub struct RefreshLlmTokenListener;
_llm_token_subscription: client::Subscription,
}
impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {} impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
@ -99,17 +95,21 @@ impl RefreshLlmTokenListener {
} }
fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self { fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
Self { client.add_message_to_client_handler({
_llm_token_subscription: client let this = cx.entity();
.add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token), move |message, cx| {
} Self::handle_refresh_llm_token(this.clone(), message, cx);
}
});
Self
} }
async fn handle_refresh_llm_token( fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
this: Entity<Self>, match message {
_: TypedEnvelope<proto::RefreshLlmToken>, MessageToClient::UserUpdated => {
mut cx: AsyncApp, this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
) -> Result<()> { }
this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent)) }
} }
} }

View file

@ -674,6 +674,10 @@ pub fn count_open_ai_tokens(
| Model::O3 | Model::O3
| Model::O3Mini | Model::O3Mini
| Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), | Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
// GPT-5 models don't have tiktoken support yet; fall back on gpt-4o tokenizer
Model::Five | Model::FiveMini | Model::FiveNano => {
tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages)
}
} }
.map(|tokens| tokens as u64) .map(|tokens| tokens as u64)
}) })

View file

@ -74,6 +74,12 @@ pub enum Model {
O3, O3,
#[serde(rename = "o4-mini")] #[serde(rename = "o4-mini")]
O4Mini, O4Mini,
#[serde(rename = "gpt-5")]
Five,
#[serde(rename = "gpt-5-mini")]
FiveMini,
#[serde(rename = "gpt-5-nano")]
FiveNano,
#[serde(rename = "custom")] #[serde(rename = "custom")]
Custom { Custom {
@ -105,6 +111,9 @@ impl Model {
"o3-mini" => Ok(Self::O3Mini), "o3-mini" => Ok(Self::O3Mini),
"o3" => Ok(Self::O3), "o3" => Ok(Self::O3),
"o4-mini" => Ok(Self::O4Mini), "o4-mini" => Ok(Self::O4Mini),
"gpt-5" => Ok(Self::Five),
"gpt-5-mini" => Ok(Self::FiveMini),
"gpt-5-nano" => Ok(Self::FiveNano),
invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"), invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
} }
} }
@ -123,6 +132,9 @@ impl Model {
Self::O3Mini => "o3-mini", Self::O3Mini => "o3-mini",
Self::O3 => "o3", Self::O3 => "o3",
Self::O4Mini => "o4-mini", Self::O4Mini => "o4-mini",
Self::Five => "gpt-5",
Self::FiveMini => "gpt-5-mini",
Self::FiveNano => "gpt-5-nano",
Self::Custom { name, .. } => name, Self::Custom { name, .. } => name,
} }
} }
@ -141,6 +153,9 @@ impl Model {
Self::O3Mini => "o3-mini", Self::O3Mini => "o3-mini",
Self::O3 => "o3", Self::O3 => "o3",
Self::O4Mini => "o4-mini", Self::O4Mini => "o4-mini",
Self::Five => "gpt-5",
Self::FiveMini => "gpt-5-mini",
Self::FiveNano => "gpt-5-nano",
Self::Custom { Self::Custom {
name, display_name, .. name, display_name, ..
} => display_name.as_ref().unwrap_or(name), } => display_name.as_ref().unwrap_or(name),
@ -161,6 +176,9 @@ impl Model {
Self::O3Mini => 200_000, Self::O3Mini => 200_000,
Self::O3 => 200_000, Self::O3 => 200_000,
Self::O4Mini => 200_000, Self::O4Mini => 200_000,
Self::Five => 272_000,
Self::FiveMini => 272_000,
Self::FiveNano => 272_000,
Self::Custom { max_tokens, .. } => *max_tokens, Self::Custom { max_tokens, .. } => *max_tokens,
} }
} }
@ -182,6 +200,9 @@ impl Model {
Self::O3Mini => Some(100_000), Self::O3Mini => Some(100_000),
Self::O3 => Some(100_000), Self::O3 => Some(100_000),
Self::O4Mini => Some(100_000), Self::O4Mini => Some(100_000),
Self::Five => Some(128_000),
Self::FiveMini => Some(128_000),
Self::FiveNano => Some(128_000),
} }
} }
@ -197,7 +218,10 @@ impl Model {
| Self::FourOmniMini | Self::FourOmniMini
| Self::FourPointOne | Self::FourPointOne
| Self::FourPointOneMini | Self::FourPointOneMini
| Self::FourPointOneNano => true, | Self::FourPointOneNano
| Self::Five
| Self::FiveMini
| Self::FiveNano => true,
Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false, Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
} }
} }

View file

@ -422,23 +422,8 @@ impl Peer {
receiver_id: ConnectionId, receiver_id: ConnectionId,
request: T, request: T,
) -> impl Future<Output = Result<T::Response>> { ) -> impl Future<Output = Result<T::Response>> {
let request_start_time = Instant::now();
let elapsed_time = move || request_start_time.elapsed().as_millis();
tracing::info!("start forwarding request");
self.request_internal(Some(sender_id), receiver_id, request) self.request_internal(Some(sender_id), receiver_id, request)
.map_ok(|envelope| envelope.payload) .map_ok(|envelope| envelope.payload)
.inspect_err(move |_| {
tracing::error!(
waiting_for_host_ms = elapsed_time(),
"error forwarding request"
)
})
.inspect_ok(move |_| {
tracing::info!(
waiting_for_host_ms = elapsed_time(),
"finished forwarding request"
)
})
} }
fn request_internal<T: RequestMessage>( fn request_internal<T: RequestMessage>(