pglue/wire.ts

1573 lines
41 KiB
TypeScript
Raw Normal View History

2025-01-07 22:12:30 +11:00
import {
type BinaryLike,
2025-01-11 06:02:32 +11:00
buf_concat,
2025-01-07 22:12:30 +11:00
buf_concat_fast,
buf_eq,
buf_xor,
channel,
from_base64,
from_utf8,
jit,
2025-01-11 06:02:32 +11:00
type Receiver,
2025-01-07 22:12:30 +11:00
semaphore,
2025-01-11 06:02:32 +11:00
type Sender,
2025-01-07 22:12:30 +11:00
to_base64,
to_utf8,
TypedEmitter,
} from "./lstd.ts";
import {
array,
byten,
2025-01-10 04:25:50 +11:00
bytes_lp,
bytes,
2025-01-07 22:12:30 +11:00
char,
cstring,
type Encoder,
object,
type ObjectEncoder,
type ObjectShape,
oneof,
ser_decode,
ser_encode,
2025-01-09 04:37:30 +11:00
i16,
2025-01-07 22:12:30 +11:00
i32,
2025-01-09 04:37:30 +11:00
i8,
type EncoderType,
2025-01-07 22:12:30 +11:00
} from "./ser.ts";
import {
type CommandResult,
2025-01-11 02:00:05 +11:00
format,
is_sql,
2025-01-07 22:12:30 +11:00
Query,
type ResultStream,
type Row,
sql,
type SqlFragment,
type SqlTypeMap,
text,
2025-01-07 22:12:30 +11:00
} from "./query.ts";
import { join } from "jsr:@std/path@^1.0.8";
export class WireError extends Error {
override get name() {
return this.constructor.name;
}
}
export class PostgresError extends WireError {
readonly severity;
readonly code;
readonly detail;
readonly hint;
readonly position;
readonly where;
readonly schema;
readonly table;
readonly column;
readonly data_type;
readonly constraint;
readonly file;
readonly line;
readonly routine;
constructor(fields: Partial<Record<string, string>>) {
// https://www.postgresql.org/docs/current/protocol-error-fields.html#PROTOCOL-ERROR-FIELDS
const { S, V, C, M, D, H, P, W, s, t, c, d, n, F, L, R } = fields;
super(M ?? "unknown error");
this.severity = V ?? S ?? "ERROR";
this.code = C ?? "XX000";
this.detail = D ?? null;
this.hint = H ?? null;
this.position = P ?? null;
this.where = W ?? null;
this.schema = s ?? null;
this.table = t ?? null;
this.column = c ?? null;
this.data_type = d ?? null;
this.constraint = n ?? null;
this.file = F ?? null;
this.line = L ? parseInt(L, 10) : null;
this.routine = R ?? null;
}
}
function severity_level(s: string): LogLevel {
switch (s) {
case "DEBUG":
return "debug";
default:
case "LOG":
case "INFO":
case "NOTICE":
return "info";
case "WARNING":
return "warn";
case "ERROR":
return "error";
case "FATAL":
case "PANIC":
return "fatal";
}
}
interface MessageEncoder<T extends string, S extends ObjectShape>
extends ObjectEncoder<S> {
readonly type: T;
}
function msg<T extends string, S extends ObjectShape>(
type: T,
shape: S
): MessageEncoder<T, S> {
const header_size = type !== "" ? 5 : 4;
2025-01-09 04:37:30 +11:00
const ty = type !== "" ? oneof(char(i8), type) : null;
2025-01-07 22:12:30 +11:00
const fields = object(shape);
return {
2025-01-10 04:11:17 +11:00
const_size: null,
2025-01-07 22:12:30 +11:00
get type() {
return type;
},
allocs(msg) {
return header_size + fields.allocs(msg);
},
encode(buf, cur, msg) {
ty?.encode(buf, cur, type);
const { i } = cur;
cur.i += 4;
fields.encode(buf, cur, msg);
i32.encode(buf, { i }, cur.i - i);
},
decode(buf, cur) {
ty?.decode(buf, cur);
const n = i32.decode(buf, cur) - 4;
return fields.decode(buf.subarray(cur.i, (cur.i += n)), { i: 0 });
},
};
}
function msg_type({ 0: n }: Uint8Array) {
return n === 0 ? "" : String.fromCharCode(n);
}
function msg_check_err(msg: Uint8Array) {
if (msg_type(msg) === ErrorResponse.type) {
const { fields } = ser_decode(ErrorResponse, msg);
throw new PostgresError(fields);
} else {
return msg;
}
}
// https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS
export const Header = object({
2025-01-09 04:37:30 +11:00
type: char(i8),
2025-01-07 22:12:30 +11:00
length: i32,
});
export const Authentication = msg("R", {
status: i32,
});
export const AuthenticationOk = msg("R", {
status: oneof(i32, 0 as const),
});
export const AuthenticationKerberosV5 = msg("R", {
status: oneof(i32, 2 as const),
});
export const AuthenticationCleartextPassword = msg("R", {
status: oneof(i32, 3 as const),
});
export const AuthenticationMD5Password = msg("R", {
status: oneof(i32, 5 as const),
salt: byten(4),
});
export const AuthenticationGSS = msg("R", {
status: oneof(i32, 7 as const),
});
export const AuthenticationGSSContinue = msg("R", {
status: oneof(i32, 8 as const),
2025-01-10 04:25:50 +11:00
data: bytes,
2025-01-07 22:12:30 +11:00
});
export const AuthenticationSSPI = msg("R", {
status: oneof(i32, 9 as const),
});
export const AuthenticationSASL = msg("R", {
status: oneof(i32, 10 as const),
mechanisms: {
const_size: null,
allocs(x) {
let size = 1;
for (const s of x) size += cstring.allocs(s);
return size;
},
encode(buf, cur, x) {
for (const s of x) cstring.encode(buf, cur, s);
cstring.encode(buf, cur, "");
},
decode(buf, cur) {
const x = [];
for (let s; (s = cstring.decode(buf, cur)) !== ""; ) x.push(s);
return x;
},
} satisfies Encoder<string[]>,
});
export const AuthenticationSASLContinue = msg("R", {
status: oneof(i32, 11 as const),
2025-01-10 04:25:50 +11:00
data: bytes,
2025-01-07 22:12:30 +11:00
});
export const AuthenticationSASLFinal = msg("R", {
status: oneof(i32, 12 as const),
2025-01-10 04:25:50 +11:00
data: bytes,
2025-01-07 22:12:30 +11:00
});
export const BackendKeyData = msg("K", {
process_id: i32,
secret_key: i32,
});
export const Bind = msg("B", {
portal: cstring,
statement: cstring,
2025-01-09 04:37:30 +11:00
param_formats: array(i16, i16),
2025-01-10 04:25:50 +11:00
param_values: array(i16, bytes_lp),
2025-01-09 04:37:30 +11:00
column_formats: array(i16, i16),
2025-01-07 22:12:30 +11:00
});
export const BindComplete = msg("2", {});
export const CancelRequest = msg("", {
code: oneof(i32, 80877102 as const),
process_id: i32,
secret_key: i32,
});
export const Close = msg("C", {
2025-01-09 04:37:30 +11:00
which: oneof(char(i8), "S" as const, "P" as const),
2025-01-07 22:12:30 +11:00
name: cstring,
});
export const CloseComplete = msg("3", {});
export const CommandComplete = msg("C", { tag: cstring });
2025-01-10 04:25:50 +11:00
export const CopyData = msg("d", { data: bytes });
2025-01-07 22:12:30 +11:00
export const CopyDone = msg("c", {});
export const CopyFail = msg("f", { cause: cstring });
export const CopyInResponse = msg("G", {
2025-01-09 04:37:30 +11:00
format: i8,
column_formats: array(i16, i16),
2025-01-07 22:12:30 +11:00
});
export const CopyOutResponse = msg("H", {
2025-01-09 04:37:30 +11:00
format: i8,
column_formats: array(i16, i16),
2025-01-07 22:12:30 +11:00
});
export const CopyBothResponse = msg("W", {
2025-01-09 04:37:30 +11:00
format: i8,
column_formats: array(i16, i16),
2025-01-07 22:12:30 +11:00
});
export const DataRow = msg("D", {
2025-01-10 04:25:50 +11:00
column_values: array(i16, bytes_lp),
2025-01-07 22:12:30 +11:00
});
export const Describe = msg("D", {
2025-01-09 04:37:30 +11:00
which: oneof(char(i8), "S" as const, "P" as const),
2025-01-07 22:12:30 +11:00
name: cstring,
});
export const EmptyQueryResponse = msg("I", {});
2025-01-09 04:37:30 +11:00
const err_field = char(i8);
2025-01-07 22:12:30 +11:00
const err_fields: Encoder<Record<string, string>> = {
const_size: null,
allocs(x) {
let size = 1;
for (const { 0: key, 1: value } of Object.entries(x)) {
size += err_field.allocs(key) + cstring.allocs(value);
}
return size;
},
encode(buf, cur, x) {
for (const { 0: key, 1: value } of Object.entries(x)) {
err_field.encode(buf, cur, key), cstring.encode(buf, cur, value);
}
err_field.encode(buf, cur, "");
},
decode(buf, cur) {
const x: Record<string, string> = {};
for (let key; (key = err_field.decode(buf, cur)) !== ""; ) {
x[key] = cstring.decode(buf, cur);
}
return x;
},
};
export const ErrorResponse = msg("E", {
fields: err_fields,
});
export const Execute = msg("E", {
portal: cstring,
row_limit: i32,
});
export const Flush = msg("H", {});
export const FunctionCall = msg("F", {
oid: i32,
2025-01-09 04:37:30 +11:00
arg_formats: array(i16, i16),
2025-01-10 04:25:50 +11:00
arg_values: array(i16, bytes_lp),
2025-01-09 04:37:30 +11:00
result_format: i16,
2025-01-07 22:12:30 +11:00
});
export const FunctionCallResponse = msg("V", {
2025-01-10 04:25:50 +11:00
result_value: bytes_lp,
2025-01-07 22:12:30 +11:00
});
export const NegotiateProtocolVersion = msg("v", {
minor_ver: i32,
bad_options: array(i32, cstring),
});
export const NoData = msg("n", {});
export const NoticeResponse = msg("N", {
fields: err_fields,
});
export const NotificationResponse = msg("A", {
process_id: i32,
channel: cstring,
payload: cstring,
});
export const ParameterDescription = msg("t", {
2025-01-09 04:37:30 +11:00
param_types: array(i16, i32),
2025-01-07 22:12:30 +11:00
});
export const ParameterStatus = msg("S", {
name: cstring,
value: cstring,
});
export const Parse = msg("P", {
statement: cstring,
query: cstring,
2025-01-09 04:37:30 +11:00
param_types: array(i16, i32),
2025-01-07 22:12:30 +11:00
});
export const ParseComplete = msg("1", {});
export const PasswordMessage = msg("p", {
password: cstring,
});
export const PortalSuspended = msg("s", {});
export const QueryMessage = msg("Q", {
query: cstring,
});
export const ReadyForQuery = msg("Z", {
2025-01-09 04:37:30 +11:00
tx_status: oneof(char(i8), "I" as const, "T" as const, "E" as const),
2025-01-07 22:12:30 +11:00
});
export const RowDescription = msg("T", {
columns: array(
2025-01-09 04:37:30 +11:00
i16,
2025-01-07 22:12:30 +11:00
object({
name: cstring,
table_oid: i32,
2025-01-09 04:37:30 +11:00
table_column: i16,
2025-01-07 22:12:30 +11:00
type_oid: i32,
2025-01-09 04:37:30 +11:00
type_size: i16,
2025-01-07 22:12:30 +11:00
type_modifier: i32,
2025-01-09 04:37:30 +11:00
format: i16,
2025-01-07 22:12:30 +11:00
})
),
});
export const SASLInitialResponse = msg("p", {
mechanism: cstring,
2025-01-10 04:25:50 +11:00
data: bytes_lp,
2025-01-07 22:12:30 +11:00
});
export const SASLResponse = msg("p", {
2025-01-10 04:25:50 +11:00
data: bytes,
2025-01-07 22:12:30 +11:00
});
export const StartupMessage = msg("", {
version: oneof(i32, 196608 as const),
params: {
const_size: null,
allocs(x) {
let size = 1;
for (const { 0: key, 1: value } of Object.entries(x)) {
size += cstring.allocs(key) + cstring.allocs(value);
}
return size;
},
encode(buf, cur, x) {
for (const { 0: key, 1: value } of Object.entries(x)) {
cstring.encode(buf, cur, key), cstring.encode(buf, cur, value);
}
2025-01-09 04:37:30 +11:00
i8.encode(buf, cur, 0);
2025-01-07 22:12:30 +11:00
},
decode(buf, cur) {
const x: Record<string, string> = {};
for (let key; (key = cstring.decode(buf, cur)) !== ""; ) {
x[key] = cstring.decode(buf, cur);
}
return x;
},
} satisfies Encoder<Record<string, string>>,
});
export const Sync = msg("S", {});
export const Terminate = msg("X", {});
export type LogLevel = "debug" | "info" | "warn" | "error" | "fatal";
export interface Parameters extends Readonly<Partial<Record<string, string>>> {}
export interface WireOptions {
readonly host: string;
readonly port: number;
readonly user: string;
readonly password: string;
readonly database: string | null;
readonly runtime_params: Record<string, string>;
readonly types: SqlTypeMap;
2025-01-07 22:12:30 +11:00
}
export type WireEvents = {
log(level: LogLevel, ctx: object, msg: string): void;
notice(notice: PostgresError): void;
parameter(name: string, value: string, prev: string | null): void;
notify(channel: string, payload: string, process_id: number): void;
close(reason?: unknown): void;
};
export interface Transaction extends CommandResult, AsyncDisposable {
readonly open: boolean;
commit(): Promise<CommandResult>;
rollback(): Promise<CommandResult>;
}
export type ChannelEvents = { notify: NotificationHandler };
export type NotificationHandler = (payload: string, process_id: number) => void;
export interface Channel
extends TypedEmitter<ChannelEvents>,
CommandResult,
AsyncDisposable {
readonly name: string;
readonly open: boolean;
notify(payload: string): Promise<CommandResult>;
unlisten(): Promise<CommandResult>;
}
export async function wire_connect(options: WireOptions) {
2025-01-11 06:02:32 +11:00
const wire = new Wire(options);
return await wire.connect(), wire;
2025-01-07 22:12:30 +11:00
}
export class Wire extends TypedEmitter<WireEvents> implements Disposable {
readonly #params;
2025-01-11 06:02:32 +11:00
readonly #connect;
2025-01-07 22:12:30 +11:00
readonly #query;
readonly #begin;
readonly #listen;
readonly #notify;
readonly #close;
get params() {
return this.#params;
}
2025-01-11 06:02:32 +11:00
constructor(options: WireOptions) {
2025-01-07 22:12:30 +11:00
super();
({
params: this.#params,
2025-01-11 06:02:32 +11:00
connect: this.#connect,
2025-01-07 22:12:30 +11:00
query: this.#query,
begin: this.#begin,
listen: this.#listen,
notify: this.#notify,
close: this.#close,
2025-01-11 06:02:32 +11:00
} = wire_impl(this, options));
}
connect() {
return this.#connect();
2025-01-07 22:12:30 +11:00
}
2025-01-11 00:15:19 +11:00
query<T = Row>(sql: SqlFragment): Query<T>;
query<T = Row>(s: TemplateStringsArray, ...xs: unknown[]): Query<T>;
2025-01-07 22:12:30 +11:00
query(s: TemplateStringsArray | SqlFragment, ...xs: unknown[]) {
return this.#query(is_sql(s) ? s : sql(s, ...xs));
}
begin(): Promise<Transaction>;
begin<T>(f: (wire: this, tx: Transaction) => T | PromiseLike<T>): Promise<T>;
async begin(f?: (wire: this, tx: Transaction) => unknown) {
if (typeof f !== "undefined") {
await using tx = await this.#begin();
const value = await f(this, tx);
if (tx.open) await tx.commit();
return value;
2025-01-07 22:12:30 +11:00
} else {
return this.#begin();
}
}
async listen(channel: string, ...fs: NotificationHandler[]) {
const ch = await this.#listen(channel);
for (const f of fs) ch.on("notify", f);
return ch;
}
notify(channel: string, payload: string) {
return this.#notify(channel, payload);
}
2025-01-11 00:15:19 +11:00
async get(param: string) {
2025-01-07 22:12:30 +11:00
return (
2025-01-11 00:15:19 +11:00
await this.query`select current_setting(${param}, true)`
2025-01-07 22:12:30 +11:00
.map(([s]) => String(s))
.first_or(null)
)[0];
}
async set(param: string, value: string, local = false) {
return await this
.query`select set_config(${param}, ${value}, ${local})`.execute();
}
close(reason?: unknown) {
this.#close(reason);
}
[Symbol.dispose]() {
this.close();
}
}
2025-01-11 06:02:32 +11:00
async function socket_connect(hostname: string, port: number) {
if (hostname.startsWith("/")) {
const path = join(hostname, `.s.PGSQL.${port}`);
return await Deno.connect({ transport: "unix", path });
} else {
const socket = await Deno.connect({ transport: "tcp", hostname, port });
return socket.setNoDelay(), socket.setKeepAlive(), socket;
}
}
2025-01-11 00:46:17 +11:00
2025-01-07 22:12:30 +11:00
function wire_impl(
wire: Wire,
2025-01-11 06:02:32 +11:00
{ host, port, user, database, password, runtime_params, types }: WireOptions
2025-01-07 22:12:30 +11:00
) {
2025-01-11 06:02:32 +11:00
// current runtime parameters as reported by postgres
2025-01-07 22:12:30 +11:00
const params: Parameters = Object.create(null);
function log(level: LogLevel, ctx: object, msg: string) {
wire.emit("log", level, ctx, msg);
}
2025-01-11 06:02:32 +11:00
// wire supports re-connection; socket and read/write channels are null when closed
let socket: Deno.Conn | null = null;
let read_pop: Receiver<Uint8Array> | null = null;
let write_push: Sender<Uint8Array> | null = null;
2025-01-07 22:12:30 +11:00
async function read<T>(type: Encoder<T>) {
2025-01-11 06:02:32 +11:00
const msg = read_pop !== null ? await read_pop() : null;
if (msg !== null) return ser_decode(type, msg_check_err(msg));
else throw new WireError(`connection closed`);
2025-01-07 22:12:30 +11:00
}
2025-01-11 06:02:32 +11:00
async function read_msg() {
const msg = read_pop !== null ? await read_pop() : null;
if (msg !== null) return msg;
else throw new WireError(`connection closed`);
2025-01-07 22:12:30 +11:00
}
2025-01-11 06:02:32 +11:00
// socket reader channel worker
async function read_socket(socket: Deno.Conn, push: Sender<Uint8Array>) {
const header_size = 5;
const read_buf = new Uint8Array(64 * 1024); // shared buffer for all socket reads
let buf = new Uint8Array(); // concatenated messages read so far
for (let read; (read = await socket.read(read_buf)) !== null; ) {
buf = buf_concat_fast(buf, read_buf.subarray(0, read)); // push read bytes to buf
while (buf.length >= header_size) {
const size = ser_decode(Header, buf).length + 1;
if (buf.length < size) break;
const msg = buf.subarray(0, size); // shift one message from buf
buf = buf.subarray(size);
if (!handle_msg(msg)) push(msg);
}
}
// there should be nothing left in buf if we gracefully exited
if (buf.length !== 0) throw new WireError(`unexpected end of stream`);
2025-01-07 22:12:30 +11:00
}
2025-01-11 06:02:32 +11:00
function handle_msg(msg: Uint8Array) {
switch (msg_type(msg)) {
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-ASYNC
case NoticeResponse.type: {
const { fields } = ser_decode(NoticeResponse, msg);
const notice = new PostgresError(fields);
log(severity_level(notice.severity), notice, notice.message);
wire.emit("notice", notice);
return true;
}
2025-01-07 22:12:30 +11:00
2025-01-11 06:02:32 +11:00
case ParameterStatus.type: {
const { name, value } = ser_decode(ParameterStatus, msg);
const prev = params[name] ?? null;
Object.defineProperty(params, name, {
configurable: true,
enumerable: true,
value,
});
wire.emit("parameter", name, value, prev);
return true;
}
2025-01-07 22:12:30 +11:00
2025-01-11 06:02:32 +11:00
case NotificationResponse.type: {
const { channel, payload, process_id } = ser_decode(
NotificationResponse,
msg
);
wire.emit("notify", channel, payload, process_id);
channels.get(channel)?.emit("notify", payload, process_id);
return true;
}
}
2025-01-07 22:12:30 +11:00
2025-01-11 06:02:32 +11:00
return false;
}
2025-01-07 22:12:30 +11:00
2025-01-11 06:02:32 +11:00
function write<T>(type: Encoder<T>, value: T) {
write_msg(ser_encode(type, value));
}
2025-01-07 22:12:30 +11:00
2025-01-11 06:02:32 +11:00
function write_msg(buf: Uint8Array) {
if (write_push !== null) write_push(buf);
else throw new WireError(`connection closed`);
}
2025-01-07 22:12:30 +11:00
2025-01-11 06:02:32 +11:00
// socket writer channel worker
async function write_socket(socket: Deno.Conn, pop: Receiver<Uint8Array>) {
for (let buf; (buf = await pop()) !== null; ) {
const bufs = [buf]; // proactively dequeue more queued msgs synchronously, if any
for (let i = 1, buf; (buf = pop.try()) !== null; ) bufs[i++] = buf;
if (bufs.length !== 1) buf = buf_concat(bufs); // write queued msgs concatenated, reduce write syscalls
for (let i = 0, n = buf.length; i < n; )
i += await socket.write(buf.subarray(i));
2025-01-07 22:12:30 +11:00
}
}
2025-01-11 06:02:32 +11:00
async function connect() {
using _rlock = await rlock();
using _wlock = await wlock();
close(new WireError(`reconnecting`));
try {
const s = (socket = await socket_connect(host, port));
read_pop = channel.receiver((push) => read_socket(s, push));
write_push = channel.sender((pop) => write_socket(s, pop));
await handle_auth(); // run auth with rw lock
} catch (e) {
throw (close(e), e);
}
2025-01-07 22:12:30 +11:00
}
function close(reason?: unknown) {
2025-01-11 06:02:32 +11:00
socket?.close(), (socket = null);
read_pop?.close(reason), (read_pop = null);
write_push?.close(reason), (write_push = null);
for (const name of Object.keys(params))
delete (params as Record<string, string>)[name];
st_cache.clear(), (st_ids = 0);
(tx_status = "I"), (tx_stack.length = 0);
2025-01-07 22:12:30 +11:00
}
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-PIPELINING
const rlock = semaphore();
const wlock = semaphore();
2025-01-07 22:12:30 +11:00
function pipeline<T>(
w: () => void | PromiseLike<void>,
r: () => T | PromiseLike<T>
) {
return new Promise<T>((res, rej) => {
pipeline_write(w).catch(rej);
pipeline_read(r).then(res, rej);
});
}
async function pipeline_read<T>(r: () => T | PromiseLike<T>) {
2025-01-11 06:02:32 +11:00
using _lock = await rlock();
try {
return await r();
} finally {
2025-01-07 22:12:30 +11:00
try {
let msg;
2025-01-11 06:02:32 +11:00
while (msg_type((msg = await read_msg())) !== ReadyForQuery.type);
({ tx_status } = ser_decode(ReadyForQuery, msg));
} catch {
// ignored
2025-01-07 22:12:30 +11:00
}
}
2025-01-07 22:12:30 +11:00
}
async function pipeline_write<T>(w: () => T | PromiseLike<T>) {
2025-01-11 06:02:32 +11:00
using _lock = await wlock();
try {
return await w();
} finally {
2025-01-07 22:12:30 +11:00
try {
2025-01-11 06:02:32 +11:00
write(Sync, {});
} catch {
// ignored
2025-01-07 22:12:30 +11:00
}
}
2025-01-07 22:12:30 +11:00
}
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-START-UP
2025-01-11 06:02:32 +11:00
async function handle_auth() {
// always run within rw lock (see connect())
write(StartupMessage, {
2025-01-07 22:12:30 +11:00
version: 196608,
params: {
application_name: "pglue",
idle_session_timeout: "0",
...runtime_params,
user,
database: database ?? user,
bytea_output: "hex",
client_encoding: "utf8",
DateStyle: "ISO",
},
});
auth: for (;;) {
2025-01-11 06:02:32 +11:00
const msg = msg_check_err(await read_msg());
2025-01-07 22:12:30 +11:00
switch (msg_type(msg)) {
case NegotiateProtocolVersion.type: {
const { bad_options } = ser_decode(NegotiateProtocolVersion, msg);
log("info", { bad_options }, `unrecognised protocol options`);
continue;
}
}
const { status } = ser_decode(Authentication, msg);
switch (status) {
case 0: // AuthenticationOk
break auth;
case 2: // AuthenticationKerberosV5
throw new WireError(`kerberos authentication is deprecated`);
case 3: // AuthenticationCleartextPassword
2025-01-11 06:02:32 +11:00
write(PasswordMessage, { password });
2025-01-07 22:12:30 +11:00
continue;
case 5: // AuthenticationMD5Password
throw new WireError(
`md5 password authentication is deprecated (prefer scram-sha-256 instead)`
);
case 7: // AuthenticationGSS
throw new WireError(`gssapi authentication is not supported`);
case 9: // AuthenticationSSPI
throw new WireError(`sspi authentication is not supported`);
// AuthenticationSASL
case 10:
2025-01-11 06:02:32 +11:00
await handle_auth_sasl();
2025-01-07 22:12:30 +11:00
continue;
default:
throw new WireError(`invalid authentication status ${status}`);
}
}
2025-01-11 06:02:32 +11:00
// wait for ready
2025-01-07 22:12:30 +11:00
ready: for (;;) {
2025-01-11 06:02:32 +11:00
const msg = msg_check_err(await read_msg());
2025-01-07 22:12:30 +11:00
switch (msg_type(msg)) {
case BackendKeyData.type:
continue; // ignored
default:
ser_decode(ReadyForQuery, msg);
break ready;
}
}
2025-01-11 06:02:32 +11:00
// re-listen previously registered channels
await Promise.all(
channels
.keys()
.map((name) => query(sql`listen ${sql.ident(name)}`).execute())
);
2025-01-07 22:12:30 +11:00
}
// https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256
// https://datatracker.ietf.org/doc/html/rfc5802
2025-01-11 06:02:32 +11:00
async function handle_auth_sasl() {
2025-01-07 22:12:30 +11:00
const bits = 256;
const hash = `SHA-${bits}`;
const mechanism = `SCRAM-${hash}`;
async function hmac(key: Uint8Array, str: string | Uint8Array) {
return new Uint8Array(
await crypto.subtle.sign(
"HMAC",
await crypto.subtle.importKey(
"raw",
key,
{ name: "HMAC", hash },
false,
["sign"]
),
to_utf8(str)
)
);
}
async function h(str: string | Uint8Array) {
return new Uint8Array(await crypto.subtle.digest(hash, to_utf8(str)));
}
async function hi(str: string | Uint8Array, salt: Uint8Array, i: number) {
return new Uint8Array(
await crypto.subtle.deriveBits(
{ name: "PBKDF2", hash, salt, iterations: i },
await crypto.subtle.importKey("raw", to_utf8(str), "PBKDF2", false, [
"deriveBits",
]),
bits
)
);
}
function parse_attrs(s: string) {
const attrs: Partial<Record<string, string>> = {};
for (const entry of s.split(",")) {
const { 0: name, 1: value = "" } = entry.split("=", 2);
attrs[name] = value;
}
return attrs;
}
const gs2_cbind_flag = `n`;
const gs2_header = `${gs2_cbind_flag},,`;
const username = `n=*`;
const cbind_data = ``;
const cbind_input = `${gs2_header}${cbind_data}`;
const channel_binding = `c=${to_base64(cbind_input)}`;
const initial_nonce = `r=${to_base64(
crypto.getRandomValues(new Uint8Array(18))
)}`;
const client_first_message_bare = `${username},${initial_nonce}`;
const client_first_message = `${gs2_header}${client_first_message_bare}`;
2025-01-11 06:02:32 +11:00
write(SASLInitialResponse, { mechanism, data: client_first_message });
2025-01-07 22:12:30 +11:00
const server_first_message_str = from_utf8(
(await read(AuthenticationSASLContinue)).data
);
const server_first_message = parse_attrs(server_first_message_str);
const nonce = `r=${server_first_message.r ?? ""}`;
if (!nonce.startsWith(initial_nonce)) throw new WireError(`bad nonce`);
const salt = from_base64(server_first_message.s ?? "");
const iters = parseInt(server_first_message.i ?? "", 10) || 0;
const salted_password = await hi(password, salt, iters);
const client_key = await hmac(salted_password, "Client Key");
const stored_key = await h(client_key);
const client_final_message_without_proof = `${channel_binding},${nonce}`;
const auth_message = `${client_first_message_bare},${server_first_message_str},${client_final_message_without_proof}`;
const client_signature = await hmac(stored_key, auth_message);
const client_proof = buf_xor(client_key, client_signature);
const proof = `p=${to_base64(client_proof)}`;
const client_final_message = `${client_final_message_without_proof},${proof}`;
2025-01-11 06:02:32 +11:00
write(SASLResponse, { data: client_final_message });
2025-01-07 22:12:30 +11:00
const server_key = await hmac(salted_password, "Server Key");
const server_signature = await hmac(server_key, auth_message);
const server_final_message = parse_attrs(
from_utf8((await read(AuthenticationSASLFinal)).data)
);
if (!buf_eq(from_base64(server_final_message.v ?? ""), server_signature))
throw new WireError(`SASL server signature mismatch`);
}
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
const st_cache = new Map<string, Statement>();
let st_ids = 0;
class Statement {
readonly name = `__st${st_ids++}`;
constructor(readonly query: string) {}
2025-01-07 22:12:30 +11:00
2025-01-11 06:02:32 +11:00
#parse_task: Promise<{
ser_params: ParameterSerializer;
Row: RowConstructor;
}> | null = null;
2025-01-07 22:12:30 +11:00
parse() {
2025-01-11 06:02:32 +11:00
return (this.#parse_task ??= this.#parse());
2025-01-07 22:12:30 +11:00
}
async #parse() {
try {
const { name, query } = this;
return await pipeline(
2025-01-11 06:02:32 +11:00
() => {
write(Parse, { statement: name, query, param_types: [] });
write(Describe, { which: "S", name });
},
async () => {
await read(ParseComplete);
2025-01-11 02:00:05 +11:00
const ser_params = make_param_ser(await read(ParameterDescription));
2025-01-11 06:02:32 +11:00
const msg = msg_check_err(await read_msg());
2025-01-11 02:00:05 +11:00
const Row =
msg_type(msg) === NoData.type
2025-01-11 02:00:05 +11:00
? EmptyRow
: make_row_ctor(ser_decode(RowDescription, msg));
2025-01-11 02:00:05 +11:00
return { ser_params, Row };
}
2025-01-07 22:12:30 +11:00
);
} catch (e) {
2025-01-11 06:02:32 +11:00
throw ((this.#parse_task = null), e);
2025-01-07 22:12:30 +11:00
}
}
2025-01-11 06:02:32 +11:00
#portals = 0;
2025-01-07 22:12:30 +11:00
portal() {
2025-01-11 06:02:32 +11:00
return `${this.name}_${this.#portals++}`;
2025-01-07 22:12:30 +11:00
}
}
type ParameterDescription = EncoderType<typeof ParameterDescription>;
interface ParameterSerializer {
(params: unknown[]): (string | null)[];
}
2025-01-11 06:02:32 +11:00
// makes function to serialize query parameters
2025-01-10 19:32:41 +11:00
function make_param_ser({ param_types }: ParameterDescription) {
return jit.compiled<ParameterSerializer>`function ser_params(xs) {
return [
${jit.map(", ", param_types, (type_oid, i) => {
const type = types[type_oid] ?? text;
return jit`${type}.output(xs[${i}])`;
})}
];
}`;
}
type RowDescription = EncoderType<typeof RowDescription>;
interface RowConstructor {
new (columns: (BinaryLike | null)[]): Row;
}
2025-01-11 06:02:32 +11:00
// makes function to create Row objects
2025-01-11 02:00:05 +11:00
const EmptyRow = make_row_ctor({ columns: [] });
2025-01-10 19:32:41 +11:00
function make_row_ctor({ columns }: RowDescription) {
const Row = jit.compiled<RowConstructor>`function Row(xs) {
${jit.map(" ", columns, ({ name, type_oid }, i) => {
const type = types[type_oid] ?? text;
return jit`this[${name}] = xs[${i}] === null ? null : ${type}.input(${from_utf8}(xs[${i}]));`;
})}
}`;
Row.prototype = Object.create(null, {
[Symbol.toStringTag]: {
configurable: true,
value: `Row`,
},
[Symbol.toPrimitive]: {
configurable: true,
value: function format() {
return [...this].join("\t");
},
},
[Symbol.iterator]: {
configurable: true,
value: jit.compiled`function* iter() {
${jit.map(" ", columns, ({ name }) => {
return jit`yield this[${name}];`;
})}
}`,
},
});
return Row;
}
2025-01-07 22:12:30 +11:00
async function read_rows(
Row: RowConstructor,
stdout: WritableStream<Uint8Array> | null
) {
for (let rows = [], i = 0; ; ) {
2025-01-11 06:02:32 +11:00
const msg = msg_check_err(await read_msg());
2025-01-07 22:12:30 +11:00
switch (msg_type(msg)) {
default:
case DataRow.type:
rows[i++] = new Row(ser_decode(DataRow, msg).column_values);
continue;
case CommandComplete.type: {
const { tag } = ser_decode(CommandComplete, msg);
return { done: true as const, rows, tag };
}
case PortalSuspended.type:
return { done: false as const, rows, tag: "" };
case EmptyQueryResponse.type:
return { done: true as const, rows, tag: "" };
2025-01-11 02:00:05 +11:00
case RowDescription.type:
Row = make_row_ctor(ser_decode(RowDescription, msg));
continue;
2025-01-07 22:12:30 +11:00
case CopyInResponse.type:
continue;
case CopyOutResponse.type:
2025-01-11 02:00:05 +11:00
await read_copy_out(stdout), (stdout = null);
2025-01-07 22:12:30 +11:00
continue;
}
}
}
async function read_copy_out(stream: WritableStream<Uint8Array> | null) {
if (stream !== null) {
const writer = stream.getWriter();
try {
2025-01-11 06:02:32 +11:00
for (let msg; msg_type((msg = await read_msg())) !== CopyDone.type; ) {
2025-01-07 22:12:30 +11:00
const { data } = ser_decode(CopyData, msg_check_err(msg));
await writer.write(to_utf8(data));
}
await writer.close();
} catch (e) {
await writer.abort(e);
throw e;
2025-01-07 22:12:30 +11:00
} finally {
writer.releaseLock();
}
} else {
2025-01-11 06:02:32 +11:00
while (msg_type(msg_check_err(await read_msg())) !== CopyDone.type);
2025-01-07 22:12:30 +11:00
}
}
async function write_copy_in(stream: ReadableStream<Uint8Array> | null) {
if (stream !== null) {
const reader = stream.getReader();
try {
for (let next; !(next = await reader.read()).done; )
2025-01-11 06:02:32 +11:00
write(CopyData, { data: next.value });
write(CopyDone, {});
} catch (e) {
2025-01-11 06:02:32 +11:00
write(CopyFail, { cause: String(e) });
throw e;
2025-01-07 22:12:30 +11:00
} finally {
reader.releaseLock();
}
} else {
2025-01-11 06:02:32 +11:00
write(CopyDone, {});
2025-01-07 22:12:30 +11:00
}
}
2025-01-11 02:00:05 +11:00
async function* execute_simple(
query: string,
stdin: ReadableStream<Uint8Array> | null,
stdout: WritableStream<Uint8Array> | null
): ResultStream<Row> {
log("debug", { query: query }, `executing simple query`);
const { chunks, err } = await pipeline(
2025-01-11 06:02:32 +11:00
() => (write(QueryMessage, { query }), write_copy_in(stdin)),
2025-01-11 02:00:05 +11:00
async () => {
for (let chunks = [], err; ; ) {
2025-01-11 06:02:32 +11:00
const msg = await read_msg();
2025-01-11 02:00:05 +11:00
switch (msg_type(msg)) {
default:
case ReadyForQuery.type:
return { chunks, err };
case RowDescription.type: {
const Row = make_row_ctor(ser_decode(RowDescription, msg));
const { rows } = await read_rows(Row, stdout);
chunks.push(rows);
stdout = null;
continue;
}
case EmptyQueryResponse.type:
case CommandComplete.type:
continue;
case ErrorResponse.type: {
const { fields } = ser_decode(ErrorResponse, msg);
err = new PostgresError(fields);
continue;
}
}
}
}
);
yield* chunks;
if (err) throw err;
return { tag: "" };
}
2025-01-07 22:12:30 +11:00
async function* execute_fast(
st: Statement,
params: unknown[],
2025-01-07 22:12:30 +11:00
stdin: ReadableStream<Uint8Array> | null,
stdout: WritableStream<Uint8Array> | null
): ResultStream<Row> {
log(
"debug",
{ query: st.query, statement: st.name, params },
`executing query`
);
const { ser_params, Row } = await st.parse();
const param_values = ser_params(params);
2025-01-07 22:12:30 +11:00
const portal = st.portal();
try {
const { rows, tag } = await pipeline(
async () => {
2025-01-11 06:02:32 +11:00
write(Bind, {
2025-01-07 22:12:30 +11:00
portal,
statement: st.name,
param_formats: [],
param_values,
2025-01-07 22:12:30 +11:00
column_formats: [],
2025-01-11 06:02:32 +11:00
});
write(Execute, { portal, row_limit: 0 });
await write_copy_in(stdin);
write(Close, { which: "P" as const, name: portal });
2025-01-07 22:12:30 +11:00
},
async () => {
await read(BindComplete);
return read_rows(Row, stdout);
}
);
if (rows.length) yield rows;
return { tag };
} catch (e) {
2025-01-11 00:15:19 +11:00
try {
await pipeline(
() => write(Close, { which: "P" as const, name: portal }),
() => read(CloseComplete)
);
} catch {
// ignored
}
2025-01-07 22:12:30 +11:00
throw e;
}
}
async function* execute_chunked(
st: Statement,
params: unknown[],
2025-01-07 22:12:30 +11:00
chunk_size: number,
stdin: ReadableStream<Uint8Array> | null,
stdout: WritableStream<Uint8Array> | null
): ResultStream<Row> {
log(
"debug",
{ query: st.query, statement: st.name, params },
`executing chunked query`
);
const { ser_params, Row } = await st.parse();
const param_values = ser_params(params);
2025-01-07 22:12:30 +11:00
const portal = st.portal();
try {
let { done, rows, tag } = await pipeline(
2025-01-11 06:02:32 +11:00
() => {
write(Bind, {
2025-01-07 22:12:30 +11:00
portal,
statement: st.name,
param_formats: [],
param_values,
2025-01-07 22:12:30 +11:00
column_formats: [],
2025-01-11 06:02:32 +11:00
});
write(Execute, { portal, row_limit: chunk_size });
return write_copy_in(stdin);
2025-01-07 22:12:30 +11:00
},
async () => {
await read(BindComplete);
return read_rows(Row, stdout);
}
);
if (rows.length) yield rows;
while (!done) {
({ done, rows, tag } = await pipeline(
() => write(Execute, { portal, row_limit: chunk_size }),
() => read_rows(Row, stdout)
));
if (rows.length) yield rows;
}
return { tag };
} finally {
await pipeline(
() => write(Close, { which: "P" as const, name: portal }),
() => read(CloseComplete)
);
}
}
2025-01-11 02:00:05 +11:00
function query(sql: SqlFragment) {
return new Query(
({ simple = false, chunk_size = 0, stdin = null, stdout = null }) => {
const { query, params } = format(sql);
if (simple) {
if (!params.length) return execute_simple(query, stdin, stdout);
else throw new WireError(`simple query cannot be parameterised`);
}
2025-01-07 22:12:30 +11:00
2025-01-11 02:00:05 +11:00
let st = st_cache.get(query);
if (!st) st_cache.set(query, (st = new Statement(query)));
if (!chunk_size) return execute_fast(st, params, stdin, stdout);
else return execute_chunked(st, params, chunk_size, stdin, stdout);
}
2025-01-07 22:12:30 +11:00
);
}
// https://www.postgresql.org/docs/current/sql-begin.html
// https://www.postgresql.org/docs/current/sql-savepoint.html
let tx_status: "I" | "T" | "E" = "I";
const tx_stack: Transaction[] = [];
const tx_begin = query(sql`begin`);
const tx_commit = query(sql`commit`);
const tx_rollback = query(sql`rollback`);
const sp_savepoint = query(sql`savepoint __tx`);
const sp_release = query(sql`release __tx`);
const sp_rollback_to = query(sql`rollback to __tx`);
async function begin() {
const tx = new Transaction(
await (tx_stack.length ? sp_savepoint.execute() : tx_begin.execute())
);
return tx_stack.push(tx), tx;
}
const Transaction = class implements Transaction {
readonly tag!: string;
get open(): boolean {
return tx_stack.indexOf(this) !== -1;
}
constructor(begin: CommandResult) {
Object.assign(this, begin);
}
async commit() {
const i = tx_stack.indexOf(this);
if (i === -1) throw new WireError(`transaction is not open`);
else tx_stack.length = i;
return await (i ? sp_release.execute() : tx_commit.execute());
}
async rollback() {
const i = tx_stack.indexOf(this);
if (i === -1) throw new WireError(`transaction is not open`);
else tx_stack.length = i;
if (i !== 0) {
const res = await sp_rollback_to.execute();
return await sp_release.execute(), res;
} else {
return await tx_rollback.execute();
}
}
async [Symbol.asyncDispose]() {
if (this.open) await this.rollback();
}
};
// https://www.postgresql.org/docs/current/sql-listen.html
// https://www.postgresql.org/docs/current/sql-notify.html
const channels = new Map<string, Channel>();
async function listen(channel: string) {
let ch;
if ((ch = channels.get(channel))) return ch;
const res = await query(sql`listen ${sql.ident(channel)}`).execute();
if (tx_status !== "I")
log("warn", {}, `LISTEN executed inside transaction`);
if ((ch = channels.get(channel))) return ch;
return channels.set(channel, (ch = new Channel(channel, res))), ch;
}
async function notify(channel: string, payload: string) {
return await query(sql`select pg_notify(${channel}, ${payload})`).execute();
}
const Channel = class extends TypedEmitter<ChannelEvents> implements Channel {
readonly #name;
readonly tag!: string;
get name() {
return this.#name;
}
get open(): boolean {
return channels.get(this.#name) === this;
}
constructor(name: string, listen: CommandResult) {
super();
Object.assign(this, listen);
this.#name = name;
}
notify(payload: string) {
return notify(this.#name, payload);
}
async unlisten() {
const name = this.#name;
if (channels.get(name) === this) channels.delete(name);
else throw new WireError(`channel is not listening`);
return await query(sql`unlisten ${sql.ident(name)}`).execute();
}
async [Symbol.asyncDispose]() {
if (this.open) await this.unlisten();
}
};
2025-01-11 06:02:32 +11:00
return { params, connect, query, begin, listen, notify, close };
2025-01-07 22:12:30 +11:00
}
export interface PoolOptions extends WireOptions {
max_connections: number;
idle_timeout: number;
}
export type PoolEvents = {
log(level: LogLevel, ctx: object, msg: string): void;
};
export interface PoolWire extends Wire {
readonly connection_id: number;
readonly borrowed: boolean;
release(): void;
}
export interface PoolTransaction extends Transaction {
readonly wire: PoolWire;
}
export class Pool
extends TypedEmitter<PoolEvents>
implements PromiseLike<PoolWire>, Disposable
{
readonly #acquire;
readonly #begin;
readonly #close;
constructor(options: PoolOptions) {
super();
({
acquire: this.#acquire,
begin: this.#begin,
close: this.#close,
} = pool_impl(this, options));
}
get(): Promise<PoolWire>;
get<T>(f: (wire: PoolWire) => T | PromiseLike<T>): Promise<T>;
async get(f?: (wire: PoolWire) => unknown) {
if (typeof f !== "undefined") {
using wire = await this.#acquire();
return await f(wire);
} else {
return this.#acquire();
}
}
2025-01-11 00:15:19 +11:00
query<T = Row>(sql: SqlFragment): Query<T>;
query<T = Row>(s: TemplateStringsArray, ...xs: unknown[]): Query<T>;
2025-01-07 22:12:30 +11:00
query(s: TemplateStringsArray | SqlFragment, ...xs: unknown[]) {
s = is_sql(s) ? s : sql(s, ...xs);
const acquire = this.#acquire;
return new Query(async function* stream(options) {
using wire = await acquire();
return yield* wire.query(s).stream(options);
});
}
begin(): Promise<PoolTransaction>;
begin<T>(
f: (wire: PoolWire, tx: PoolTransaction) => T | PromiseLike<T>
): Promise<T>;
async begin(f?: (wire: PoolWire, tx: PoolTransaction) => unknown) {
if (typeof f !== "undefined") {
await using tx = await this.#begin();
const value = await f(tx.wire, tx);
if (tx.open) await tx.commit();
return value;
2025-01-07 22:12:30 +11:00
} else {
return this.#begin();
}
}
then<T = PoolWire, U = never>(
f?: ((wire: PoolWire) => T | PromiseLike<T>) | null,
g?: ((reason?: unknown) => U | PromiseLike<U>) | null
) {
return this.get().then(f, g);
}
close() {
this.#close();
}
[Symbol.dispose]() {
this.close();
}
}
function pool_impl(
pool: Pool,
{ max_connections, idle_timeout: _, ...options }: PoolOptions
) {
const lock = semaphore(max_connections);
const all = new Set<PoolWire>();
const free: PoolWire[] = [];
let ids = 0;
const PoolWire = class extends Wire implements PoolWire {
readonly #id = ids++;
get connection_id() {
return this.#id;
}
get borrowed(): boolean {
return free.indexOf(this) === -1;
}
release() {
if (all.has(this) && free.indexOf(this) === -1)
free.push(this), lock.release();
}
override [Symbol.dispose]() {
this.release();
}
};
const PoolTransaction = class implements Transaction {
readonly #wire;
readonly #tx;
get wire() {
return this.#wire;
}
get tag() {
return this.#tx.tag;
}
get open() {
return this.#tx.open;
}
constructor(wire: PoolWire, tx: Transaction) {
this.#wire = wire;
this.#tx = tx;
}
async commit() {
const res = await this.#tx.commit();
return this.#wire.release(), res;
}
async rollback() {
const res = await this.#tx.rollback();
return this.#wire.release(), res;
}
async [Symbol.asyncDispose]() {
if (this.open) await this.rollback();
}
};
async function connect() {
2025-01-11 06:02:32 +11:00
const wire = new PoolWire(options);
await wire.connect(), all.add(wire);
2025-01-07 22:12:30 +11:00
const { connection_id } = wire;
return wire
.on("log", (l, c, s) => pool.emit("log", l, { ...c, connection_id }, s))
.on("close", () => forget(wire));
}
async function acquire() {
await lock();
try {
return free.pop() ?? (await connect());
} catch (e) {
throw (lock.release(), e);
}
}
function forget(wire: PoolWire) {
if (all.delete(wire)) {
const i = free.indexOf(wire);
if (i !== -1) free.splice(i, 1);
else lock.release();
}
}
async function begin() {
const wire = await acquire();
try {
return new PoolTransaction(wire, await wire.begin());
} catch (e) {
throw (wire.release(), e);
}
}
function close() {
for (const wire of all) wire.close();
all.clear(), (free.length = 0), lock.reset(max_connections);
}
return { acquire, begin, close };
}