From 119c06565c3352a39ef49d47c980d8f1a15e5cd9 Mon Sep 17 00:00:00 2001 From: luaneko Date: Sun, 12 Jan 2025 06:36:15 +1100 Subject: [PATCH] Rewrite options handling --- deno.lock | 8 +++- lstd.ts | 2 +- mod.ts | 132 +++++++++++++++++++++++++----------------------------- query.ts | 4 +- ser.ts | 2 +- test.ts | 30 ++++++------- wire.ts | 101 ++++++++++++++++++++++++++--------------- 7 files changed, 150 insertions(+), 129 deletions(-) diff --git a/deno.lock b/deno.lock index 7aa15e4..36fdf31 100644 --- a/deno.lock +++ b/deno.lock @@ -465,6 +465,12 @@ "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/events.ts": "28d395b8eea87f9bf7908a44b351d2d3c609ba7eab62bcecd0d43be8ee603438", "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/func.ts": "f1935f673365cd68939531d65ef18fe81b5d43dc795b03c34bb5ad821ab1c9ff", "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/jit.ts": "c1db7820de95c48521b057c7cdf9aa41f7eaba77462407c29d3932e7da252d53", - "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/mod.ts": "95d8b15048a54cb82391825831f695b74e7c8b206317264a99c906ce25c63f13" + "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/mod.ts": "95d8b15048a54cb82391825831f695b74e7c8b206317264a99c906ce25c63f13", + "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/async.ts": "20bc54c7260c2d2cd27ffcca33b903dde57a3a3635386d8e0c6baca4b253ae4e", + "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/bytes.ts": "94f4809b375800bb2c949e31082dfdf08d022db56c5b5c9c7dfe6f399285da6f", + "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/events.ts": "28d395b8eea87f9bf7908a44b351d2d3c609ba7eab62bcecd0d43be8ee603438", + "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/func.ts": "f1935f673365cd68939531d65ef18fe81b5d43dc795b03c34bb5ad821ab1c9ff", + "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/jit.ts": "c1db7820de95c48521b057c7cdf9aa41f7eaba77462407c29d3932e7da252d53", + "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/mod.ts": "589763be8ab18e7d6c5f5921e74ab44580f466c92acead401b2903d42d94112a" } } diff --git a/lstd.ts b/lstd.ts index c7101be..e70c8d3 100644 --- a/lstd.ts +++ b/lstd.ts @@ -1 +1 @@ -export * from "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/mod.ts"; +export * from "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/mod.ts"; diff --git a/mod.ts b/mod.ts index 95d58f1..56ec622 100644 --- a/mod.ts +++ b/mod.ts @@ -1,15 +1,13 @@ -import pg_conn_string from "npm:pg-connection-string@^2.7.0"; +import pg_conn_str from "npm:pg-connection-string@^2.7.0"; +import type * as v from "./valita.ts"; import { - type Infer, - number, - object, - record, - string, - union, - unknown, -} from "./valita.ts"; -import { Pool, wire_connect } from "./wire.ts"; -import { sql_types, type SqlTypeMap } from "./query.ts"; + Pool, + PoolOptions, + SubscribeOptions, + Subscription, + Wire, + WireOptions, +} from "./wire.ts"; export { WireError, @@ -33,85 +31,77 @@ export { type RowStream, } from "./query.ts"; -export type Options = { - host?: string; - port?: number | string; - user?: string; - password?: string; - database?: string | null; - max_connections?: number; - idle_timeout?: number; - runtime_params?: Record; - types?: SqlTypeMap; -}; +export default function postgres(s: string, options: Partial = {}) { + return new Postgres(Options.parse(parse_conn(s, options), { mode: "strip" })); +} -type ParsedOptions = Infer; -const ParsedOptions = object({ - host: string().optional(() => "localhost"), - port: union( - number(), - string().map((s) => parseInt(s, 10)) - ).optional(() => 5432), - user: string().optional(() => "postgres"), - password: string().optional(() => "postgres"), - database: string() - .nullable() - .optional(() => null), - runtime_params: record(string()).optional(() => ({})), - max_connections: number().optional(() => 10), - idle_timeout: number().optional(() => 20), - reconnect_delay: number().optional(() => 5), - types: record(unknown()) - .optional(() => ({})) - .map((types): SqlTypeMap => ({ ...sql_types, ...types })), -}); - -function parse_opts(s: string, opts: Options) { +function parse_conn(s: string, options: Partial) { const { host, port, user, password, database, - ssl: _ssl, // TODO: + ssl: _ssl, // TODO: ssl support ...runtime_params - } = pg_conn_string.parse(s); + } = s ? pg_conn_str.parse(s) : {}; - const { PGHOST, PGPORT, PGUSER, PGPASSWORD, PGDATABASE, USER } = - Deno.env.toObject(); - - return ParsedOptions.parse({ - ...opts, - host: opts.host ?? host ?? PGHOST ?? undefined, - port: opts.port ?? port ?? PGPORT ?? undefined, - user: opts.user ?? user ?? PGUSER ?? USER ?? undefined, - password: opts.password ?? password ?? PGPASSWORD ?? undefined, - database: opts.database ?? database ?? PGDATABASE ?? undefined, - runtime_params: { ...runtime_params, ...opts.runtime_params }, - }); -} - -export default function postgres(s: string, options: Options = {}) { - return new Postgres(parse_opts(s, options)); -} - -export function connect(s: string, options: Options = {}) { - return wire_connect(parse_opts(s, options)); + return { + ...options, + host: options.host ?? host, + port: options.port ?? port, + user: options.user ?? user, + password: options.password ?? password, + database: options.database ?? database, + runtime_params: { ...runtime_params, ...options.runtime_params }, + }; } postgres.connect = connect; +postgres.subscribe = subscribe; + +export async function connect(s: string, options: Partial = {}) { + return await new Wire( + WireOptions.parse(parse_conn(s, options), { mode: "strip" }) + ).connect(); +} + +export async function subscribe( + s: string, + options: Partial = {} +) { + return await new Subscription( + SubscribeOptions.parse(parse_conn(s, options), { mode: "strip" }) + ).connect(); +} + +export type Options = v.Infer; +export const Options = PoolOptions; export class Postgres extends Pool { readonly #options; - constructor(options: ParsedOptions) { + constructor(options: Options) { super(options); this.#options = options; } - async connect(options: Options = {}) { - const opts = ParsedOptions.parse({ ...this.#options, ...options }); - const wire = await wire_connect(opts); - return wire.on("log", (l, c, s) => this.emit("log", l, c, s)); + async connect(options: Partial = {}) { + return await new Wire( + WireOptions.parse({ ...this.#options, ...options }, { mode: "strip" }) + ) + .on("log", (l, c, s) => this.emit("log", l, c, s)) + .connect(); + } + + async subscribe(options: Partial = {}) { + return await new Subscription( + SubscribeOptions.parse( + { ...this.#options, ...options }, + { mode: "strip" } + ) + ) + .on("log", (l, c, s) => this.emit("log", l, c, s)) + .connect(); } } diff --git a/query.ts b/query.ts index b35fccf..96f1a87 100644 --- a/query.ts +++ b/query.ts @@ -1,4 +1,4 @@ -import type { ObjectType } from "./valita.ts"; +import type * as v from "./valita.ts"; import { from_hex, to_hex, to_utf8 } from "./lstd.ts"; export const sql_format = Symbol.for(`re.lua.pglue.sql_format`); @@ -470,7 +470,7 @@ export class Query implements PromiseLike>, RowStream { }); } - parse( + parse( type: S, { mode = "strip" }: { mode?: "passthrough" | "strict" | "strip" } = {} ) { diff --git a/ser.ts b/ser.ts index 99c8b93..30444e4 100644 --- a/ser.ts +++ b/ser.ts @@ -11,7 +11,7 @@ import { write_i8, } from "./lstd.ts"; -export class EncoderError extends Error { +export class EncoderError extends TypeError { override get name() { return this.constructor.name; } diff --git a/test.ts b/test.ts index 428be4e..7654a13 100644 --- a/test.ts +++ b/test.ts @@ -2,18 +2,14 @@ import pglue, { PostgresError, SqlTypeError } from "./mod.ts"; import { expect } from "jsr:@std/expect"; import { toText } from "jsr:@std/streams"; -async function connect(params?: Record) { - const pg = await pglue.connect(`postgres://test:test@localhost:5432/test`, { - runtime_params: { client_min_messages: "INFO", ...params }, - }); +const pool = pglue(`postgres://test:test@localhost:5432/test`, { + runtime_params: { client_min_messages: "INFO" }, +}); - return pg.on("log", (_level, ctx, msg) => { - console.info(`${msg}`, ctx); - }); -} +pool.on("log", (level, ctx, msg) => console.info(`${level}: ${msg}`, ctx)); Deno.test(`integers`, async () => { - await using pg = await connect(); + await using pg = await pool.connect(); await using _tx = await pg.begin(); const { a, b, c } = await pg.query` @@ -44,7 +40,7 @@ Deno.test(`integers`, async () => { }); Deno.test(`boolean`, async () => { - await using pg = await connect(); + await using pg = await pool.connect(); await using _tx = await pg.begin(); const { a, b, c } = await pg.query` @@ -60,7 +56,7 @@ Deno.test(`boolean`, async () => { }); Deno.test(`bytea`, async () => { - await using pg = await connect(); + await using pg = await pool.connect(); await using _tx = await pg.begin(); const { string, array, buffer } = await pg.query` @@ -76,7 +72,7 @@ Deno.test(`bytea`, async () => { }); Deno.test(`row`, async () => { - await using pg = await connect(); + await using pg = await pool.connect(); await using _tx = await pg.begin(); expect( @@ -119,7 +115,7 @@ Deno.test(`row`, async () => { }); Deno.test(`sql injection`, async () => { - await using pg = await connect(); + await using pg = await pool.connect(); await using _tx = await pg.begin(); const input = `injection'); drop table users; --`; @@ -140,7 +136,7 @@ Deno.test(`sql injection`, async () => { }); Deno.test(`listen/notify`, async () => { - await using pg = await connect(); + await using pg = await pool.connect(); const sent: string[] = []; await using ch = await pg.listen(`my channel`, (payload) => { @@ -157,7 +153,7 @@ Deno.test(`listen/notify`, async () => { }); Deno.test(`transactions`, async () => { - await using pg = await connect(); + await using pg = await pool.connect(); await pg.begin(async (pg) => { await pg.begin(async (pg, tx) => { @@ -192,7 +188,7 @@ Deno.test(`transactions`, async () => { }); Deno.test(`streaming`, async () => { - await using pg = await connect(); + await using pg = await pool.connect(); await using _tx = await pg.begin(); await pg.query`create table my_table (field text not null)`; @@ -211,7 +207,7 @@ Deno.test(`streaming`, async () => { }); Deno.test(`simple`, async () => { - await using pg = await connect(); + await using pg = await pool.connect(); await using _tx = await pg.begin(); const rows = await pg.query` diff --git a/wire.ts b/wire.ts index 3143f99..74c3464 100644 --- a/wire.ts +++ b/wire.ts @@ -1,3 +1,5 @@ +import * as v from "./valita.ts"; +import { join } from "jsr:@std/path@^1.0.8"; import { type BinaryLike, buf_concat, @@ -11,6 +13,7 @@ import { type Receiver, semaphore, type Sender, + to_base58, to_base64, to_utf8, TypedEmitter, @@ -45,8 +48,8 @@ import { type SqlFragment, type SqlTypeMap, text, + sql_types, } from "./query.ts"; -import { join } from "jsr:@std/path@^1.0.8"; export class WireError extends Error { override get name() { @@ -437,29 +440,53 @@ export const StartupMessage = msg("", { export const Sync = msg("S", {}); export const Terminate = msg("X", {}); -export type LogLevel = "debug" | "info" | "warn" | "error" | "fatal"; - -export interface Parameters extends Readonly>> {} - -export interface WireOptions { - readonly host: string; - readonly port: number; - readonly user: string; - readonly password: string; - readonly database: string | null; - readonly runtime_params: Record; - readonly reconnect_delay: number; - readonly types: SqlTypeMap; +function getenv(name: string) { + return Deno.env.get(name); } +export type WireOptions = v.Infer; +export const WireOptions = v.object({ + host: v.string().optional(() => getenv("PGHOST") ?? "localhost"), + port: v + .union(v.string(), v.number()) + .optional(() => getenv("PGPORT") ?? 5432) + .map(Number) + .assert(Number.isSafeInteger, `invalid number`), + user: v + .string() + .optional(() => getenv("PGUSER") ?? getenv("USER") ?? "postgres"), + password: v.string().optional(() => getenv("PGPASSWORD") ?? "postgres"), + database: v + .string() + .nullable() + .optional(() => getenv("PGDATABASE") ?? null), + runtime_params: v + .record(v.string()) + .map((p) => ((p.application_name ??= "pglue"), p)), + reconnect_delay: v + .number() + .optional(() => 5) + .assert(Number.isSafeInteger, `invalid number`) + .nullable(), + types: v + .record(v.unknown()) + .optional(() => ({})) + .map((types): SqlTypeMap => ({ ...sql_types, ...types })), +}); + export type WireEvents = { log(level: LogLevel, ctx: object, msg: string): void; + connect(): void; notice(notice: PostgresError): void; notify(channel: string, payload: string, process_id: number): void; parameter(name: string, value: string, prev: string | null): void; close(reason?: unknown): void; }; +export type LogLevel = "debug" | "info" | "warn" | "error" | "fatal"; + +export interface Parameters extends Readonly>> {} + export interface Transaction extends Result, AsyncDisposable { readonly open: boolean; commit(): Promise; @@ -478,15 +505,11 @@ export interface Channel unlisten(): Promise; } -export async function wire_connect(options: WireOptions) { - const wire = new Wire(options); - return await wire.connect(), wire; -} - export class Wire extends TypedEmitter implements Disposable { + readonly #options; readonly #params; readonly #connect; readonly #query; @@ -509,11 +532,11 @@ export class Wire listen: this.#listen, notify: this.#notify, close: this.#close, - } = wire_impl(this, options)); + } = wire_impl(this, (this.#options = options))); } - connect() { - return this.#connect(); + async connect() { + return await this.#connect(), this; } query(sql: SqlFragment): Query; @@ -855,16 +878,18 @@ function wire_impl( read_pop = channel.receiver((push) => read_socket(s, push)); write_push = channel.sender((pop) => write_socket(s, pop)); await handle_auth(); // run auth with rw lock - (connected = true), (should_reconnect = reconnect_delay !== 0); + (connected = true), (should_reconnect = reconnect_delay !== null); + wire.emit("connect"); } catch (e) { throw (close(e), e); } } function reconnect() { + if (should_reconnect) return; connect().catch((err) => { log("warn", err as Error, `reconnect failed`); - setTimeout(reconnect, reconnect_delay); + if (reconnect_delay !== null) setTimeout(reconnect, reconnect_delay); }); } @@ -882,7 +907,7 @@ function wire_impl( delete (params as Record)[name]; st_cache.clear(), (st_ids = 0); (tx_status = "I"), (tx_stack.length = 0); - should_reconnect &&= (setTimeout(reconnect, reconnect_delay), false); + should_reconnect &&= (reconnect(), false); wire.emit("close", reason); } @@ -1063,9 +1088,7 @@ function wire_impl( const cbind_data = ``; const cbind_input = `${gs2_header}${cbind_data}`; const channel_binding = `c=${to_base64(cbind_input)}`; - const initial_nonce = `r=${to_base64( - crypto.getRandomValues(new Uint8Array(18)) - )}`; + const initial_nonce = `r=${randstr(20)}`; const client_first_message_bare = `${username},${initial_nonce}`; const client_first_message = `${gs2_header}${client_first_message_bare}`; write(SASLInitialResponse, { mechanism, data: client_first_message }); @@ -1550,10 +1573,17 @@ function wire_impl( return { params, connect, query, begin, listen, notify, close }; } -export interface PoolOptions extends WireOptions { - max_connections: number; - idle_timeout: number; -} +export type PoolOptions = v.Infer; +export const PoolOptions = WireOptions.extend({ + max_connections: v + .number() + .optional(() => 10) + .assert(Number.isSafeInteger, `invalid number`), + idle_timeout: v + .number() + .optional(() => 30) + .assert(Number.isSafeInteger, `invalid number`), +}); export type PoolEvents = { log(level: LogLevel, ctx: object, msg: string): void; @@ -1706,12 +1736,11 @@ function pool_impl( }; async function connect() { - const wire = new PoolWire(options); - await wire.connect(), all.add(wire); - const { connection_id } = wire; - return wire + const wire = new PoolWire({ ...options, reconnect_delay: null }); + const { connection_id } = wire .on("log", (l, c, s) => pool.emit("log", l, { ...c, connection_id }, s)) .on("close", () => forget(wire)); + return await wire.connect(), all.add(wire), wire; } async function acquire() {