diff --git a/wire.ts b/wire.ts index 00d18c2..13744da 100644 --- a/wire.ts +++ b/wire.ts @@ -1,5 +1,6 @@ import { type BinaryLike, + buf_concat, buf_concat_fast, buf_eq, buf_xor, @@ -7,7 +8,9 @@ import { from_base64, from_utf8, jit, + type Receiver, semaphore, + type Sender, to_base64, to_utf8, TypedEmitter, @@ -475,57 +478,38 @@ export interface Channel } export async function wire_connect(options: WireOptions) { - const { host, port } = options; - const wire = new Wire(await socket_connect(host, port), options); - return await wire.connected, wire; -} - -async function socket_connect(hostname: string, port: number) { - if (hostname.startsWith("/")) { - const path = join(hostname, `.s.PGSQL.${port}`); - return await Deno.connect({ transport: "unix", path }); - } else { - const socket = await Deno.connect({ transport: "tcp", hostname, port }); - return socket.setNoDelay(), socket.setKeepAlive(), socket; - } + const wire = new Wire(options); + return await wire.connect(), wire; } export class Wire extends TypedEmitter implements Disposable { - readonly #socket; readonly #params; - readonly #auth; - readonly #connected; + readonly #connect; readonly #query; readonly #begin; readonly #listen; readonly #notify; readonly #close; - get socket() { - return this.#socket; - } - get params() { return this.#params; } - get connected() { - return this.#connected; - } - - constructor(socket: Deno.Conn, options: WireOptions) { + constructor(options: WireOptions) { super(); ({ params: this.#params, - auth: this.#auth, + connect: this.#connect, query: this.#query, begin: this.#begin, listen: this.#listen, notify: this.#notify, close: this.#close, - } = wire_impl(this, socket, options)); - this.#socket = socket; - (this.#connected = this.#auth()).catch(close); + } = wire_impl(this, options)); + } + + connect() { + return this.#connect(); } query(sql: SqlFragment): Query; @@ -579,107 +563,144 @@ export class Wire extends TypedEmitter implements Disposable { } } -const msg_PD = object({ P: Parse, D: Describe }); -const msg_BE = object({ B: Bind, E: Execute }); -const msg_BEc = object({ B: Bind, E: Execute, c: CopyDone }); -const msg_BEcC = object({ B: Bind, E: Execute, c: CopyDone, C: Close }); +async function socket_connect(hostname: string, port: number) { + if (hostname.startsWith("/")) { + const path = join(hostname, `.s.PGSQL.${port}`); + return await Deno.connect({ transport: "unix", path }); + } else { + const socket = await Deno.connect({ transport: "tcp", hostname, port }); + return socket.setNoDelay(), socket.setKeepAlive(), socket; + } +} function wire_impl( wire: Wire, - socket: Deno.Conn, - { user, database, password, runtime_params, types }: WireOptions + { host, port, user, database, password, runtime_params, types }: WireOptions ) { + // current runtime parameters as reported by postgres const params: Parameters = Object.create(null); function log(level: LogLevel, ctx: object, msg: string) { wire.emit("log", level, ctx, msg); } + // wire supports re-connection; socket and read/write channels are null when closed + let socket: Deno.Conn | null = null; + let read_pop: Receiver | null = null; + let write_push: Sender | null = null; + async function read(type: Encoder) { - const msg = await read_recv(); - if (msg === null) throw new WireError(`connection closed`); - else return ser_decode(type, msg_check_err(msg)); + const msg = read_pop !== null ? await read_pop() : null; + if (msg !== null) return ser_decode(type, msg_check_err(msg)); + else throw new WireError(`connection closed`); } - async function read_raw() { - const msg = await read_recv(); - if (msg === null) throw new WireError(`connection closed`); - else return msg; + async function read_msg() { + const msg = read_pop !== null ? await read_pop() : null; + if (msg !== null) return msg; + else throw new WireError(`connection closed`); } - async function* read_socket() { - const buf = new Uint8Array(64 * 1024); - for (let n; (n = await socket.read(buf)) !== null; ) - yield buf.subarray(0, n); + // socket reader channel worker + async function read_socket(socket: Deno.Conn, push: Sender) { + const header_size = 5; + const read_buf = new Uint8Array(64 * 1024); // shared buffer for all socket reads + let buf = new Uint8Array(); // concatenated messages read so far + + for (let read; (read = await socket.read(read_buf)) !== null; ) { + buf = buf_concat_fast(buf, read_buf.subarray(0, read)); // push read bytes to buf + while (buf.length >= header_size) { + const size = ser_decode(Header, buf).length + 1; + if (buf.length < size) break; + const msg = buf.subarray(0, size); // shift one message from buf + buf = buf.subarray(size); + if (!handle_msg(msg)) push(msg); + } + } + + // there should be nothing left in buf if we gracefully exited + if (buf.length !== 0) throw new WireError(`unexpected end of stream`); } - const read_recv = channel.receiver(async function read(send) { - let err: unknown; - try { - let buf = new Uint8Array(); - for await (const chunk of read_socket()) { - buf = buf_concat_fast(buf, chunk); - - for (let n; (n = ser_decode(Header, buf).length + 1) <= buf.length; ) { - const msg = buf.subarray(0, n); - buf = buf.subarray(n); - - switch (msg_type(msg)) { - // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-ASYNC - case NoticeResponse.type: { - const { fields } = ser_decode(NoticeResponse, msg); - const notice = new PostgresError(fields); - log(severity_level(notice.severity), notice, notice.message); - wire.emit("notice", notice); - continue; - } - - case ParameterStatus.type: { - const { name, value } = ser_decode(ParameterStatus, msg); - const prev = params[name] ?? null; - Object.defineProperty(params, name, { - configurable: true, - enumerable: true, - value, - }); - wire.emit("parameter", name, value, prev); - continue; - } - - case NotificationResponse.type: { - const { channel, payload, process_id } = ser_decode( - NotificationResponse, - msg - ); - wire.emit("notify", channel, payload, process_id); - channels.get(channel)?.emit("notify", payload, process_id); - continue; - } - } - - send(msg); - } + function handle_msg(msg: Uint8Array) { + switch (msg_type(msg)) { + // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-ASYNC + case NoticeResponse.type: { + const { fields } = ser_decode(NoticeResponse, msg); + const notice = new PostgresError(fields); + log(severity_level(notice.severity), notice, notice.message); + wire.emit("notice", notice); + return true; } - if (buf.length !== 0) throw new WireError(`unexpected end of stream`); - } catch (e) { - throw ((err = e), e); - } finally { - wire.emit("close", err); - } - }); + case ParameterStatus.type: { + const { name, value } = ser_decode(ParameterStatus, msg); + const prev = params[name] ?? null; + Object.defineProperty(params, name, { + configurable: true, + enumerable: true, + value, + }); + wire.emit("parameter", name, value, prev); + return true; + } - function write(type: Encoder, value: T) { - return write_raw(ser_encode(type, value)); + case NotificationResponse.type: { + const { channel, payload, process_id } = ser_decode( + NotificationResponse, + msg + ); + wire.emit("notify", channel, payload, process_id); + channels.get(channel)?.emit("notify", payload, process_id); + return true; + } + } + + return false; } - async function write_raw(buf: Uint8Array) { - for (let i = 0, n = buf.length; i < n; ) - i += await socket.write(buf.subarray(i)); + function write(type: Encoder, value: T) { + write_msg(ser_encode(type, value)); + } + + function write_msg(buf: Uint8Array) { + if (write_push !== null) write_push(buf); + else throw new WireError(`connection closed`); + } + + // socket writer channel worker + async function write_socket(socket: Deno.Conn, pop: Receiver) { + for (let buf; (buf = await pop()) !== null; ) { + const bufs = [buf]; // proactively dequeue more queued msgs synchronously, if any + for (let i = 1, buf; (buf = pop.try()) !== null; ) bufs[i++] = buf; + if (bufs.length !== 1) buf = buf_concat(bufs); // write queued msgs concatenated, reduce write syscalls + for (let i = 0, n = buf.length; i < n; ) + i += await socket.write(buf.subarray(i)); + } + } + + async function connect() { + using _rlock = await rlock(); + using _wlock = await wlock(); + close(new WireError(`reconnecting`)); + try { + const s = (socket = await socket_connect(host, port)); + read_pop = channel.receiver((push) => read_socket(s, push)); + write_push = channel.sender((pop) => write_socket(s, pop)); + await handle_auth(); // run auth with rw lock + } catch (e) { + throw (close(e), e); + } } function close(reason?: unknown) { - socket.close(), read_recv.close(reason); + socket?.close(), (socket = null); + read_pop?.close(reason), (read_pop = null); + write_push?.close(reason), (write_push = null); + for (const name of Object.keys(params)) + delete (params as Record)[name]; + st_cache.clear(), (st_ids = 0); + (tx_status = "I"), (tx_stack.length = 0); } // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-PIPELINING @@ -697,13 +718,13 @@ function wire_impl( } async function pipeline_read(r: () => T | PromiseLike) { - using _rlock = await rlock(); + using _lock = await rlock(); try { return await r(); } finally { try { let msg; - while (msg_type((msg = await read_raw())) !== ReadyForQuery.type); + while (msg_type((msg = await read_msg())) !== ReadyForQuery.type); ({ tx_status } = ser_decode(ReadyForQuery, msg)); } catch { // ignored @@ -712,12 +733,12 @@ function wire_impl( } async function pipeline_write(w: () => T | PromiseLike) { - using _wlock = await wlock(); + using _lock = await wlock(); try { return await w(); } finally { try { - await write(Sync, {}); + write(Sync, {}); } catch { // ignored } @@ -725,11 +746,9 @@ function wire_impl( } // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-START-UP - async function auth() { - using _rlock = await rlock(); - using _wlock = await wlock(); - - await write(StartupMessage, { + async function handle_auth() { + // always run within rw lock (see connect()) + write(StartupMessage, { version: 196608, params: { application_name: "pglue", @@ -744,7 +763,7 @@ function wire_impl( }); auth: for (;;) { - const msg = msg_check_err(await read_raw()); + const msg = msg_check_err(await read_msg()); switch (msg_type(msg)) { case NegotiateProtocolVersion.type: { const { bad_options } = ser_decode(NegotiateProtocolVersion, msg); @@ -762,7 +781,7 @@ function wire_impl( throw new WireError(`kerberos authentication is deprecated`); case 3: // AuthenticationCleartextPassword - await write(PasswordMessage, { password }); + write(PasswordMessage, { password }); continue; case 5: // AuthenticationMD5Password @@ -778,7 +797,7 @@ function wire_impl( // AuthenticationSASL case 10: - await auth_sasl(); + await handle_auth_sasl(); continue; default: @@ -786,8 +805,9 @@ function wire_impl( } } + // wait for ready ready: for (;;) { - const msg = msg_check_err(await read_raw()); + const msg = msg_check_err(await read_msg()); switch (msg_type(msg)) { case BackendKeyData.type: continue; // ignored @@ -797,11 +817,18 @@ function wire_impl( break ready; } } + + // re-listen previously registered channels + await Promise.all( + channels + .keys() + .map((name) => query(sql`listen ${sql.ident(name)}`).execute()) + ); } // https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256 // https://datatracker.ietf.org/doc/html/rfc5802 - async function auth_sasl() { + async function handle_auth_sasl() { const bits = 256; const hash = `SHA-${bits}`; const mechanism = `SCRAM-${hash}`; @@ -858,7 +885,7 @@ function wire_impl( )}`; const client_first_message_bare = `${username},${initial_nonce}`; const client_first_message = `${gs2_header}${client_first_message_bare}`; - await write(SASLInitialResponse, { mechanism, data: client_first_message }); + write(SASLInitialResponse, { mechanism, data: client_first_message }); const server_first_message_str = from_utf8( (await read(AuthenticationSASLContinue)).data @@ -877,7 +904,7 @@ function wire_impl( const client_proof = buf_xor(client_key, client_signature); const proof = `p=${to_base64(client_proof)}`; const client_final_message = `${client_final_message_without_proof},${proof}`; - await write(SASLResponse, { data: client_final_message }); + write(SASLResponse, { data: client_final_message }); const server_key = await hmac(salted_password, "Server Key"); const server_signature = await hmac(server_key, auth_message); @@ -897,29 +924,28 @@ function wire_impl( readonly name = `__st${st_ids++}`; constructor(readonly query: string) {} - parse_task: Promise<{ + #parse_task: Promise<{ ser_params: ParameterSerializer; Row: RowConstructor; }> | null = null; parse() { - return (this.parse_task ??= this.#parse()); + return (this.#parse_task ??= this.#parse()); } async #parse() { try { const { name, query } = this; return await pipeline( - () => - write(msg_PD, { - P: { statement: name, query, param_types: [] }, - D: { which: "S", name }, - }), + () => { + write(Parse, { statement: name, query, param_types: [] }); + write(Describe, { which: "S", name }); + }, async () => { await read(ParseComplete); const ser_params = make_param_ser(await read(ParameterDescription)); - const msg = msg_check_err(await read_raw()); + const msg = msg_check_err(await read_msg()); const Row = msg_type(msg) === NoData.type ? EmptyRow @@ -929,13 +955,13 @@ function wire_impl( } ); } catch (e) { - throw ((this.parse_task = null), e); + throw ((this.#parse_task = null), e); } } - portals = 0; + #portals = 0; portal() { - return `${this.name}_${this.portals++}`; + return `${this.name}_${this.#portals++}`; } } @@ -944,6 +970,7 @@ function wire_impl( (params: unknown[]): (string | null)[]; } + // makes function to serialize query parameters function make_param_ser({ param_types }: ParameterDescription) { return jit.compiled`function ser_params(xs) { return [ @@ -960,6 +987,7 @@ function wire_impl( new (columns: (BinaryLike | null)[]): Row; } + // makes function to create Row objects const EmptyRow = make_row_ctor({ columns: [] }); function make_row_ctor({ columns }: RowDescription) { const Row = jit.compiled`function Row(xs) { @@ -998,7 +1026,7 @@ function wire_impl( stdout: WritableStream | null ) { for (let rows = [], i = 0; ; ) { - const msg = msg_check_err(await read_raw()); + const msg = msg_check_err(await read_msg()); switch (msg_type(msg)) { default: case DataRow.type: @@ -1034,7 +1062,7 @@ function wire_impl( if (stream !== null) { const writer = stream.getWriter(); try { - for (let msg; msg_type((msg = await read_raw())) !== CopyDone.type; ) { + for (let msg; msg_type((msg = await read_msg())) !== CopyDone.type; ) { const { data } = ser_decode(CopyData, msg_check_err(msg)); await writer.write(to_utf8(data)); } @@ -1046,7 +1074,7 @@ function wire_impl( writer.releaseLock(); } } else { - while (msg_type(msg_check_err(await read_raw())) !== CopyDone.type); + while (msg_type(msg_check_err(await read_msg())) !== CopyDone.type); } } @@ -1055,16 +1083,16 @@ function wire_impl( const reader = stream.getReader(); try { for (let next; !(next = await reader.read()).done; ) - await write(CopyData, { data: next.value }); - await write(CopyDone, {}); + write(CopyData, { data: next.value }); + write(CopyDone, {}); } catch (e) { - await write(CopyFail, { cause: String(e) }); + write(CopyFail, { cause: String(e) }); throw e; } finally { reader.releaseLock(); } } else { - await write(CopyDone, {}); + write(CopyDone, {}); } } @@ -1076,13 +1104,10 @@ function wire_impl( log("debug", { query: query }, `executing simple query`); const { chunks, err } = await pipeline( - async () => { - await write(QueryMessage, { query }); - return write_copy_in(stdin); - }, + () => (write(QueryMessage, { query }), write_copy_in(stdin)), async () => { for (let chunks = [], err; ; ) { - const msg = await read_raw(); + const msg = await read_msg(); switch (msg_type(msg)) { default: case ReadyForQuery.type: @@ -1134,23 +1159,16 @@ function wire_impl( try { const { rows, tag } = await pipeline( async () => { - const B = { + write(Bind, { portal, statement: st.name, param_formats: [], param_values, column_formats: [], - }; - const E = { portal, row_limit: 0 }; - const C = { which: "P" as const, name: portal }; - - if (stdin !== null) { - await write(msg_BE, { B, E }); - await write_copy_in(stdin); - return write(Close, C); - } else { - return write(msg_BEcC, { B, E, c: {}, C }); - } + }); + write(Execute, { portal, row_limit: 0 }); + await write_copy_in(stdin); + write(Close, { which: "P" as const, name: portal }); }, async () => { await read(BindComplete); @@ -1193,22 +1211,16 @@ function wire_impl( try { let { done, rows, tag } = await pipeline( - async () => { - const B = { + () => { + write(Bind, { portal, statement: st.name, param_formats: [], param_values, column_formats: [], - }; - const E = { portal, row_limit: chunk_size }; - - if (stdin !== null) { - await write(msg_BE, { B, E }); - return write_copy_in(stdin); - } else { - return write(msg_BEc, { B, E, c: {} }); - } + }); + write(Execute, { portal, row_limit: chunk_size }); + return write_copy_in(stdin); }, async () => { await read(BindComplete); @@ -1358,7 +1370,7 @@ function wire_impl( } }; - return { params, auth, query, begin, listen, notify, close }; + return { params, connect, query, begin, listen, notify, close }; } export interface PoolOptions extends WireOptions { @@ -1517,9 +1529,8 @@ function pool_impl( }; async function connect() { - const { host, port } = options; - const wire = new PoolWire(await socket_connect(host, port), options); - await wire.connected, all.add(wire); + const wire = new PoolWire(options); + await wire.connect(), all.add(wire); const { connection_id } = wire; return wire .on("log", (l, c, s) => pool.emit("log", l, { ...c, connection_id }, s))