Merge branch 'main' into edit-file-tool
This commit is contained in:
commit
26cf0cd9df
20 changed files with 318 additions and 241 deletions
27
.github/actionlint.yml
vendored
27
.github/actionlint.yml
vendored
|
@ -13,22 +13,17 @@ self-hosted-runner:
|
||||||
- windows-2025-16
|
- windows-2025-16
|
||||||
- windows-2025-32
|
- windows-2025-32
|
||||||
- windows-2025-64
|
- windows-2025-64
|
||||||
# Buildjet Ubuntu 20.04 - AMD x86_64
|
# Namespace Ubuntu 20.04 (Release builds)
|
||||||
- buildjet-2vcpu-ubuntu-2004
|
- namespace-profile-16x32-ubuntu-2004
|
||||||
- buildjet-4vcpu-ubuntu-2004
|
- namespace-profile-32x64-ubuntu-2004
|
||||||
- buildjet-8vcpu-ubuntu-2004
|
- namespace-profile-16x32-ubuntu-2004-arm
|
||||||
- buildjet-16vcpu-ubuntu-2004
|
- namespace-profile-32x64-ubuntu-2004-arm
|
||||||
- buildjet-32vcpu-ubuntu-2004
|
# Namespace Ubuntu 22.04 (Everything else)
|
||||||
# Buildjet Ubuntu 22.04 - AMD x86_64
|
- namespace-profile-2x4-ubuntu-2204
|
||||||
- buildjet-2vcpu-ubuntu-2204
|
- namespace-profile-4x8-ubuntu-2204
|
||||||
- buildjet-4vcpu-ubuntu-2204
|
- namespace-profile-8x16-ubuntu-2204
|
||||||
- buildjet-8vcpu-ubuntu-2204
|
- namespace-profile-16x32-ubuntu-2204
|
||||||
- buildjet-16vcpu-ubuntu-2204
|
- namespace-profile-32x64-ubuntu-2204
|
||||||
- buildjet-32vcpu-ubuntu-2204
|
|
||||||
# Buildjet Ubuntu 22.04 - Graviton aarch64
|
|
||||||
- buildjet-8vcpu-ubuntu-2204-arm
|
|
||||||
- buildjet-16vcpu-ubuntu-2204-arm
|
|
||||||
- buildjet-32vcpu-ubuntu-2204-arm
|
|
||||||
# Self Hosted Runners
|
# Self Hosted Runners
|
||||||
- self-mini-macos
|
- self-mini-macos
|
||||||
- self-32vcpu-windows-2022
|
- self-32vcpu-windows-2022
|
||||||
|
|
2
.github/workflows/bump_patch_version.yml
vendored
2
.github/workflows/bump_patch_version.yml
vendored
|
@ -16,7 +16,7 @@ jobs:
|
||||||
bump_patch_version:
|
bump_patch_version:
|
||||||
if: github.repository_owner == 'zed-industries'
|
if: github.repository_owner == 'zed-industries'
|
||||||
runs-on:
|
runs-on:
|
||||||
- github-16vcpu-ubuntu-2204
|
- namespace-profile-16x32-ubuntu-2204
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
|
||||||
|
|
15
.github/workflows/ci.yml
vendored
15
.github/workflows/ci.yml
vendored
|
@ -137,7 +137,7 @@ jobs:
|
||||||
github.repository_owner == 'zed-industries' &&
|
github.repository_owner == 'zed-industries' &&
|
||||||
needs.job_spec.outputs.run_tests == 'true'
|
needs.job_spec.outputs.run_tests == 'true'
|
||||||
runs-on:
|
runs-on:
|
||||||
- github-8vcpu-ubuntu-2204
|
- namespace-profile-8x16-ubuntu-2204
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repo
|
- name: Checkout repo
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
|
||||||
|
@ -168,7 +168,7 @@ jobs:
|
||||||
needs: [job_spec]
|
needs: [job_spec]
|
||||||
if: github.repository_owner == 'zed-industries'
|
if: github.repository_owner == 'zed-industries'
|
||||||
runs-on:
|
runs-on:
|
||||||
- github-8vcpu-ubuntu-2204
|
- namespace-profile-4x8-ubuntu-2204
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repo
|
- name: Checkout repo
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
|
||||||
|
@ -221,7 +221,7 @@ jobs:
|
||||||
github.repository_owner == 'zed-industries' &&
|
github.repository_owner == 'zed-industries' &&
|
||||||
(needs.job_spec.outputs.run_tests == 'true' || needs.job_spec.outputs.run_docs == 'true')
|
(needs.job_spec.outputs.run_tests == 'true' || needs.job_spec.outputs.run_docs == 'true')
|
||||||
runs-on:
|
runs-on:
|
||||||
- github-8vcpu-ubuntu-2204
|
- namespace-profile-8x16-ubuntu-2204
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repo
|
- name: Checkout repo
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
|
||||||
|
@ -328,7 +328,7 @@ jobs:
|
||||||
github.repository_owner == 'zed-industries' &&
|
github.repository_owner == 'zed-industries' &&
|
||||||
needs.job_spec.outputs.run_tests == 'true'
|
needs.job_spec.outputs.run_tests == 'true'
|
||||||
runs-on:
|
runs-on:
|
||||||
- github-16vcpu-ubuntu-2204
|
- namespace-profile-16x32-ubuntu-2204
|
||||||
steps:
|
steps:
|
||||||
- name: Add Rust to the PATH
|
- name: Add Rust to the PATH
|
||||||
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
|
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
|
||||||
|
@ -380,7 +380,7 @@ jobs:
|
||||||
github.repository_owner == 'zed-industries' &&
|
github.repository_owner == 'zed-industries' &&
|
||||||
needs.job_spec.outputs.run_tests == 'true'
|
needs.job_spec.outputs.run_tests == 'true'
|
||||||
runs-on:
|
runs-on:
|
||||||
- github-8vcpu-ubuntu-2204
|
- namespace-profile-16x32-ubuntu-2204
|
||||||
steps:
|
steps:
|
||||||
- name: Add Rust to the PATH
|
- name: Add Rust to the PATH
|
||||||
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
|
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
|
||||||
|
@ -597,8 +597,7 @@ jobs:
|
||||||
timeout-minutes: 60
|
timeout-minutes: 60
|
||||||
name: Linux x86_x64 release bundle
|
name: Linux x86_x64 release bundle
|
||||||
runs-on:
|
runs-on:
|
||||||
- github-16vcpu-ubuntu-2204
|
- namespace-profile-16x32-ubuntu-2004 # ubuntu 20.04 for minimal glibc
|
||||||
# - buildjet-16vcpu-ubuntu-2004 # ubuntu 20.04 for minimal glibc
|
|
||||||
if: |
|
if: |
|
||||||
startsWith(github.ref, 'refs/tags/v')
|
startsWith(github.ref, 'refs/tags/v')
|
||||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||||
|
@ -651,7 +650,7 @@ jobs:
|
||||||
timeout-minutes: 60
|
timeout-minutes: 60
|
||||||
name: Linux arm64 release bundle
|
name: Linux arm64 release bundle
|
||||||
runs-on:
|
runs-on:
|
||||||
- github-16vcpu-ubuntu-2204-arm
|
- namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc
|
||||||
if: |
|
if: |
|
||||||
startsWith(github.ref, 'refs/tags/v')
|
startsWith(github.ref, 'refs/tags/v')
|
||||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||||
|
|
2
.github/workflows/deploy_cloudflare.yml
vendored
2
.github/workflows/deploy_cloudflare.yml
vendored
|
@ -9,7 +9,7 @@ jobs:
|
||||||
deploy-docs:
|
deploy-docs:
|
||||||
name: Deploy Docs
|
name: Deploy Docs
|
||||||
if: github.repository_owner == 'zed-industries'
|
if: github.repository_owner == 'zed-industries'
|
||||||
runs-on: github-16vcpu-ubuntu-2204
|
runs-on: namespace-profile-16x32-ubuntu-2204
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repo
|
- name: Checkout repo
|
||||||
|
|
4
.github/workflows/deploy_collab.yml
vendored
4
.github/workflows/deploy_collab.yml
vendored
|
@ -61,7 +61,7 @@ jobs:
|
||||||
- style
|
- style
|
||||||
- tests
|
- tests
|
||||||
runs-on:
|
runs-on:
|
||||||
- github-16vcpu-ubuntu-2204
|
- namespace-profile-16x32-ubuntu-2204
|
||||||
steps:
|
steps:
|
||||||
- name: Install doctl
|
- name: Install doctl
|
||||||
uses: digitalocean/action-doctl@v2
|
uses: digitalocean/action-doctl@v2
|
||||||
|
@ -94,7 +94,7 @@ jobs:
|
||||||
needs:
|
needs:
|
||||||
- publish
|
- publish
|
||||||
runs-on:
|
runs-on:
|
||||||
- github-16vcpu-ubuntu-2204
|
- namespace-profile-16x32-ubuntu-2204
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repo
|
- name: Checkout repo
|
||||||
|
|
2
.github/workflows/eval.yml
vendored
2
.github/workflows/eval.yml
vendored
|
@ -32,7 +32,7 @@ jobs:
|
||||||
github.repository_owner == 'zed-industries' &&
|
github.repository_owner == 'zed-industries' &&
|
||||||
(github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-eval'))
|
(github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-eval'))
|
||||||
runs-on:
|
runs-on:
|
||||||
- github-16vcpu-ubuntu-2204
|
- namespace-profile-16x32-ubuntu-2204
|
||||||
steps:
|
steps:
|
||||||
- name: Add Rust to the PATH
|
- name: Add Rust to the PATH
|
||||||
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
|
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
|
||||||
|
|
2
.github/workflows/nix.yml
vendored
2
.github/workflows/nix.yml
vendored
|
@ -20,7 +20,7 @@ jobs:
|
||||||
matrix:
|
matrix:
|
||||||
system:
|
system:
|
||||||
- os: x86 Linux
|
- os: x86 Linux
|
||||||
runner: github-16vcpu-ubuntu-2204
|
runner: namespace-profile-16x32-ubuntu-2204
|
||||||
install_nix: true
|
install_nix: true
|
||||||
- os: arm Mac
|
- os: arm Mac
|
||||||
runner: [macOS, ARM64, test]
|
runner: [macOS, ARM64, test]
|
||||||
|
|
2
.github/workflows/randomized_tests.yml
vendored
2
.github/workflows/randomized_tests.yml
vendored
|
@ -20,7 +20,7 @@ jobs:
|
||||||
name: Run randomized tests
|
name: Run randomized tests
|
||||||
if: github.repository_owner == 'zed-industries'
|
if: github.repository_owner == 'zed-industries'
|
||||||
runs-on:
|
runs-on:
|
||||||
- github-16vcpu-ubuntu-2204
|
- namespace-profile-16x32-ubuntu-2204
|
||||||
steps:
|
steps:
|
||||||
- name: Install Node
|
- name: Install Node
|
||||||
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4
|
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4
|
||||||
|
|
5
.github/workflows/release_nightly.yml
vendored
5
.github/workflows/release_nightly.yml
vendored
|
@ -128,8 +128,7 @@ jobs:
|
||||||
name: Create a Linux *.tar.gz bundle for x86
|
name: Create a Linux *.tar.gz bundle for x86
|
||||||
if: github.repository_owner == 'zed-industries'
|
if: github.repository_owner == 'zed-industries'
|
||||||
runs-on:
|
runs-on:
|
||||||
- github-16vcpu-ubuntu-2204
|
- namespace-profile-16x32-ubuntu-2004 # ubuntu 20.04 for minimal glibc
|
||||||
# - buildjet-16vcpu-ubuntu-2004
|
|
||||||
needs: tests
|
needs: tests
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repo
|
- name: Checkout repo
|
||||||
|
@ -169,7 +168,7 @@ jobs:
|
||||||
name: Create a Linux *.tar.gz bundle for ARM
|
name: Create a Linux *.tar.gz bundle for ARM
|
||||||
if: github.repository_owner == 'zed-industries'
|
if: github.repository_owner == 'zed-industries'
|
||||||
runs-on:
|
runs-on:
|
||||||
- github-16vcpu-ubuntu-2204-arm
|
- namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc
|
||||||
needs: tests
|
needs: tests
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repo
|
- name: Checkout repo
|
||||||
|
|
2
.github/workflows/unit_evals.yml
vendored
2
.github/workflows/unit_evals.yml
vendored
|
@ -23,7 +23,7 @@ jobs:
|
||||||
timeout-minutes: 60
|
timeout-minutes: 60
|
||||||
name: Run unit evals
|
name: Run unit evals
|
||||||
runs-on:
|
runs-on:
|
||||||
- github-16vcpu-ubuntu-2204
|
- namespace-profile-16x32-ubuntu-2204
|
||||||
steps:
|
steps:
|
||||||
- name: Add Rust to the PATH
|
- name: Add Rust to the PATH
|
||||||
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
|
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
|
||||||
|
|
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -9132,6 +9132,7 @@ dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"client",
|
"client",
|
||||||
|
"cloud_api_types",
|
||||||
"cloud_llm_client",
|
"cloud_llm_client",
|
||||||
"collections",
|
"collections",
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
|
|
|
@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use smol::process::Child;
|
use smol::process::Child;
|
||||||
use std::cell::{Cell, RefCell};
|
use std::cell::RefCell;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
@ -153,20 +153,17 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
|
|
||||||
let pending_cancellation = Rc::new(Cell::new(PendingCancellation::None));
|
let turn_state = Rc::new(RefCell::new(TurnState::None));
|
||||||
|
|
||||||
let end_turn_tx = Rc::new(RefCell::new(None));
|
|
||||||
let handler_task = cx.spawn({
|
let handler_task = cx.spawn({
|
||||||
let end_turn_tx = end_turn_tx.clone();
|
let turn_state = turn_state.clone();
|
||||||
let mut thread_rx = thread_rx.clone();
|
let mut thread_rx = thread_rx.clone();
|
||||||
let cancellation_state = pending_cancellation.clone();
|
|
||||||
async move |cx| {
|
async move |cx| {
|
||||||
while let Some(message) = incoming_message_rx.next().await {
|
while let Some(message) = incoming_message_rx.next().await {
|
||||||
ClaudeAgentSession::handle_message(
|
ClaudeAgentSession::handle_message(
|
||||||
thread_rx.clone(),
|
thread_rx.clone(),
|
||||||
message,
|
message,
|
||||||
end_turn_tx.clone(),
|
turn_state.clone(),
|
||||||
cancellation_state.clone(),
|
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
@ -192,8 +189,7 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
|
|
||||||
let session = ClaudeAgentSession {
|
let session = ClaudeAgentSession {
|
||||||
outgoing_tx,
|
outgoing_tx,
|
||||||
end_turn_tx,
|
turn_state,
|
||||||
pending_cancellation,
|
|
||||||
_handler_task: handler_task,
|
_handler_task: handler_task,
|
||||||
_mcp_server: Some(permission_mcp_server),
|
_mcp_server: Some(permission_mcp_server),
|
||||||
};
|
};
|
||||||
|
@ -225,8 +221,8 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
)));
|
)));
|
||||||
};
|
};
|
||||||
|
|
||||||
let (tx, rx) = oneshot::channel();
|
let (end_tx, end_rx) = oneshot::channel();
|
||||||
session.end_turn_tx.borrow_mut().replace(tx);
|
session.turn_state.replace(TurnState::InProgress { end_tx });
|
||||||
|
|
||||||
let mut content = String::new();
|
let mut content = String::new();
|
||||||
for chunk in params.prompt {
|
for chunk in params.prompt {
|
||||||
|
@ -260,12 +256,7 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
return Task::ready(Err(anyhow!(err)));
|
return Task::ready(Err(anyhow!(err)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let cancellation_state = session.pending_cancellation.clone();
|
cx.foreground_executor().spawn(async move { end_rx.await? })
|
||||||
cx.foreground_executor().spawn(async move {
|
|
||||||
let result = rx.await??;
|
|
||||||
cancellation_state.set(PendingCancellation::None);
|
|
||||||
Ok(result)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
|
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
|
||||||
|
@ -277,7 +268,15 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
|
|
||||||
let request_id = new_request_id();
|
let request_id = new_request_id();
|
||||||
|
|
||||||
session.pending_cancellation.set(PendingCancellation::Sent {
|
let turn_state = session.turn_state.take();
|
||||||
|
let TurnState::InProgress { end_tx } = turn_state else {
|
||||||
|
// Already cancelled or idle, put it back
|
||||||
|
session.turn_state.replace(turn_state);
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
session.turn_state.replace(TurnState::CancelRequested {
|
||||||
|
end_tx,
|
||||||
request_id: request_id.clone(),
|
request_id: request_id.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -349,28 +348,56 @@ fn spawn_claude(
|
||||||
|
|
||||||
struct ClaudeAgentSession {
|
struct ClaudeAgentSession {
|
||||||
outgoing_tx: UnboundedSender<SdkMessage>,
|
outgoing_tx: UnboundedSender<SdkMessage>,
|
||||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
|
turn_state: Rc<RefCell<TurnState>>,
|
||||||
pending_cancellation: Rc<Cell<PendingCancellation>>,
|
|
||||||
_mcp_server: Option<ClaudeZedMcpServer>,
|
_mcp_server: Option<ClaudeZedMcpServer>,
|
||||||
_handler_task: Task<()>,
|
_handler_task: Task<()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default, PartialEq)]
|
#[derive(Debug, Default)]
|
||||||
enum PendingCancellation {
|
enum TurnState {
|
||||||
#[default]
|
#[default]
|
||||||
None,
|
None,
|
||||||
Sent {
|
InProgress {
|
||||||
|
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
|
||||||
|
},
|
||||||
|
CancelRequested {
|
||||||
|
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
|
||||||
request_id: String,
|
request_id: String,
|
||||||
},
|
},
|
||||||
Confirmed,
|
CancelConfirmed {
|
||||||
|
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TurnState {
|
||||||
|
fn is_cancelled(&self) -> bool {
|
||||||
|
matches!(self, TurnState::CancelConfirmed { .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
|
||||||
|
match self {
|
||||||
|
TurnState::None => None,
|
||||||
|
TurnState::InProgress { end_tx, .. } => Some(end_tx),
|
||||||
|
TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
|
||||||
|
TurnState::CancelConfirmed { end_tx } => Some(end_tx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn confirm_cancellation(self, id: &str) -> Self {
|
||||||
|
match self {
|
||||||
|
TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
|
||||||
|
TurnState::CancelConfirmed { end_tx }
|
||||||
|
}
|
||||||
|
_ => self,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClaudeAgentSession {
|
impl ClaudeAgentSession {
|
||||||
async fn handle_message(
|
async fn handle_message(
|
||||||
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||||
message: SdkMessage,
|
message: SdkMessage,
|
||||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
|
turn_state: Rc<RefCell<TurnState>>,
|
||||||
pending_cancellation: Rc<Cell<PendingCancellation>>,
|
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) {
|
) {
|
||||||
match message {
|
match message {
|
||||||
|
@ -393,15 +420,13 @@ impl ClaudeAgentSession {
|
||||||
for chunk in message.content.chunks() {
|
for chunk in message.content.chunks() {
|
||||||
match chunk {
|
match chunk {
|
||||||
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
|
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
|
||||||
let state = pending_cancellation.take();
|
if !turn_state.borrow().is_cancelled() {
|
||||||
if state != PendingCancellation::Confirmed {
|
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.push_user_content_block(text.into(), cx)
|
thread.push_user_content_block(text.into(), cx)
|
||||||
})
|
})
|
||||||
.log_err();
|
.log_err();
|
||||||
}
|
}
|
||||||
pending_cancellation.set(state);
|
|
||||||
}
|
}
|
||||||
ContentChunk::ToolResult {
|
ContentChunk::ToolResult {
|
||||||
content,
|
content,
|
||||||
|
@ -414,7 +439,12 @@ impl ClaudeAgentSession {
|
||||||
acp::ToolCallUpdate {
|
acp::ToolCallUpdate {
|
||||||
id: acp::ToolCallId(tool_use_id.into()),
|
id: acp::ToolCallId(tool_use_id.into()),
|
||||||
fields: acp::ToolCallUpdateFields {
|
fields: acp::ToolCallUpdateFields {
|
||||||
status: Some(acp::ToolCallStatus::Completed),
|
status: if turn_state.borrow().is_cancelled() {
|
||||||
|
// Do not set to completed if turn was cancelled
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(acp::ToolCallStatus::Completed)
|
||||||
|
},
|
||||||
content: (!content.is_empty())
|
content: (!content.is_empty())
|
||||||
.then(|| vec![content.into()]),
|
.then(|| vec![content.into()]),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
|
@ -541,40 +571,38 @@ impl ClaudeAgentSession {
|
||||||
result,
|
result,
|
||||||
..
|
..
|
||||||
} => {
|
} => {
|
||||||
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
|
let turn_state = turn_state.take();
|
||||||
if is_error
|
let was_cancelled = turn_state.is_cancelled();
|
||||||
|| (subtype == ResultErrorType::ErrorDuringExecution
|
let Some(end_turn_tx) = turn_state.end_tx() else {
|
||||||
&& pending_cancellation.take() != PendingCancellation::Confirmed)
|
debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
|
||||||
{
|
return;
|
||||||
end_turn_tx
|
};
|
||||||
.send(Err(anyhow!(
|
|
||||||
"Error: {}",
|
if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution)
|
||||||
result.unwrap_or_else(|| subtype.to_string())
|
{
|
||||||
)))
|
end_turn_tx
|
||||||
.ok();
|
.send(Err(anyhow!(
|
||||||
} else {
|
"Error: {}",
|
||||||
let stop_reason = match subtype {
|
result.unwrap_or_else(|| subtype.to_string())
|
||||||
ResultErrorType::Success => acp::StopReason::EndTurn,
|
)))
|
||||||
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
|
.ok();
|
||||||
ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
|
} else {
|
||||||
};
|
let stop_reason = match subtype {
|
||||||
end_turn_tx
|
ResultErrorType::Success => acp::StopReason::EndTurn,
|
||||||
.send(Ok(acp::PromptResponse { stop_reason }))
|
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
|
||||||
.ok();
|
ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
|
||||||
}
|
};
|
||||||
|
end_turn_tx
|
||||||
|
.send(Ok(acp::PromptResponse { stop_reason }))
|
||||||
|
.ok();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
SdkMessage::ControlResponse { response } => {
|
SdkMessage::ControlResponse { response } => {
|
||||||
if matches!(response.subtype, ResultErrorType::Success) {
|
if matches!(response.subtype, ResultErrorType::Success) {
|
||||||
let pending_cancellation_value = pending_cancellation.take();
|
let new_state = turn_state.take().confirm_cancellation(&response.request_id);
|
||||||
|
turn_state.replace(new_state);
|
||||||
if let PendingCancellation::Sent { request_id } = &pending_cancellation_value
|
} else {
|
||||||
&& request_id == &response.request_id
|
log::error!("Control response error: {:?}", response);
|
||||||
{
|
|
||||||
pending_cancellation.set(PendingCancellation::Confirmed);
|
|
||||||
} else {
|
|
||||||
pending_cancellation.set(pending_cancellation_value);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
SdkMessage::System { .. } => {}
|
SdkMessage::System { .. } => {}
|
||||||
|
|
|
@ -246,7 +246,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
|
||||||
|
|
||||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||||
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
|
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
|
||||||
let full_turn = thread.update(cx, |thread, cx| {
|
let _ = thread.update(cx, |thread, cx| {
|
||||||
thread.send_raw(
|
thread.send_raw(
|
||||||
r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
|
r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
|
||||||
cx,
|
cx,
|
||||||
|
@ -285,9 +285,8 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
|
||||||
id.clone()
|
id.clone()
|
||||||
});
|
});
|
||||||
|
|
||||||
let _ = thread.update(cx, |thread, cx| thread.cancel(cx));
|
thread.update(cx, |thread, cx| thread.cancel(cx)).await;
|
||||||
full_turn.await.unwrap();
|
thread.read_with(cx, |thread, _cx| {
|
||||||
thread.read_with(cx, |thread, _| {
|
|
||||||
let AgentThreadEntry::ToolCall(ToolCall {
|
let AgentThreadEntry::ToolCall(ToolCall {
|
||||||
status: ToolCallStatus::Canceled,
|
status: ToolCallStatus::Canceled,
|
||||||
..
|
..
|
||||||
|
|
|
@ -193,7 +193,7 @@ pub fn init(client: &Arc<Client>, cx: &mut App) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &App) + Send + Sync + 'static>;
|
pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &mut App) + Send + Sync + 'static>;
|
||||||
|
|
||||||
struct GlobalClient(Arc<Client>);
|
struct GlobalClient(Arc<Client>);
|
||||||
|
|
||||||
|
@ -1684,7 +1684,7 @@ impl Client {
|
||||||
|
|
||||||
pub fn add_message_to_client_handler(
|
pub fn add_message_to_client_handler(
|
||||||
self: &Arc<Client>,
|
self: &Arc<Client>,
|
||||||
handler: impl Fn(&MessageToClient, &App) + Send + Sync + 'static,
|
handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
|
||||||
) {
|
) {
|
||||||
self.message_to_client_handlers
|
self.message_to_client_handlers
|
||||||
.lock()
|
.lock()
|
||||||
|
|
|
@ -41,9 +41,11 @@ use chrono::Utc;
|
||||||
use collections::{HashMap, HashSet};
|
use collections::{HashMap, HashSet};
|
||||||
pub use connection_pool::{ConnectionPool, ZedVersion};
|
pub use connection_pool::{ConnectionPool, ZedVersion};
|
||||||
use core::fmt::{self, Debug, Formatter};
|
use core::fmt::{self, Debug, Formatter};
|
||||||
|
use futures::TryFutureExt as _;
|
||||||
use reqwest_client::ReqwestClient;
|
use reqwest_client::ReqwestClient;
|
||||||
use rpc::proto::{MultiLspQuery, split_repository_update};
|
use rpc::proto::{MultiLspQuery, split_repository_update};
|
||||||
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
|
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
|
||||||
|
use tracing::Span;
|
||||||
|
|
||||||
use futures::{
|
use futures::{
|
||||||
FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
|
FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
|
||||||
|
@ -94,8 +96,13 @@ const MAX_CONCURRENT_CONNECTIONS: usize = 512;
|
||||||
|
|
||||||
static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
|
static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
|
||||||
|
|
||||||
|
const TOTAL_DURATION_MS: &str = "total_duration_ms";
|
||||||
|
const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
|
||||||
|
const QUEUE_DURATION_MS: &str = "queue_duration_ms";
|
||||||
|
const HOST_WAITING_MS: &str = "host_waiting_ms";
|
||||||
|
|
||||||
type MessageHandler =
|
type MessageHandler =
|
||||||
Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
|
Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
|
||||||
|
|
||||||
pub struct ConnectionGuard;
|
pub struct ConnectionGuard;
|
||||||
|
|
||||||
|
@ -163,6 +170,42 @@ impl Principal {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct MessageContext {
|
||||||
|
session: Session,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deref for MessageContext {
|
||||||
|
type Target = Session;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.session
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MessageContext {
|
||||||
|
pub fn forward_request<T: RequestMessage>(
|
||||||
|
&self,
|
||||||
|
receiver_id: ConnectionId,
|
||||||
|
request: T,
|
||||||
|
) -> impl Future<Output = anyhow::Result<T::Response>> {
|
||||||
|
let request_start_time = Instant::now();
|
||||||
|
let span = self.span.clone();
|
||||||
|
tracing::info!("start forwarding request");
|
||||||
|
self.peer
|
||||||
|
.forward_request(self.connection_id, receiver_id, request)
|
||||||
|
.inspect(move |_| {
|
||||||
|
span.record(
|
||||||
|
HOST_WAITING_MS,
|
||||||
|
request_start_time.elapsed().as_micros() as f64 / 1000.0,
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.inspect_err(|_| tracing::error!("error forwarding request"))
|
||||||
|
.inspect_ok(|_| tracing::info!("finished forwarding request"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct Session {
|
struct Session {
|
||||||
principal: Principal,
|
principal: Principal,
|
||||||
|
@ -646,40 +689,37 @@ impl Server {
|
||||||
|
|
||||||
fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
||||||
where
|
where
|
||||||
F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
|
F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
|
||||||
Fut: 'static + Send + Future<Output = Result<()>>,
|
Fut: 'static + Send + Future<Output = Result<()>>,
|
||||||
M: EnvelopedMessage,
|
M: EnvelopedMessage,
|
||||||
{
|
{
|
||||||
let prev_handler = self.handlers.insert(
|
let prev_handler = self.handlers.insert(
|
||||||
TypeId::of::<M>(),
|
TypeId::of::<M>(),
|
||||||
Box::new(move |envelope, session| {
|
Box::new(move |envelope, session, span| {
|
||||||
let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
|
let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
|
||||||
let received_at = envelope.received_at;
|
let received_at = envelope.received_at;
|
||||||
tracing::info!("message received");
|
tracing::info!("message received");
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let future = (handler)(*envelope, session);
|
let future = (handler)(
|
||||||
|
*envelope,
|
||||||
|
MessageContext {
|
||||||
|
session,
|
||||||
|
span: span.clone(),
|
||||||
|
},
|
||||||
|
);
|
||||||
async move {
|
async move {
|
||||||
let result = future.await;
|
let result = future.await;
|
||||||
let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
|
let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
|
||||||
let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
|
let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
|
||||||
let queue_duration_ms = total_duration_ms - processing_duration_ms;
|
let queue_duration_ms = total_duration_ms - processing_duration_ms;
|
||||||
|
span.record(TOTAL_DURATION_MS, total_duration_ms);
|
||||||
|
span.record(PROCESSING_DURATION_MS, processing_duration_ms);
|
||||||
|
span.record(QUEUE_DURATION_MS, queue_duration_ms);
|
||||||
match result {
|
match result {
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
tracing::error!(
|
tracing::error!(?error, "error handling message")
|
||||||
?error,
|
|
||||||
total_duration_ms,
|
|
||||||
processing_duration_ms,
|
|
||||||
queue_duration_ms,
|
|
||||||
"error handling message"
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
Ok(()) => tracing::info!(
|
Ok(()) => tracing::info!("finished handling message"),
|
||||||
total_duration_ms,
|
|
||||||
processing_duration_ms,
|
|
||||||
queue_duration_ms,
|
|
||||||
"finished handling message"
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
|
@ -693,7 +733,7 @@ impl Server {
|
||||||
|
|
||||||
fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
||||||
where
|
where
|
||||||
F: 'static + Send + Sync + Fn(M, Session) -> Fut,
|
F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
|
||||||
Fut: 'static + Send + Future<Output = Result<()>>,
|
Fut: 'static + Send + Future<Output = Result<()>>,
|
||||||
M: EnvelopedMessage,
|
M: EnvelopedMessage,
|
||||||
{
|
{
|
||||||
|
@ -703,7 +743,7 @@ impl Server {
|
||||||
|
|
||||||
fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
||||||
where
|
where
|
||||||
F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
|
F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
|
||||||
Fut: Send + Future<Output = Result<()>>,
|
Fut: Send + Future<Output = Result<()>>,
|
||||||
M: RequestMessage,
|
M: RequestMessage,
|
||||||
{
|
{
|
||||||
|
@ -889,12 +929,16 @@ impl Server {
|
||||||
login=field::Empty,
|
login=field::Empty,
|
||||||
impersonator=field::Empty,
|
impersonator=field::Empty,
|
||||||
multi_lsp_query_request=field::Empty,
|
multi_lsp_query_request=field::Empty,
|
||||||
|
{ TOTAL_DURATION_MS }=field::Empty,
|
||||||
|
{ PROCESSING_DURATION_MS }=field::Empty,
|
||||||
|
{ QUEUE_DURATION_MS }=field::Empty,
|
||||||
|
{ HOST_WAITING_MS }=field::Empty
|
||||||
);
|
);
|
||||||
principal.update_span(&span);
|
principal.update_span(&span);
|
||||||
let span_enter = span.enter();
|
let span_enter = span.enter();
|
||||||
if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
|
if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
|
||||||
let is_background = message.is_background();
|
let is_background = message.is_background();
|
||||||
let handle_message = (handler)(message, session.clone());
|
let handle_message = (handler)(message, session.clone(), span.clone());
|
||||||
drop(span_enter);
|
drop(span_enter);
|
||||||
|
|
||||||
let handle_message = async move {
|
let handle_message = async move {
|
||||||
|
@ -1386,7 +1430,11 @@ async fn connection_lost(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Acknowledges a ping from a client, used to keep the connection alive.
|
/// Acknowledges a ping from a client, used to keep the connection alive.
|
||||||
async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
|
async fn ping(
|
||||||
|
_: proto::Ping,
|
||||||
|
response: Response<proto::Ping>,
|
||||||
|
_session: MessageContext,
|
||||||
|
) -> Result<()> {
|
||||||
response.send(proto::Ack {})?;
|
response.send(proto::Ack {})?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1395,7 +1443,7 @@ async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session
|
||||||
async fn create_room(
|
async fn create_room(
|
||||||
_request: proto::CreateRoom,
|
_request: proto::CreateRoom,
|
||||||
response: Response<proto::CreateRoom>,
|
response: Response<proto::CreateRoom>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let livekit_room = nanoid::nanoid!(30);
|
let livekit_room = nanoid::nanoid!(30);
|
||||||
|
|
||||||
|
@ -1435,7 +1483,7 @@ async fn create_room(
|
||||||
async fn join_room(
|
async fn join_room(
|
||||||
request: proto::JoinRoom,
|
request: proto::JoinRoom,
|
||||||
response: Response<proto::JoinRoom>,
|
response: Response<proto::JoinRoom>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let room_id = RoomId::from_proto(request.id);
|
let room_id = RoomId::from_proto(request.id);
|
||||||
|
|
||||||
|
@ -1502,7 +1550,7 @@ async fn join_room(
|
||||||
async fn rejoin_room(
|
async fn rejoin_room(
|
||||||
request: proto::RejoinRoom,
|
request: proto::RejoinRoom,
|
||||||
response: Response<proto::RejoinRoom>,
|
response: Response<proto::RejoinRoom>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let room;
|
let room;
|
||||||
let channel;
|
let channel;
|
||||||
|
@ -1679,7 +1727,7 @@ fn notify_rejoined_projects(
|
||||||
async fn leave_room(
|
async fn leave_room(
|
||||||
_: proto::LeaveRoom,
|
_: proto::LeaveRoom,
|
||||||
response: Response<proto::LeaveRoom>,
|
response: Response<proto::LeaveRoom>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
leave_room_for_session(&session, session.connection_id).await?;
|
leave_room_for_session(&session, session.connection_id).await?;
|
||||||
response.send(proto::Ack {})?;
|
response.send(proto::Ack {})?;
|
||||||
|
@ -1690,7 +1738,7 @@ async fn leave_room(
|
||||||
async fn set_room_participant_role(
|
async fn set_room_participant_role(
|
||||||
request: proto::SetRoomParticipantRole,
|
request: proto::SetRoomParticipantRole,
|
||||||
response: Response<proto::SetRoomParticipantRole>,
|
response: Response<proto::SetRoomParticipantRole>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let user_id = UserId::from_proto(request.user_id);
|
let user_id = UserId::from_proto(request.user_id);
|
||||||
let role = ChannelRole::from(request.role());
|
let role = ChannelRole::from(request.role());
|
||||||
|
@ -1738,7 +1786,7 @@ async fn set_room_participant_role(
|
||||||
async fn call(
|
async fn call(
|
||||||
request: proto::Call,
|
request: proto::Call,
|
||||||
response: Response<proto::Call>,
|
response: Response<proto::Call>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let room_id = RoomId::from_proto(request.room_id);
|
let room_id = RoomId::from_proto(request.room_id);
|
||||||
let calling_user_id = session.user_id();
|
let calling_user_id = session.user_id();
|
||||||
|
@ -1807,7 +1855,7 @@ async fn call(
|
||||||
async fn cancel_call(
|
async fn cancel_call(
|
||||||
request: proto::CancelCall,
|
request: proto::CancelCall,
|
||||||
response: Response<proto::CancelCall>,
|
response: Response<proto::CancelCall>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let called_user_id = UserId::from_proto(request.called_user_id);
|
let called_user_id = UserId::from_proto(request.called_user_id);
|
||||||
let room_id = RoomId::from_proto(request.room_id);
|
let room_id = RoomId::from_proto(request.room_id);
|
||||||
|
@ -1842,7 +1890,7 @@ async fn cancel_call(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Decline an incoming call.
|
/// Decline an incoming call.
|
||||||
async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
|
async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
|
||||||
let room_id = RoomId::from_proto(message.room_id);
|
let room_id = RoomId::from_proto(message.room_id);
|
||||||
{
|
{
|
||||||
let room = session
|
let room = session
|
||||||
|
@ -1877,7 +1925,7 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<(
|
||||||
async fn update_participant_location(
|
async fn update_participant_location(
|
||||||
request: proto::UpdateParticipantLocation,
|
request: proto::UpdateParticipantLocation,
|
||||||
response: Response<proto::UpdateParticipantLocation>,
|
response: Response<proto::UpdateParticipantLocation>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let room_id = RoomId::from_proto(request.room_id);
|
let room_id = RoomId::from_proto(request.room_id);
|
||||||
let location = request.location.context("invalid location")?;
|
let location = request.location.context("invalid location")?;
|
||||||
|
@ -1896,7 +1944,7 @@ async fn update_participant_location(
|
||||||
async fn share_project(
|
async fn share_project(
|
||||||
request: proto::ShareProject,
|
request: proto::ShareProject,
|
||||||
response: Response<proto::ShareProject>,
|
response: Response<proto::ShareProject>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let (project_id, room) = &*session
|
let (project_id, room) = &*session
|
||||||
.db()
|
.db()
|
||||||
|
@ -1917,7 +1965,7 @@ async fn share_project(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unshare a project from the room.
|
/// Unshare a project from the room.
|
||||||
async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
|
async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
|
||||||
let project_id = ProjectId::from_proto(message.project_id);
|
let project_id = ProjectId::from_proto(message.project_id);
|
||||||
unshare_project_internal(project_id, session.connection_id, &session).await
|
unshare_project_internal(project_id, session.connection_id, &session).await
|
||||||
}
|
}
|
||||||
|
@ -1964,7 +2012,7 @@ async fn unshare_project_internal(
|
||||||
async fn join_project(
|
async fn join_project(
|
||||||
request: proto::JoinProject,
|
request: proto::JoinProject,
|
||||||
response: Response<proto::JoinProject>,
|
response: Response<proto::JoinProject>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let project_id = ProjectId::from_proto(request.project_id);
|
let project_id = ProjectId::from_proto(request.project_id);
|
||||||
|
|
||||||
|
@ -2111,7 +2159,7 @@ async fn join_project(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Leave someone elses shared project.
|
/// Leave someone elses shared project.
|
||||||
async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
|
async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
|
||||||
let sender_id = session.connection_id;
|
let sender_id = session.connection_id;
|
||||||
let project_id = ProjectId::from_proto(request.project_id);
|
let project_id = ProjectId::from_proto(request.project_id);
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
|
@ -2134,7 +2182,7 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result
|
||||||
async fn update_project(
|
async fn update_project(
|
||||||
request: proto::UpdateProject,
|
request: proto::UpdateProject,
|
||||||
response: Response<proto::UpdateProject>,
|
response: Response<proto::UpdateProject>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let project_id = ProjectId::from_proto(request.project_id);
|
let project_id = ProjectId::from_proto(request.project_id);
|
||||||
let (room, guest_connection_ids) = &*session
|
let (room, guest_connection_ids) = &*session
|
||||||
|
@ -2163,7 +2211,7 @@ async fn update_project(
|
||||||
async fn update_worktree(
|
async fn update_worktree(
|
||||||
request: proto::UpdateWorktree,
|
request: proto::UpdateWorktree,
|
||||||
response: Response<proto::UpdateWorktree>,
|
response: Response<proto::UpdateWorktree>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let guest_connection_ids = session
|
let guest_connection_ids = session
|
||||||
.db()
|
.db()
|
||||||
|
@ -2187,7 +2235,7 @@ async fn update_worktree(
|
||||||
async fn update_repository(
|
async fn update_repository(
|
||||||
request: proto::UpdateRepository,
|
request: proto::UpdateRepository,
|
||||||
response: Response<proto::UpdateRepository>,
|
response: Response<proto::UpdateRepository>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let guest_connection_ids = session
|
let guest_connection_ids = session
|
||||||
.db()
|
.db()
|
||||||
|
@ -2211,7 +2259,7 @@ async fn update_repository(
|
||||||
async fn remove_repository(
|
async fn remove_repository(
|
||||||
request: proto::RemoveRepository,
|
request: proto::RemoveRepository,
|
||||||
response: Response<proto::RemoveRepository>,
|
response: Response<proto::RemoveRepository>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let guest_connection_ids = session
|
let guest_connection_ids = session
|
||||||
.db()
|
.db()
|
||||||
|
@ -2235,7 +2283,7 @@ async fn remove_repository(
|
||||||
/// Updates other participants with changes to the diagnostics
|
/// Updates other participants with changes to the diagnostics
|
||||||
async fn update_diagnostic_summary(
|
async fn update_diagnostic_summary(
|
||||||
message: proto::UpdateDiagnosticSummary,
|
message: proto::UpdateDiagnosticSummary,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let guest_connection_ids = session
|
let guest_connection_ids = session
|
||||||
.db()
|
.db()
|
||||||
|
@ -2259,7 +2307,7 @@ async fn update_diagnostic_summary(
|
||||||
/// Updates other participants with changes to the worktree settings
|
/// Updates other participants with changes to the worktree settings
|
||||||
async fn update_worktree_settings(
|
async fn update_worktree_settings(
|
||||||
message: proto::UpdateWorktreeSettings,
|
message: proto::UpdateWorktreeSettings,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let guest_connection_ids = session
|
let guest_connection_ids = session
|
||||||
.db()
|
.db()
|
||||||
|
@ -2283,7 +2331,7 @@ async fn update_worktree_settings(
|
||||||
/// Notify other participants that a language server has started.
|
/// Notify other participants that a language server has started.
|
||||||
async fn start_language_server(
|
async fn start_language_server(
|
||||||
request: proto::StartLanguageServer,
|
request: proto::StartLanguageServer,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let guest_connection_ids = session
|
let guest_connection_ids = session
|
||||||
.db()
|
.db()
|
||||||
|
@ -2306,7 +2354,7 @@ async fn start_language_server(
|
||||||
/// Notify other participants that a language server has changed.
|
/// Notify other participants that a language server has changed.
|
||||||
async fn update_language_server(
|
async fn update_language_server(
|
||||||
request: proto::UpdateLanguageServer,
|
request: proto::UpdateLanguageServer,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let project_id = ProjectId::from_proto(request.project_id);
|
let project_id = ProjectId::from_proto(request.project_id);
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
|
@ -2339,7 +2387,7 @@ async fn update_language_server(
|
||||||
async fn forward_read_only_project_request<T>(
|
async fn forward_read_only_project_request<T>(
|
||||||
request: T,
|
request: T,
|
||||||
response: Response<T>,
|
response: Response<T>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()>
|
) -> Result<()>
|
||||||
where
|
where
|
||||||
T: EntityMessage + RequestMessage,
|
T: EntityMessage + RequestMessage,
|
||||||
|
@ -2350,10 +2398,7 @@ where
|
||||||
.await
|
.await
|
||||||
.host_for_read_only_project_request(project_id, session.connection_id)
|
.host_for_read_only_project_request(project_id, session.connection_id)
|
||||||
.await?;
|
.await?;
|
||||||
let payload = session
|
let payload = session.forward_request(host_connection_id, request).await?;
|
||||||
.peer
|
|
||||||
.forward_request(session.connection_id, host_connection_id, request)
|
|
||||||
.await?;
|
|
||||||
response.send(payload)?;
|
response.send(payload)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -2363,7 +2408,7 @@ where
|
||||||
async fn forward_mutating_project_request<T>(
|
async fn forward_mutating_project_request<T>(
|
||||||
request: T,
|
request: T,
|
||||||
response: Response<T>,
|
response: Response<T>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()>
|
) -> Result<()>
|
||||||
where
|
where
|
||||||
T: EntityMessage + RequestMessage,
|
T: EntityMessage + RequestMessage,
|
||||||
|
@ -2375,10 +2420,7 @@ where
|
||||||
.await
|
.await
|
||||||
.host_for_mutating_project_request(project_id, session.connection_id)
|
.host_for_mutating_project_request(project_id, session.connection_id)
|
||||||
.await?;
|
.await?;
|
||||||
let payload = session
|
let payload = session.forward_request(host_connection_id, request).await?;
|
||||||
.peer
|
|
||||||
.forward_request(session.connection_id, host_connection_id, request)
|
|
||||||
.await?;
|
|
||||||
response.send(payload)?;
|
response.send(payload)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -2386,7 +2428,7 @@ where
|
||||||
async fn multi_lsp_query(
|
async fn multi_lsp_query(
|
||||||
request: MultiLspQuery,
|
request: MultiLspQuery,
|
||||||
response: Response<MultiLspQuery>,
|
response: Response<MultiLspQuery>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
tracing::Span::current().record("multi_lsp_query_request", request.request_str());
|
tracing::Span::current().record("multi_lsp_query_request", request.request_str());
|
||||||
tracing::info!("multi_lsp_query message received");
|
tracing::info!("multi_lsp_query message received");
|
||||||
|
@ -2396,7 +2438,7 @@ async fn multi_lsp_query(
|
||||||
/// Notify other participants that a new buffer has been created
|
/// Notify other participants that a new buffer has been created
|
||||||
async fn create_buffer_for_peer(
|
async fn create_buffer_for_peer(
|
||||||
request: proto::CreateBufferForPeer,
|
request: proto::CreateBufferForPeer,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
session
|
session
|
||||||
.db()
|
.db()
|
||||||
|
@ -2418,7 +2460,7 @@ async fn create_buffer_for_peer(
|
||||||
async fn update_buffer(
|
async fn update_buffer(
|
||||||
request: proto::UpdateBuffer,
|
request: proto::UpdateBuffer,
|
||||||
response: Response<proto::UpdateBuffer>,
|
response: Response<proto::UpdateBuffer>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let project_id = ProjectId::from_proto(request.project_id);
|
let project_id = ProjectId::from_proto(request.project_id);
|
||||||
let mut capability = Capability::ReadOnly;
|
let mut capability = Capability::ReadOnly;
|
||||||
|
@ -2453,17 +2495,14 @@ async fn update_buffer(
|
||||||
};
|
};
|
||||||
|
|
||||||
if host != session.connection_id {
|
if host != session.connection_id {
|
||||||
session
|
session.forward_request(host, request.clone()).await?;
|
||||||
.peer
|
|
||||||
.forward_request(session.connection_id, host, request.clone())
|
|
||||||
.await?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
response.send(proto::Ack {})?;
|
response.send(proto::Ack {})?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
|
async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
|
||||||
let project_id = ProjectId::from_proto(message.project_id);
|
let project_id = ProjectId::from_proto(message.project_id);
|
||||||
|
|
||||||
let operation = message.operation.as_ref().context("invalid operation")?;
|
let operation = message.operation.as_ref().context("invalid operation")?;
|
||||||
|
@ -2508,7 +2547,7 @@ async fn update_context(message: proto::UpdateContext, session: Session) -> Resu
|
||||||
/// Notify other participants that a project has been updated.
|
/// Notify other participants that a project has been updated.
|
||||||
async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
|
async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
|
||||||
request: T,
|
request: T,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let project_id = ProjectId::from_proto(request.remote_entity_id());
|
let project_id = ProjectId::from_proto(request.remote_entity_id());
|
||||||
let project_connection_ids = session
|
let project_connection_ids = session
|
||||||
|
@ -2533,7 +2572,7 @@ async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProj
|
||||||
async fn follow(
|
async fn follow(
|
||||||
request: proto::Follow,
|
request: proto::Follow,
|
||||||
response: Response<proto::Follow>,
|
response: Response<proto::Follow>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let room_id = RoomId::from_proto(request.room_id);
|
let room_id = RoomId::from_proto(request.room_id);
|
||||||
let project_id = request.project_id.map(ProjectId::from_proto);
|
let project_id = request.project_id.map(ProjectId::from_proto);
|
||||||
|
@ -2546,10 +2585,7 @@ async fn follow(
|
||||||
.check_room_participants(room_id, leader_id, session.connection_id)
|
.check_room_participants(room_id, leader_id, session.connection_id)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let response_payload = session
|
let response_payload = session.forward_request(leader_id, request).await?;
|
||||||
.peer
|
|
||||||
.forward_request(session.connection_id, leader_id, request)
|
|
||||||
.await?;
|
|
||||||
response.send(response_payload)?;
|
response.send(response_payload)?;
|
||||||
|
|
||||||
if let Some(project_id) = project_id {
|
if let Some(project_id) = project_id {
|
||||||
|
@ -2565,7 +2601,7 @@ async fn follow(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stop following another user in a call.
|
/// Stop following another user in a call.
|
||||||
async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
|
async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
|
||||||
let room_id = RoomId::from_proto(request.room_id);
|
let room_id = RoomId::from_proto(request.room_id);
|
||||||
let project_id = request.project_id.map(ProjectId::from_proto);
|
let project_id = request.project_id.map(ProjectId::from_proto);
|
||||||
let leader_id = request.leader_id.context("invalid leader id")?.into();
|
let leader_id = request.leader_id.context("invalid leader id")?.into();
|
||||||
|
@ -2594,7 +2630,7 @@ async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Notify everyone following you of your current location.
|
/// Notify everyone following you of your current location.
|
||||||
async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
|
async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
|
||||||
let room_id = RoomId::from_proto(request.room_id);
|
let room_id = RoomId::from_proto(request.room_id);
|
||||||
let database = session.db.lock().await;
|
let database = session.db.lock().await;
|
||||||
|
|
||||||
|
@ -2629,7 +2665,7 @@ async fn update_followers(request: proto::UpdateFollowers, session: Session) ->
|
||||||
async fn get_users(
|
async fn get_users(
|
||||||
request: proto::GetUsers,
|
request: proto::GetUsers,
|
||||||
response: Response<proto::GetUsers>,
|
response: Response<proto::GetUsers>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let user_ids = request
|
let user_ids = request
|
||||||
.user_ids
|
.user_ids
|
||||||
|
@ -2657,7 +2693,7 @@ async fn get_users(
|
||||||
async fn fuzzy_search_users(
|
async fn fuzzy_search_users(
|
||||||
request: proto::FuzzySearchUsers,
|
request: proto::FuzzySearchUsers,
|
||||||
response: Response<proto::FuzzySearchUsers>,
|
response: Response<proto::FuzzySearchUsers>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let query = request.query;
|
let query = request.query;
|
||||||
let users = match query.len() {
|
let users = match query.len() {
|
||||||
|
@ -2689,7 +2725,7 @@ async fn fuzzy_search_users(
|
||||||
async fn request_contact(
|
async fn request_contact(
|
||||||
request: proto::RequestContact,
|
request: proto::RequestContact,
|
||||||
response: Response<proto::RequestContact>,
|
response: Response<proto::RequestContact>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let requester_id = session.user_id();
|
let requester_id = session.user_id();
|
||||||
let responder_id = UserId::from_proto(request.responder_id);
|
let responder_id = UserId::from_proto(request.responder_id);
|
||||||
|
@ -2736,7 +2772,7 @@ async fn request_contact(
|
||||||
async fn respond_to_contact_request(
|
async fn respond_to_contact_request(
|
||||||
request: proto::RespondToContactRequest,
|
request: proto::RespondToContactRequest,
|
||||||
response: Response<proto::RespondToContactRequest>,
|
response: Response<proto::RespondToContactRequest>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let responder_id = session.user_id();
|
let responder_id = session.user_id();
|
||||||
let requester_id = UserId::from_proto(request.requester_id);
|
let requester_id = UserId::from_proto(request.requester_id);
|
||||||
|
@ -2794,7 +2830,7 @@ async fn respond_to_contact_request(
|
||||||
async fn remove_contact(
|
async fn remove_contact(
|
||||||
request: proto::RemoveContact,
|
request: proto::RemoveContact,
|
||||||
response: Response<proto::RemoveContact>,
|
response: Response<proto::RemoveContact>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let requester_id = session.user_id();
|
let requester_id = session.user_id();
|
||||||
let responder_id = UserId::from_proto(request.user_id);
|
let responder_id = UserId::from_proto(request.user_id);
|
||||||
|
@ -3053,7 +3089,10 @@ async fn update_user_plan(session: &Session) -> Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
|
async fn subscribe_to_channels(
|
||||||
|
_: proto::SubscribeToChannels,
|
||||||
|
session: MessageContext,
|
||||||
|
) -> Result<()> {
|
||||||
subscribe_user_to_channels(session.user_id(), &session).await?;
|
subscribe_user_to_channels(session.user_id(), &session).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -3079,7 +3118,7 @@ async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Resul
|
||||||
async fn create_channel(
|
async fn create_channel(
|
||||||
request: proto::CreateChannel,
|
request: proto::CreateChannel,
|
||||||
response: Response<proto::CreateChannel>,
|
response: Response<proto::CreateChannel>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
|
|
||||||
|
@ -3134,7 +3173,7 @@ async fn create_channel(
|
||||||
async fn delete_channel(
|
async fn delete_channel(
|
||||||
request: proto::DeleteChannel,
|
request: proto::DeleteChannel,
|
||||||
response: Response<proto::DeleteChannel>,
|
response: Response<proto::DeleteChannel>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
|
|
||||||
|
@ -3162,7 +3201,7 @@ async fn delete_channel(
|
||||||
async fn invite_channel_member(
|
async fn invite_channel_member(
|
||||||
request: proto::InviteChannelMember,
|
request: proto::InviteChannelMember,
|
||||||
response: Response<proto::InviteChannelMember>,
|
response: Response<proto::InviteChannelMember>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
|
@ -3199,7 +3238,7 @@ async fn invite_channel_member(
|
||||||
async fn remove_channel_member(
|
async fn remove_channel_member(
|
||||||
request: proto::RemoveChannelMember,
|
request: proto::RemoveChannelMember,
|
||||||
response: Response<proto::RemoveChannelMember>,
|
response: Response<proto::RemoveChannelMember>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
|
@ -3243,7 +3282,7 @@ async fn remove_channel_member(
|
||||||
async fn set_channel_visibility(
|
async fn set_channel_visibility(
|
||||||
request: proto::SetChannelVisibility,
|
request: proto::SetChannelVisibility,
|
||||||
response: Response<proto::SetChannelVisibility>,
|
response: Response<proto::SetChannelVisibility>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
|
@ -3288,7 +3327,7 @@ async fn set_channel_visibility(
|
||||||
async fn set_channel_member_role(
|
async fn set_channel_member_role(
|
||||||
request: proto::SetChannelMemberRole,
|
request: proto::SetChannelMemberRole,
|
||||||
response: Response<proto::SetChannelMemberRole>,
|
response: Response<proto::SetChannelMemberRole>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
|
@ -3336,7 +3375,7 @@ async fn set_channel_member_role(
|
||||||
async fn rename_channel(
|
async fn rename_channel(
|
||||||
request: proto::RenameChannel,
|
request: proto::RenameChannel,
|
||||||
response: Response<proto::RenameChannel>,
|
response: Response<proto::RenameChannel>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
|
@ -3368,7 +3407,7 @@ async fn rename_channel(
|
||||||
async fn move_channel(
|
async fn move_channel(
|
||||||
request: proto::MoveChannel,
|
request: proto::MoveChannel,
|
||||||
response: Response<proto::MoveChannel>,
|
response: Response<proto::MoveChannel>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
let to = ChannelId::from_proto(request.to);
|
let to = ChannelId::from_proto(request.to);
|
||||||
|
@ -3410,7 +3449,7 @@ async fn move_channel(
|
||||||
async fn reorder_channel(
|
async fn reorder_channel(
|
||||||
request: proto::ReorderChannel,
|
request: proto::ReorderChannel,
|
||||||
response: Response<proto::ReorderChannel>,
|
response: Response<proto::ReorderChannel>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
let direction = request.direction();
|
let direction = request.direction();
|
||||||
|
@ -3456,7 +3495,7 @@ async fn reorder_channel(
|
||||||
async fn get_channel_members(
|
async fn get_channel_members(
|
||||||
request: proto::GetChannelMembers,
|
request: proto::GetChannelMembers,
|
||||||
response: Response<proto::GetChannelMembers>,
|
response: Response<proto::GetChannelMembers>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
|
@ -3476,7 +3515,7 @@ async fn get_channel_members(
|
||||||
async fn respond_to_channel_invite(
|
async fn respond_to_channel_invite(
|
||||||
request: proto::RespondToChannelInvite,
|
request: proto::RespondToChannelInvite,
|
||||||
response: Response<proto::RespondToChannelInvite>,
|
response: Response<proto::RespondToChannelInvite>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
|
@ -3517,7 +3556,7 @@ async fn respond_to_channel_invite(
|
||||||
async fn join_channel(
|
async fn join_channel(
|
||||||
request: proto::JoinChannel,
|
request: proto::JoinChannel,
|
||||||
response: Response<proto::JoinChannel>,
|
response: Response<proto::JoinChannel>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
join_channel_internal(channel_id, Box::new(response), session).await
|
join_channel_internal(channel_id, Box::new(response), session).await
|
||||||
|
@ -3540,7 +3579,7 @@ impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
|
||||||
async fn join_channel_internal(
|
async fn join_channel_internal(
|
||||||
channel_id: ChannelId,
|
channel_id: ChannelId,
|
||||||
response: Box<impl JoinChannelInternalResponse>,
|
response: Box<impl JoinChannelInternalResponse>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let joined_room = {
|
let joined_room = {
|
||||||
let mut db = session.db().await;
|
let mut db = session.db().await;
|
||||||
|
@ -3635,7 +3674,7 @@ async fn join_channel_internal(
|
||||||
async fn join_channel_buffer(
|
async fn join_channel_buffer(
|
||||||
request: proto::JoinChannelBuffer,
|
request: proto::JoinChannelBuffer,
|
||||||
response: Response<proto::JoinChannelBuffer>,
|
response: Response<proto::JoinChannelBuffer>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
|
@ -3666,7 +3705,7 @@ async fn join_channel_buffer(
|
||||||
/// Edit the channel notes
|
/// Edit the channel notes
|
||||||
async fn update_channel_buffer(
|
async fn update_channel_buffer(
|
||||||
request: proto::UpdateChannelBuffer,
|
request: proto::UpdateChannelBuffer,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
|
@ -3718,7 +3757,7 @@ async fn update_channel_buffer(
|
||||||
async fn rejoin_channel_buffers(
|
async fn rejoin_channel_buffers(
|
||||||
request: proto::RejoinChannelBuffers,
|
request: proto::RejoinChannelBuffers,
|
||||||
response: Response<proto::RejoinChannelBuffers>,
|
response: Response<proto::RejoinChannelBuffers>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
let buffers = db
|
let buffers = db
|
||||||
|
@ -3753,7 +3792,7 @@ async fn rejoin_channel_buffers(
|
||||||
async fn leave_channel_buffer(
|
async fn leave_channel_buffer(
|
||||||
request: proto::LeaveChannelBuffer,
|
request: proto::LeaveChannelBuffer,
|
||||||
response: Response<proto::LeaveChannelBuffer>,
|
response: Response<proto::LeaveChannelBuffer>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
|
@ -3815,7 +3854,7 @@ fn send_notifications(
|
||||||
async fn send_channel_message(
|
async fn send_channel_message(
|
||||||
request: proto::SendChannelMessage,
|
request: proto::SendChannelMessage,
|
||||||
response: Response<proto::SendChannelMessage>,
|
response: Response<proto::SendChannelMessage>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
// Validate the message body.
|
// Validate the message body.
|
||||||
let body = request.body.trim().to_string();
|
let body = request.body.trim().to_string();
|
||||||
|
@ -3908,7 +3947,7 @@ async fn send_channel_message(
|
||||||
async fn remove_channel_message(
|
async fn remove_channel_message(
|
||||||
request: proto::RemoveChannelMessage,
|
request: proto::RemoveChannelMessage,
|
||||||
response: Response<proto::RemoveChannelMessage>,
|
response: Response<proto::RemoveChannelMessage>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
let message_id = MessageId::from_proto(request.message_id);
|
let message_id = MessageId::from_proto(request.message_id);
|
||||||
|
@ -3943,7 +3982,7 @@ async fn remove_channel_message(
|
||||||
async fn update_channel_message(
|
async fn update_channel_message(
|
||||||
request: proto::UpdateChannelMessage,
|
request: proto::UpdateChannelMessage,
|
||||||
response: Response<proto::UpdateChannelMessage>,
|
response: Response<proto::UpdateChannelMessage>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
let message_id = MessageId::from_proto(request.message_id);
|
let message_id = MessageId::from_proto(request.message_id);
|
||||||
|
@ -4027,7 +4066,7 @@ async fn update_channel_message(
|
||||||
/// Mark a channel message as read
|
/// Mark a channel message as read
|
||||||
async fn acknowledge_channel_message(
|
async fn acknowledge_channel_message(
|
||||||
request: proto::AckChannelMessage,
|
request: proto::AckChannelMessage,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
let message_id = MessageId::from_proto(request.message_id);
|
let message_id = MessageId::from_proto(request.message_id);
|
||||||
|
@ -4047,7 +4086,7 @@ async fn acknowledge_channel_message(
|
||||||
/// Mark a buffer version as synced
|
/// Mark a buffer version as synced
|
||||||
async fn acknowledge_buffer_version(
|
async fn acknowledge_buffer_version(
|
||||||
request: proto::AckBufferOperation,
|
request: proto::AckBufferOperation,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let buffer_id = BufferId::from_proto(request.buffer_id);
|
let buffer_id = BufferId::from_proto(request.buffer_id);
|
||||||
session
|
session
|
||||||
|
@ -4067,7 +4106,7 @@ async fn acknowledge_buffer_version(
|
||||||
async fn get_supermaven_api_key(
|
async fn get_supermaven_api_key(
|
||||||
_request: proto::GetSupermavenApiKey,
|
_request: proto::GetSupermavenApiKey,
|
||||||
response: Response<proto::GetSupermavenApiKey>,
|
response: Response<proto::GetSupermavenApiKey>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let user_id: String = session.user_id().to_string();
|
let user_id: String = session.user_id().to_string();
|
||||||
if !session.is_staff() {
|
if !session.is_staff() {
|
||||||
|
@ -4096,7 +4135,7 @@ async fn get_supermaven_api_key(
|
||||||
async fn join_channel_chat(
|
async fn join_channel_chat(
|
||||||
request: proto::JoinChannelChat,
|
request: proto::JoinChannelChat,
|
||||||
response: Response<proto::JoinChannelChat>,
|
response: Response<proto::JoinChannelChat>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
|
|
||||||
|
@ -4114,7 +4153,10 @@ async fn join_channel_chat(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stop receiving chat updates for a channel
|
/// Stop receiving chat updates for a channel
|
||||||
async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
|
async fn leave_channel_chat(
|
||||||
|
request: proto::LeaveChannelChat,
|
||||||
|
session: MessageContext,
|
||||||
|
) -> Result<()> {
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
session
|
session
|
||||||
.db()
|
.db()
|
||||||
|
@ -4128,7 +4170,7 @@ async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session)
|
||||||
async fn get_channel_messages(
|
async fn get_channel_messages(
|
||||||
request: proto::GetChannelMessages,
|
request: proto::GetChannelMessages,
|
||||||
response: Response<proto::GetChannelMessages>,
|
response: Response<proto::GetChannelMessages>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||||
let messages = session
|
let messages = session
|
||||||
|
@ -4152,7 +4194,7 @@ async fn get_channel_messages(
|
||||||
async fn get_channel_messages_by_id(
|
async fn get_channel_messages_by_id(
|
||||||
request: proto::GetChannelMessagesById,
|
request: proto::GetChannelMessagesById,
|
||||||
response: Response<proto::GetChannelMessagesById>,
|
response: Response<proto::GetChannelMessagesById>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let message_ids = request
|
let message_ids = request
|
||||||
.message_ids
|
.message_ids
|
||||||
|
@ -4175,7 +4217,7 @@ async fn get_channel_messages_by_id(
|
||||||
async fn get_notifications(
|
async fn get_notifications(
|
||||||
request: proto::GetNotifications,
|
request: proto::GetNotifications,
|
||||||
response: Response<proto::GetNotifications>,
|
response: Response<proto::GetNotifications>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let notifications = session
|
let notifications = session
|
||||||
.db()
|
.db()
|
||||||
|
@ -4197,7 +4239,7 @@ async fn get_notifications(
|
||||||
async fn mark_notification_as_read(
|
async fn mark_notification_as_read(
|
||||||
request: proto::MarkNotificationRead,
|
request: proto::MarkNotificationRead,
|
||||||
response: Response<proto::MarkNotificationRead>,
|
response: Response<proto::MarkNotificationRead>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let database = &session.db().await;
|
let database = &session.db().await;
|
||||||
let notifications = database
|
let notifications = database
|
||||||
|
@ -4219,7 +4261,7 @@ async fn mark_notification_as_read(
|
||||||
async fn get_private_user_info(
|
async fn get_private_user_info(
|
||||||
_request: proto::GetPrivateUserInfo,
|
_request: proto::GetPrivateUserInfo,
|
||||||
response: Response<proto::GetPrivateUserInfo>,
|
response: Response<proto::GetPrivateUserInfo>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
|
|
||||||
|
@ -4243,7 +4285,7 @@ async fn get_private_user_info(
|
||||||
async fn accept_terms_of_service(
|
async fn accept_terms_of_service(
|
||||||
_request: proto::AcceptTermsOfService,
|
_request: proto::AcceptTermsOfService,
|
||||||
response: Response<proto::AcceptTermsOfService>,
|
response: Response<proto::AcceptTermsOfService>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
|
|
||||||
|
@ -4267,7 +4309,7 @@ async fn accept_terms_of_service(
|
||||||
async fn get_llm_api_token(
|
async fn get_llm_api_token(
|
||||||
_request: proto::GetLlmToken,
|
_request: proto::GetLlmToken,
|
||||||
response: Response<proto::GetLlmToken>,
|
response: Response<proto::GetLlmToken>,
|
||||||
session: Session,
|
session: MessageContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ anthropic = { workspace = true, features = ["schemars"] }
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
base64.workspace = true
|
base64.workspace = true
|
||||||
client.workspace = true
|
client.workspace = true
|
||||||
|
cloud_api_types.workspace = true
|
||||||
cloud_llm_client.workspace = true
|
cloud_llm_client.workspace = true
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
|
|
|
@ -3,11 +3,9 @@ use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use client::Client;
|
use client::Client;
|
||||||
|
use cloud_api_types::websocket_protocol::MessageToClient;
|
||||||
use cloud_llm_client::Plan;
|
use cloud_llm_client::Plan;
|
||||||
use gpui::{
|
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
|
||||||
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
|
|
||||||
};
|
|
||||||
use proto::TypedEnvelope;
|
|
||||||
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
|
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
|
@ -82,9 +80,7 @@ impl Global for GlobalRefreshLlmTokenListener {}
|
||||||
|
|
||||||
pub struct RefreshLlmTokenEvent;
|
pub struct RefreshLlmTokenEvent;
|
||||||
|
|
||||||
pub struct RefreshLlmTokenListener {
|
pub struct RefreshLlmTokenListener;
|
||||||
_llm_token_subscription: client::Subscription,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
|
impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
|
||||||
|
|
||||||
|
@ -99,17 +95,21 @@ impl RefreshLlmTokenListener {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
|
fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
|
||||||
Self {
|
client.add_message_to_client_handler({
|
||||||
_llm_token_subscription: client
|
let this = cx.entity();
|
||||||
.add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
|
move |message, cx| {
|
||||||
}
|
Self::handle_refresh_llm_token(this.clone(), message, cx);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Self
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_refresh_llm_token(
|
fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
|
||||||
this: Entity<Self>,
|
match message {
|
||||||
_: TypedEnvelope<proto::RefreshLlmToken>,
|
MessageToClient::UserUpdated => {
|
||||||
mut cx: AsyncApp,
|
this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
|
||||||
) -> Result<()> {
|
}
|
||||||
this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -674,6 +674,10 @@ pub fn count_open_ai_tokens(
|
||||||
| Model::O3
|
| Model::O3
|
||||||
| Model::O3Mini
|
| Model::O3Mini
|
||||||
| Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
|
| Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
|
||||||
|
// GPT-5 models don't have tiktoken support yet; fall back on gpt-4o tokenizer
|
||||||
|
Model::Five | Model::FiveMini | Model::FiveNano => {
|
||||||
|
tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
.map(|tokens| tokens as u64)
|
.map(|tokens| tokens as u64)
|
||||||
})
|
})
|
||||||
|
|
|
@ -74,6 +74,12 @@ pub enum Model {
|
||||||
O3,
|
O3,
|
||||||
#[serde(rename = "o4-mini")]
|
#[serde(rename = "o4-mini")]
|
||||||
O4Mini,
|
O4Mini,
|
||||||
|
#[serde(rename = "gpt-5")]
|
||||||
|
Five,
|
||||||
|
#[serde(rename = "gpt-5-mini")]
|
||||||
|
FiveMini,
|
||||||
|
#[serde(rename = "gpt-5-nano")]
|
||||||
|
FiveNano,
|
||||||
|
|
||||||
#[serde(rename = "custom")]
|
#[serde(rename = "custom")]
|
||||||
Custom {
|
Custom {
|
||||||
|
@ -105,6 +111,9 @@ impl Model {
|
||||||
"o3-mini" => Ok(Self::O3Mini),
|
"o3-mini" => Ok(Self::O3Mini),
|
||||||
"o3" => Ok(Self::O3),
|
"o3" => Ok(Self::O3),
|
||||||
"o4-mini" => Ok(Self::O4Mini),
|
"o4-mini" => Ok(Self::O4Mini),
|
||||||
|
"gpt-5" => Ok(Self::Five),
|
||||||
|
"gpt-5-mini" => Ok(Self::FiveMini),
|
||||||
|
"gpt-5-nano" => Ok(Self::FiveNano),
|
||||||
invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
|
invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -123,6 +132,9 @@ impl Model {
|
||||||
Self::O3Mini => "o3-mini",
|
Self::O3Mini => "o3-mini",
|
||||||
Self::O3 => "o3",
|
Self::O3 => "o3",
|
||||||
Self::O4Mini => "o4-mini",
|
Self::O4Mini => "o4-mini",
|
||||||
|
Self::Five => "gpt-5",
|
||||||
|
Self::FiveMini => "gpt-5-mini",
|
||||||
|
Self::FiveNano => "gpt-5-nano",
|
||||||
Self::Custom { name, .. } => name,
|
Self::Custom { name, .. } => name,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -141,6 +153,9 @@ impl Model {
|
||||||
Self::O3Mini => "o3-mini",
|
Self::O3Mini => "o3-mini",
|
||||||
Self::O3 => "o3",
|
Self::O3 => "o3",
|
||||||
Self::O4Mini => "o4-mini",
|
Self::O4Mini => "o4-mini",
|
||||||
|
Self::Five => "gpt-5",
|
||||||
|
Self::FiveMini => "gpt-5-mini",
|
||||||
|
Self::FiveNano => "gpt-5-nano",
|
||||||
Self::Custom {
|
Self::Custom {
|
||||||
name, display_name, ..
|
name, display_name, ..
|
||||||
} => display_name.as_ref().unwrap_or(name),
|
} => display_name.as_ref().unwrap_or(name),
|
||||||
|
@ -161,6 +176,9 @@ impl Model {
|
||||||
Self::O3Mini => 200_000,
|
Self::O3Mini => 200_000,
|
||||||
Self::O3 => 200_000,
|
Self::O3 => 200_000,
|
||||||
Self::O4Mini => 200_000,
|
Self::O4Mini => 200_000,
|
||||||
|
Self::Five => 272_000,
|
||||||
|
Self::FiveMini => 272_000,
|
||||||
|
Self::FiveNano => 272_000,
|
||||||
Self::Custom { max_tokens, .. } => *max_tokens,
|
Self::Custom { max_tokens, .. } => *max_tokens,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -182,6 +200,9 @@ impl Model {
|
||||||
Self::O3Mini => Some(100_000),
|
Self::O3Mini => Some(100_000),
|
||||||
Self::O3 => Some(100_000),
|
Self::O3 => Some(100_000),
|
||||||
Self::O4Mini => Some(100_000),
|
Self::O4Mini => Some(100_000),
|
||||||
|
Self::Five => Some(128_000),
|
||||||
|
Self::FiveMini => Some(128_000),
|
||||||
|
Self::FiveNano => Some(128_000),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -197,7 +218,10 @@ impl Model {
|
||||||
| Self::FourOmniMini
|
| Self::FourOmniMini
|
||||||
| Self::FourPointOne
|
| Self::FourPointOne
|
||||||
| Self::FourPointOneMini
|
| Self::FourPointOneMini
|
||||||
| Self::FourPointOneNano => true,
|
| Self::FourPointOneNano
|
||||||
|
| Self::Five
|
||||||
|
| Self::FiveMini
|
||||||
|
| Self::FiveNano => true,
|
||||||
Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
|
Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -422,23 +422,8 @@ impl Peer {
|
||||||
receiver_id: ConnectionId,
|
receiver_id: ConnectionId,
|
||||||
request: T,
|
request: T,
|
||||||
) -> impl Future<Output = Result<T::Response>> {
|
) -> impl Future<Output = Result<T::Response>> {
|
||||||
let request_start_time = Instant::now();
|
|
||||||
let elapsed_time = move || request_start_time.elapsed().as_millis();
|
|
||||||
tracing::info!("start forwarding request");
|
|
||||||
self.request_internal(Some(sender_id), receiver_id, request)
|
self.request_internal(Some(sender_id), receiver_id, request)
|
||||||
.map_ok(|envelope| envelope.payload)
|
.map_ok(|envelope| envelope.payload)
|
||||||
.inspect_err(move |_| {
|
|
||||||
tracing::error!(
|
|
||||||
waiting_for_host_ms = elapsed_time(),
|
|
||||||
"error forwarding request"
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.inspect_ok(move |_| {
|
|
||||||
tracing::info!(
|
|
||||||
waiting_for_host_ms = elapsed_time(),
|
|
||||||
"finished forwarding request"
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn request_internal<T: RequestMessage>(
|
fn request_internal<T: RequestMessage>(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue