Implement wire reconnect support

This commit is contained in:
luaneko 2025-01-11 06:02:32 +11:00
parent 6f9e9770cf
commit da7f7e12f3
Signed by: luaneko
GPG Key ID: 406809B8763FF07A

353
wire.ts
View File

@ -1,5 +1,6 @@
import { import {
type BinaryLike, type BinaryLike,
buf_concat,
buf_concat_fast, buf_concat_fast,
buf_eq, buf_eq,
buf_xor, buf_xor,
@ -7,7 +8,9 @@ import {
from_base64, from_base64,
from_utf8, from_utf8,
jit, jit,
type Receiver,
semaphore, semaphore,
type Sender,
to_base64, to_base64,
to_utf8, to_utf8,
TypedEmitter, TypedEmitter,
@ -475,57 +478,38 @@ export interface Channel
} }
export async function wire_connect(options: WireOptions) { export async function wire_connect(options: WireOptions) {
const { host, port } = options; const wire = new Wire(options);
const wire = new Wire(await socket_connect(host, port), options); return await wire.connect(), wire;
return await wire.connected, wire;
}
async function socket_connect(hostname: string, port: number) {
if (hostname.startsWith("/")) {
const path = join(hostname, `.s.PGSQL.${port}`);
return await Deno.connect({ transport: "unix", path });
} else {
const socket = await Deno.connect({ transport: "tcp", hostname, port });
return socket.setNoDelay(), socket.setKeepAlive(), socket;
}
} }
export class Wire extends TypedEmitter<WireEvents> implements Disposable { export class Wire extends TypedEmitter<WireEvents> implements Disposable {
readonly #socket;
readonly #params; readonly #params;
readonly #auth; readonly #connect;
readonly #connected;
readonly #query; readonly #query;
readonly #begin; readonly #begin;
readonly #listen; readonly #listen;
readonly #notify; readonly #notify;
readonly #close; readonly #close;
get socket() {
return this.#socket;
}
get params() { get params() {
return this.#params; return this.#params;
} }
get connected() { constructor(options: WireOptions) {
return this.#connected;
}
constructor(socket: Deno.Conn, options: WireOptions) {
super(); super();
({ ({
params: this.#params, params: this.#params,
auth: this.#auth, connect: this.#connect,
query: this.#query, query: this.#query,
begin: this.#begin, begin: this.#begin,
listen: this.#listen, listen: this.#listen,
notify: this.#notify, notify: this.#notify,
close: this.#close, close: this.#close,
} = wire_impl(this, socket, options)); } = wire_impl(this, options));
this.#socket = socket; }
(this.#connected = this.#auth()).catch(close);
connect() {
return this.#connect();
} }
query<T = Row>(sql: SqlFragment): Query<T>; query<T = Row>(sql: SqlFragment): Query<T>;
@ -579,107 +563,144 @@ export class Wire extends TypedEmitter<WireEvents> implements Disposable {
} }
} }
const msg_PD = object({ P: Parse, D: Describe }); async function socket_connect(hostname: string, port: number) {
const msg_BE = object({ B: Bind, E: Execute }); if (hostname.startsWith("/")) {
const msg_BEc = object({ B: Bind, E: Execute, c: CopyDone }); const path = join(hostname, `.s.PGSQL.${port}`);
const msg_BEcC = object({ B: Bind, E: Execute, c: CopyDone, C: Close }); return await Deno.connect({ transport: "unix", path });
} else {
const socket = await Deno.connect({ transport: "tcp", hostname, port });
return socket.setNoDelay(), socket.setKeepAlive(), socket;
}
}
function wire_impl( function wire_impl(
wire: Wire, wire: Wire,
socket: Deno.Conn, { host, port, user, database, password, runtime_params, types }: WireOptions
{ user, database, password, runtime_params, types }: WireOptions
) { ) {
// current runtime parameters as reported by postgres
const params: Parameters = Object.create(null); const params: Parameters = Object.create(null);
function log(level: LogLevel, ctx: object, msg: string) { function log(level: LogLevel, ctx: object, msg: string) {
wire.emit("log", level, ctx, msg); wire.emit("log", level, ctx, msg);
} }
// wire supports re-connection; socket and read/write channels are null when closed
let socket: Deno.Conn | null = null;
let read_pop: Receiver<Uint8Array> | null = null;
let write_push: Sender<Uint8Array> | null = null;
async function read<T>(type: Encoder<T>) { async function read<T>(type: Encoder<T>) {
const msg = await read_recv(); const msg = read_pop !== null ? await read_pop() : null;
if (msg === null) throw new WireError(`connection closed`); if (msg !== null) return ser_decode(type, msg_check_err(msg));
else return ser_decode(type, msg_check_err(msg)); else throw new WireError(`connection closed`);
} }
async function read_raw() { async function read_msg() {
const msg = await read_recv(); const msg = read_pop !== null ? await read_pop() : null;
if (msg === null) throw new WireError(`connection closed`); if (msg !== null) return msg;
else return msg; else throw new WireError(`connection closed`);
} }
async function* read_socket() { // socket reader channel worker
const buf = new Uint8Array(64 * 1024); async function read_socket(socket: Deno.Conn, push: Sender<Uint8Array>) {
for (let n; (n = await socket.read(buf)) !== null; ) const header_size = 5;
yield buf.subarray(0, n); const read_buf = new Uint8Array(64 * 1024); // shared buffer for all socket reads
let buf = new Uint8Array(); // concatenated messages read so far
for (let read; (read = await socket.read(read_buf)) !== null; ) {
buf = buf_concat_fast(buf, read_buf.subarray(0, read)); // push read bytes to buf
while (buf.length >= header_size) {
const size = ser_decode(Header, buf).length + 1;
if (buf.length < size) break;
const msg = buf.subarray(0, size); // shift one message from buf
buf = buf.subarray(size);
if (!handle_msg(msg)) push(msg);
}
}
// there should be nothing left in buf if we gracefully exited
if (buf.length !== 0) throw new WireError(`unexpected end of stream`);
} }
const read_recv = channel.receiver<Uint8Array>(async function read(send) { function handle_msg(msg: Uint8Array) {
let err: unknown; switch (msg_type(msg)) {
try { // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-ASYNC
let buf = new Uint8Array(); case NoticeResponse.type: {
for await (const chunk of read_socket()) { const { fields } = ser_decode(NoticeResponse, msg);
buf = buf_concat_fast(buf, chunk); const notice = new PostgresError(fields);
log(severity_level(notice.severity), notice, notice.message);
for (let n; (n = ser_decode(Header, buf).length + 1) <= buf.length; ) { wire.emit("notice", notice);
const msg = buf.subarray(0, n); return true;
buf = buf.subarray(n);
switch (msg_type(msg)) {
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-ASYNC
case NoticeResponse.type: {
const { fields } = ser_decode(NoticeResponse, msg);
const notice = new PostgresError(fields);
log(severity_level(notice.severity), notice, notice.message);
wire.emit("notice", notice);
continue;
}
case ParameterStatus.type: {
const { name, value } = ser_decode(ParameterStatus, msg);
const prev = params[name] ?? null;
Object.defineProperty(params, name, {
configurable: true,
enumerable: true,
value,
});
wire.emit("parameter", name, value, prev);
continue;
}
case NotificationResponse.type: {
const { channel, payload, process_id } = ser_decode(
NotificationResponse,
msg
);
wire.emit("notify", channel, payload, process_id);
channels.get(channel)?.emit("notify", payload, process_id);
continue;
}
}
send(msg);
}
} }
if (buf.length !== 0) throw new WireError(`unexpected end of stream`); case ParameterStatus.type: {
} catch (e) { const { name, value } = ser_decode(ParameterStatus, msg);
throw ((err = e), e); const prev = params[name] ?? null;
} finally { Object.defineProperty(params, name, {
wire.emit("close", err); configurable: true,
} enumerable: true,
}); value,
});
wire.emit("parameter", name, value, prev);
return true;
}
function write<T>(type: Encoder<T>, value: T) { case NotificationResponse.type: {
return write_raw(ser_encode(type, value)); const { channel, payload, process_id } = ser_decode(
NotificationResponse,
msg
);
wire.emit("notify", channel, payload, process_id);
channels.get(channel)?.emit("notify", payload, process_id);
return true;
}
}
return false;
} }
async function write_raw(buf: Uint8Array) { function write<T>(type: Encoder<T>, value: T) {
for (let i = 0, n = buf.length; i < n; ) write_msg(ser_encode(type, value));
i += await socket.write(buf.subarray(i)); }
function write_msg(buf: Uint8Array) {
if (write_push !== null) write_push(buf);
else throw new WireError(`connection closed`);
}
// socket writer channel worker
async function write_socket(socket: Deno.Conn, pop: Receiver<Uint8Array>) {
for (let buf; (buf = await pop()) !== null; ) {
const bufs = [buf]; // proactively dequeue more queued msgs synchronously, if any
for (let i = 1, buf; (buf = pop.try()) !== null; ) bufs[i++] = buf;
if (bufs.length !== 1) buf = buf_concat(bufs); // write queued msgs concatenated, reduce write syscalls
for (let i = 0, n = buf.length; i < n; )
i += await socket.write(buf.subarray(i));
}
}
async function connect() {
using _rlock = await rlock();
using _wlock = await wlock();
close(new WireError(`reconnecting`));
try {
const s = (socket = await socket_connect(host, port));
read_pop = channel.receiver((push) => read_socket(s, push));
write_push = channel.sender((pop) => write_socket(s, pop));
await handle_auth(); // run auth with rw lock
} catch (e) {
throw (close(e), e);
}
} }
function close(reason?: unknown) { function close(reason?: unknown) {
socket.close(), read_recv.close(reason); socket?.close(), (socket = null);
read_pop?.close(reason), (read_pop = null);
write_push?.close(reason), (write_push = null);
for (const name of Object.keys(params))
delete (params as Record<string, string>)[name];
st_cache.clear(), (st_ids = 0);
(tx_status = "I"), (tx_stack.length = 0);
} }
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-PIPELINING // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-PIPELINING
@ -697,13 +718,13 @@ function wire_impl(
} }
async function pipeline_read<T>(r: () => T | PromiseLike<T>) { async function pipeline_read<T>(r: () => T | PromiseLike<T>) {
using _rlock = await rlock(); using _lock = await rlock();
try { try {
return await r(); return await r();
} finally { } finally {
try { try {
let msg; let msg;
while (msg_type((msg = await read_raw())) !== ReadyForQuery.type); while (msg_type((msg = await read_msg())) !== ReadyForQuery.type);
({ tx_status } = ser_decode(ReadyForQuery, msg)); ({ tx_status } = ser_decode(ReadyForQuery, msg));
} catch { } catch {
// ignored // ignored
@ -712,12 +733,12 @@ function wire_impl(
} }
async function pipeline_write<T>(w: () => T | PromiseLike<T>) { async function pipeline_write<T>(w: () => T | PromiseLike<T>) {
using _wlock = await wlock(); using _lock = await wlock();
try { try {
return await w(); return await w();
} finally { } finally {
try { try {
await write(Sync, {}); write(Sync, {});
} catch { } catch {
// ignored // ignored
} }
@ -725,11 +746,9 @@ function wire_impl(
} }
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-START-UP // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-START-UP
async function auth() { async function handle_auth() {
using _rlock = await rlock(); // always run within rw lock (see connect())
using _wlock = await wlock(); write(StartupMessage, {
await write(StartupMessage, {
version: 196608, version: 196608,
params: { params: {
application_name: "pglue", application_name: "pglue",
@ -744,7 +763,7 @@ function wire_impl(
}); });
auth: for (;;) { auth: for (;;) {
const msg = msg_check_err(await read_raw()); const msg = msg_check_err(await read_msg());
switch (msg_type(msg)) { switch (msg_type(msg)) {
case NegotiateProtocolVersion.type: { case NegotiateProtocolVersion.type: {
const { bad_options } = ser_decode(NegotiateProtocolVersion, msg); const { bad_options } = ser_decode(NegotiateProtocolVersion, msg);
@ -762,7 +781,7 @@ function wire_impl(
throw new WireError(`kerberos authentication is deprecated`); throw new WireError(`kerberos authentication is deprecated`);
case 3: // AuthenticationCleartextPassword case 3: // AuthenticationCleartextPassword
await write(PasswordMessage, { password }); write(PasswordMessage, { password });
continue; continue;
case 5: // AuthenticationMD5Password case 5: // AuthenticationMD5Password
@ -778,7 +797,7 @@ function wire_impl(
// AuthenticationSASL // AuthenticationSASL
case 10: case 10:
await auth_sasl(); await handle_auth_sasl();
continue; continue;
default: default:
@ -786,8 +805,9 @@ function wire_impl(
} }
} }
// wait for ready
ready: for (;;) { ready: for (;;) {
const msg = msg_check_err(await read_raw()); const msg = msg_check_err(await read_msg());
switch (msg_type(msg)) { switch (msg_type(msg)) {
case BackendKeyData.type: case BackendKeyData.type:
continue; // ignored continue; // ignored
@ -797,11 +817,18 @@ function wire_impl(
break ready; break ready;
} }
} }
// re-listen previously registered channels
await Promise.all(
channels
.keys()
.map((name) => query(sql`listen ${sql.ident(name)}`).execute())
);
} }
// https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256 // https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256
// https://datatracker.ietf.org/doc/html/rfc5802 // https://datatracker.ietf.org/doc/html/rfc5802
async function auth_sasl() { async function handle_auth_sasl() {
const bits = 256; const bits = 256;
const hash = `SHA-${bits}`; const hash = `SHA-${bits}`;
const mechanism = `SCRAM-${hash}`; const mechanism = `SCRAM-${hash}`;
@ -858,7 +885,7 @@ function wire_impl(
)}`; )}`;
const client_first_message_bare = `${username},${initial_nonce}`; const client_first_message_bare = `${username},${initial_nonce}`;
const client_first_message = `${gs2_header}${client_first_message_bare}`; const client_first_message = `${gs2_header}${client_first_message_bare}`;
await write(SASLInitialResponse, { mechanism, data: client_first_message }); write(SASLInitialResponse, { mechanism, data: client_first_message });
const server_first_message_str = from_utf8( const server_first_message_str = from_utf8(
(await read(AuthenticationSASLContinue)).data (await read(AuthenticationSASLContinue)).data
@ -877,7 +904,7 @@ function wire_impl(
const client_proof = buf_xor(client_key, client_signature); const client_proof = buf_xor(client_key, client_signature);
const proof = `p=${to_base64(client_proof)}`; const proof = `p=${to_base64(client_proof)}`;
const client_final_message = `${client_final_message_without_proof},${proof}`; const client_final_message = `${client_final_message_without_proof},${proof}`;
await write(SASLResponse, { data: client_final_message }); write(SASLResponse, { data: client_final_message });
const server_key = await hmac(salted_password, "Server Key"); const server_key = await hmac(salted_password, "Server Key");
const server_signature = await hmac(server_key, auth_message); const server_signature = await hmac(server_key, auth_message);
@ -897,29 +924,28 @@ function wire_impl(
readonly name = `__st${st_ids++}`; readonly name = `__st${st_ids++}`;
constructor(readonly query: string) {} constructor(readonly query: string) {}
parse_task: Promise<{ #parse_task: Promise<{
ser_params: ParameterSerializer; ser_params: ParameterSerializer;
Row: RowConstructor; Row: RowConstructor;
}> | null = null; }> | null = null;
parse() { parse() {
return (this.parse_task ??= this.#parse()); return (this.#parse_task ??= this.#parse());
} }
async #parse() { async #parse() {
try { try {
const { name, query } = this; const { name, query } = this;
return await pipeline( return await pipeline(
() => () => {
write(msg_PD, { write(Parse, { statement: name, query, param_types: [] });
P: { statement: name, query, param_types: [] }, write(Describe, { which: "S", name });
D: { which: "S", name }, },
}),
async () => { async () => {
await read(ParseComplete); await read(ParseComplete);
const ser_params = make_param_ser(await read(ParameterDescription)); const ser_params = make_param_ser(await read(ParameterDescription));
const msg = msg_check_err(await read_raw()); const msg = msg_check_err(await read_msg());
const Row = const Row =
msg_type(msg) === NoData.type msg_type(msg) === NoData.type
? EmptyRow ? EmptyRow
@ -929,13 +955,13 @@ function wire_impl(
} }
); );
} catch (e) { } catch (e) {
throw ((this.parse_task = null), e); throw ((this.#parse_task = null), e);
} }
} }
portals = 0; #portals = 0;
portal() { portal() {
return `${this.name}_${this.portals++}`; return `${this.name}_${this.#portals++}`;
} }
} }
@ -944,6 +970,7 @@ function wire_impl(
(params: unknown[]): (string | null)[]; (params: unknown[]): (string | null)[];
} }
// makes function to serialize query parameters
function make_param_ser({ param_types }: ParameterDescription) { function make_param_ser({ param_types }: ParameterDescription) {
return jit.compiled<ParameterSerializer>`function ser_params(xs) { return jit.compiled<ParameterSerializer>`function ser_params(xs) {
return [ return [
@ -960,6 +987,7 @@ function wire_impl(
new (columns: (BinaryLike | null)[]): Row; new (columns: (BinaryLike | null)[]): Row;
} }
// makes function to create Row objects
const EmptyRow = make_row_ctor({ columns: [] }); const EmptyRow = make_row_ctor({ columns: [] });
function make_row_ctor({ columns }: RowDescription) { function make_row_ctor({ columns }: RowDescription) {
const Row = jit.compiled<RowConstructor>`function Row(xs) { const Row = jit.compiled<RowConstructor>`function Row(xs) {
@ -998,7 +1026,7 @@ function wire_impl(
stdout: WritableStream<Uint8Array> | null stdout: WritableStream<Uint8Array> | null
) { ) {
for (let rows = [], i = 0; ; ) { for (let rows = [], i = 0; ; ) {
const msg = msg_check_err(await read_raw()); const msg = msg_check_err(await read_msg());
switch (msg_type(msg)) { switch (msg_type(msg)) {
default: default:
case DataRow.type: case DataRow.type:
@ -1034,7 +1062,7 @@ function wire_impl(
if (stream !== null) { if (stream !== null) {
const writer = stream.getWriter(); const writer = stream.getWriter();
try { try {
for (let msg; msg_type((msg = await read_raw())) !== CopyDone.type; ) { for (let msg; msg_type((msg = await read_msg())) !== CopyDone.type; ) {
const { data } = ser_decode(CopyData, msg_check_err(msg)); const { data } = ser_decode(CopyData, msg_check_err(msg));
await writer.write(to_utf8(data)); await writer.write(to_utf8(data));
} }
@ -1046,7 +1074,7 @@ function wire_impl(
writer.releaseLock(); writer.releaseLock();
} }
} else { } else {
while (msg_type(msg_check_err(await read_raw())) !== CopyDone.type); while (msg_type(msg_check_err(await read_msg())) !== CopyDone.type);
} }
} }
@ -1055,16 +1083,16 @@ function wire_impl(
const reader = stream.getReader(); const reader = stream.getReader();
try { try {
for (let next; !(next = await reader.read()).done; ) for (let next; !(next = await reader.read()).done; )
await write(CopyData, { data: next.value }); write(CopyData, { data: next.value });
await write(CopyDone, {}); write(CopyDone, {});
} catch (e) { } catch (e) {
await write(CopyFail, { cause: String(e) }); write(CopyFail, { cause: String(e) });
throw e; throw e;
} finally { } finally {
reader.releaseLock(); reader.releaseLock();
} }
} else { } else {
await write(CopyDone, {}); write(CopyDone, {});
} }
} }
@ -1076,13 +1104,10 @@ function wire_impl(
log("debug", { query: query }, `executing simple query`); log("debug", { query: query }, `executing simple query`);
const { chunks, err } = await pipeline( const { chunks, err } = await pipeline(
async () => { () => (write(QueryMessage, { query }), write_copy_in(stdin)),
await write(QueryMessage, { query });
return write_copy_in(stdin);
},
async () => { async () => {
for (let chunks = [], err; ; ) { for (let chunks = [], err; ; ) {
const msg = await read_raw(); const msg = await read_msg();
switch (msg_type(msg)) { switch (msg_type(msg)) {
default: default:
case ReadyForQuery.type: case ReadyForQuery.type:
@ -1134,23 +1159,16 @@ function wire_impl(
try { try {
const { rows, tag } = await pipeline( const { rows, tag } = await pipeline(
async () => { async () => {
const B = { write(Bind, {
portal, portal,
statement: st.name, statement: st.name,
param_formats: [], param_formats: [],
param_values, param_values,
column_formats: [], column_formats: [],
}; });
const E = { portal, row_limit: 0 }; write(Execute, { portal, row_limit: 0 });
const C = { which: "P" as const, name: portal }; await write_copy_in(stdin);
write(Close, { which: "P" as const, name: portal });
if (stdin !== null) {
await write(msg_BE, { B, E });
await write_copy_in(stdin);
return write(Close, C);
} else {
return write(msg_BEcC, { B, E, c: {}, C });
}
}, },
async () => { async () => {
await read(BindComplete); await read(BindComplete);
@ -1193,22 +1211,16 @@ function wire_impl(
try { try {
let { done, rows, tag } = await pipeline( let { done, rows, tag } = await pipeline(
async () => { () => {
const B = { write(Bind, {
portal, portal,
statement: st.name, statement: st.name,
param_formats: [], param_formats: [],
param_values, param_values,
column_formats: [], column_formats: [],
}; });
const E = { portal, row_limit: chunk_size }; write(Execute, { portal, row_limit: chunk_size });
return write_copy_in(stdin);
if (stdin !== null) {
await write(msg_BE, { B, E });
return write_copy_in(stdin);
} else {
return write(msg_BEc, { B, E, c: {} });
}
}, },
async () => { async () => {
await read(BindComplete); await read(BindComplete);
@ -1358,7 +1370,7 @@ function wire_impl(
} }
}; };
return { params, auth, query, begin, listen, notify, close }; return { params, connect, query, begin, listen, notify, close };
} }
export interface PoolOptions extends WireOptions { export interface PoolOptions extends WireOptions {
@ -1517,9 +1529,8 @@ function pool_impl(
}; };
async function connect() { async function connect() {
const { host, port } = options; const wire = new PoolWire(options);
const wire = new PoolWire(await socket_connect(host, port), options); await wire.connect(), all.add(wire);
await wire.connected, all.add(wire);
const { connection_id } = wire; const { connection_id } = wire;
return wire return wire
.on("log", (l, c, s) => pool.emit("log", l, { ...c, connection_id }, s)) .on("log", (l, c, s) => pool.emit("log", l, { ...c, connection_id }, s))