ZIm/crates/client/src/test.rs
Marshall Bowers bd1fda6782
proto: Remove GetPrivateUserInfo message (#36265)
This PR removes the `GetPrivateUserInfo` RPC message.

We're no longer using the message after
https://github.com/zed-industries/zed/pull/36255.

Release Notes:

- N/A
2025-08-15 18:27:31 +00:00

290 lines
9.8 KiB
Rust

use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
use anyhow::{Context as _, Result, anyhow};
use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
use futures::{StreamExt, stream::BoxStream};
use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
use http_client::{AsyncBody, Method, Request, http};
use parking_lot::Mutex;
use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto};
use std::sync::Arc;
pub struct FakeServer {
peer: Arc<Peer>,
state: Arc<Mutex<FakeServerState>>,
user_id: u64,
executor: BackgroundExecutor,
}
#[derive(Default)]
struct FakeServerState {
incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
connection_id: Option<ConnectionId>,
forbid_connections: bool,
auth_count: usize,
access_token: usize,
}
impl FakeServer {
pub async fn for_client(
client_user_id: u64,
client: &Arc<Client>,
cx: &TestAppContext,
) -> Self {
let server = Self {
peer: Peer::new(0),
state: Default::default(),
user_id: client_user_id,
executor: cx.executor(),
};
client.http_client().as_fake().replace_handler({
let state = server.state.clone();
move |old_handler, req| {
let state = state.clone();
let old_handler = old_handler.clone();
async move {
match (req.method(), req.uri().path()) {
(&Method::GET, "/client/users/me") => {
let credentials = parse_authorization_header(&req);
if credentials
!= Some(Credentials {
user_id: client_user_id,
access_token: state.lock().access_token.to_string(),
})
{
return Ok(http_client::Response::builder()
.status(401)
.body("Unauthorized".into())
.unwrap());
}
Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&make_get_authenticated_user_response(
client_user_id as i32,
format!("user-{client_user_id}"),
))
.unwrap()
.into(),
)
.unwrap())
}
_ => old_handler(req).await,
}
}
}
});
client
.override_authenticate({
let state = Arc::downgrade(&server.state);
move |cx| {
let state = state.clone();
cx.spawn(async move |_| {
let state = state.upgrade().context("server dropped")?;
let mut state = state.lock();
state.auth_count += 1;
let access_token = state.access_token.to_string();
Ok(Credentials {
user_id: client_user_id,
access_token,
})
})
}
})
.override_establish_connection({
let peer = Arc::downgrade(&server.peer);
let state = Arc::downgrade(&server.state);
move |credentials, cx| {
let peer = peer.clone();
let state = state.clone();
let credentials = credentials.clone();
cx.spawn(async move |cx| {
let state = state.upgrade().context("server dropped")?;
let peer = peer.upgrade().context("server dropped")?;
if state.lock().forbid_connections {
Err(EstablishConnectionError::Other(anyhow!(
"server is forbidding connections"
)))?
}
if credentials
!= (Credentials {
user_id: client_user_id,
access_token: state.lock().access_token.to_string(),
})
{
Err(EstablishConnectionError::Unauthorized)?
}
let (client_conn, server_conn, _) =
Connection::in_memory(cx.background_executor().clone());
let (connection_id, io, incoming) =
peer.add_test_connection(server_conn, cx.background_executor().clone());
cx.background_spawn(io).detach();
{
let mut state = state.lock();
state.connection_id = Some(connection_id);
state.incoming = Some(incoming);
}
peer.send(
connection_id,
proto::Hello {
peer_id: Some(connection_id.into()),
},
)
.unwrap();
Ok(client_conn)
})
}
});
client
.connect(false, &cx.to_async())
.await
.into_response()
.unwrap();
server
}
pub fn disconnect(&self) {
if self.state.lock().connection_id.is_some() {
self.peer.disconnect(self.connection_id());
let mut state = self.state.lock();
state.connection_id.take();
state.incoming.take();
}
}
pub fn auth_count(&self) -> usize {
self.state.lock().auth_count
}
pub fn roll_access_token(&self) {
self.state.lock().access_token += 1;
}
pub fn forbid_connections(&self) {
self.state.lock().forbid_connections = true;
}
pub fn allow_connections(&self) {
self.state.lock().forbid_connections = false;
}
pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
self.peer.send(self.connection_id(), message).unwrap();
}
#[allow(clippy::await_holding_lock)]
pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
self.executor.start_waiting();
let message = self
.state
.lock()
.incoming
.as_mut()
.expect("not connected")
.next()
.await
.context("other half hung up")?;
self.executor.finish_waiting();
let type_name = message.payload_type_name();
let message = message.into_any();
if message.is::<TypedEnvelope<M>>() {
return Ok(*message.downcast().unwrap());
}
panic!(
"fake server received unexpected message type: {:?}",
type_name
);
}
pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
self.peer.respond(receipt, response).unwrap()
}
fn connection_id(&self) -> ConnectionId {
self.state.lock().connection_id.expect("not connected")
}
pub async fn build_user_store(
&self,
client: Arc<Client>,
cx: &mut TestAppContext,
) -> Entity<UserStore> {
let user_store = cx.new(|cx| UserStore::new(client, cx));
assert_eq!(
self.receive::<proto::GetUsers>()
.await
.unwrap()
.payload
.user_ids,
&[self.user_id]
);
user_store
}
}
impl Drop for FakeServer {
fn drop(&mut self) {
self.disconnect();
}
}
pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
let mut auth_header = req
.headers()
.get(http::header::AUTHORIZATION)?
.to_str()
.ok()?
.split_whitespace();
let user_id = auth_header.next()?.parse().ok()?;
let access_token = auth_header.next()?;
Some(Credentials {
user_id,
access_token: access_token.to_string(),
})
}
pub fn make_get_authenticated_user_response(
user_id: i32,
github_login: String,
) -> GetAuthenticatedUserResponse {
GetAuthenticatedUserResponse {
user: AuthenticatedUser {
id: user_id,
metrics_id: format!("metrics-id-{user_id}"),
avatar_url: "".to_string(),
github_login,
name: None,
is_staff: false,
accepted_tos_at: None,
},
feature_flags: vec![],
plan: PlanInfo {
plan: Plan::ZedPro,
subscription_period: None,
usage: CurrentUsage {
model_requests: UsageData {
used: 0,
limit: UsageLimit::Limited(500),
},
edit_predictions: UsageData {
used: 250,
limit: UsageLimit::Unlimited,
},
},
trial_started_at: None,
is_usage_based_billing_enabled: false,
is_account_too_young: false,
has_overdue_invoices: false,
},
}
}