Rewrite options handling

This commit is contained in:
luaneko 2025-01-12 06:36:15 +11:00
parent d959a80678
commit 119c06565c
Signed by: luaneko
GPG Key ID: 406809B8763FF07A
7 changed files with 150 additions and 129 deletions

8
deno.lock generated
View File

@ -465,6 +465,12 @@
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/events.ts": "28d395b8eea87f9bf7908a44b351d2d3c609ba7eab62bcecd0d43be8ee603438", "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/events.ts": "28d395b8eea87f9bf7908a44b351d2d3c609ba7eab62bcecd0d43be8ee603438",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/func.ts": "f1935f673365cd68939531d65ef18fe81b5d43dc795b03c34bb5ad821ab1c9ff", "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/func.ts": "f1935f673365cd68939531d65ef18fe81b5d43dc795b03c34bb5ad821ab1c9ff",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/jit.ts": "c1db7820de95c48521b057c7cdf9aa41f7eaba77462407c29d3932e7da252d53", "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/jit.ts": "c1db7820de95c48521b057c7cdf9aa41f7eaba77462407c29d3932e7da252d53",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/mod.ts": "95d8b15048a54cb82391825831f695b74e7c8b206317264a99c906ce25c63f13" "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/mod.ts": "95d8b15048a54cb82391825831f695b74e7c8b206317264a99c906ce25c63f13",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/async.ts": "20bc54c7260c2d2cd27ffcca33b903dde57a3a3635386d8e0c6baca4b253ae4e",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/bytes.ts": "94f4809b375800bb2c949e31082dfdf08d022db56c5b5c9c7dfe6f399285da6f",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/events.ts": "28d395b8eea87f9bf7908a44b351d2d3c609ba7eab62bcecd0d43be8ee603438",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/func.ts": "f1935f673365cd68939531d65ef18fe81b5d43dc795b03c34bb5ad821ab1c9ff",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/jit.ts": "c1db7820de95c48521b057c7cdf9aa41f7eaba77462407c29d3932e7da252d53",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/mod.ts": "589763be8ab18e7d6c5f5921e74ab44580f466c92acead401b2903d42d94112a"
} }
} }

View File

@ -1 +1 @@
export * from "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/mod.ts"; export * from "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/mod.ts";

132
mod.ts
View File

@ -1,15 +1,13 @@
import pg_conn_string from "npm:pg-connection-string@^2.7.0"; import pg_conn_str from "npm:pg-connection-string@^2.7.0";
import type * as v from "./valita.ts";
import { import {
type Infer, Pool,
number, PoolOptions,
object, SubscribeOptions,
record, Subscription,
string, Wire,
union, WireOptions,
unknown, } from "./wire.ts";
} from "./valita.ts";
import { Pool, wire_connect } from "./wire.ts";
import { sql_types, type SqlTypeMap } from "./query.ts";
export { export {
WireError, WireError,
@ -33,85 +31,77 @@ export {
type RowStream, type RowStream,
} from "./query.ts"; } from "./query.ts";
export type Options = { export default function postgres(s: string, options: Partial<Options> = {}) {
host?: string; return new Postgres(Options.parse(parse_conn(s, options), { mode: "strip" }));
port?: number | string; }
user?: string;
password?: string;
database?: string | null;
max_connections?: number;
idle_timeout?: number;
runtime_params?: Record<string, string>;
types?: SqlTypeMap;
};
type ParsedOptions = Infer<typeof ParsedOptions>; function parse_conn(s: string, options: Partial<WireOptions>) {
const ParsedOptions = object({
host: string().optional(() => "localhost"),
port: union(
number(),
string().map((s) => parseInt(s, 10))
).optional(() => 5432),
user: string().optional(() => "postgres"),
password: string().optional(() => "postgres"),
database: string()
.nullable()
.optional(() => null),
runtime_params: record(string()).optional(() => ({})),
max_connections: number().optional(() => 10),
idle_timeout: number().optional(() => 20),
reconnect_delay: number().optional(() => 5),
types: record(unknown())
.optional(() => ({}))
.map((types): SqlTypeMap => ({ ...sql_types, ...types })),
});
function parse_opts(s: string, opts: Options) {
const { const {
host, host,
port, port,
user, user,
password, password,
database, database,
ssl: _ssl, // TODO: ssl: _ssl, // TODO: ssl support
...runtime_params ...runtime_params
} = pg_conn_string.parse(s); } = s ? pg_conn_str.parse(s) : {};
const { PGHOST, PGPORT, PGUSER, PGPASSWORD, PGDATABASE, USER } = return {
Deno.env.toObject(); ...options,
host: options.host ?? host,
return ParsedOptions.parse({ port: options.port ?? port,
...opts, user: options.user ?? user,
host: opts.host ?? host ?? PGHOST ?? undefined, password: options.password ?? password,
port: opts.port ?? port ?? PGPORT ?? undefined, database: options.database ?? database,
user: opts.user ?? user ?? PGUSER ?? USER ?? undefined, runtime_params: { ...runtime_params, ...options.runtime_params },
password: opts.password ?? password ?? PGPASSWORD ?? undefined, };
database: opts.database ?? database ?? PGDATABASE ?? undefined,
runtime_params: { ...runtime_params, ...opts.runtime_params },
});
}
export default function postgres(s: string, options: Options = {}) {
return new Postgres(parse_opts(s, options));
}
export function connect(s: string, options: Options = {}) {
return wire_connect(parse_opts(s, options));
} }
postgres.connect = connect; postgres.connect = connect;
postgres.subscribe = subscribe;
export async function connect(s: string, options: Partial<WireOptions> = {}) {
return await new Wire(
WireOptions.parse(parse_conn(s, options), { mode: "strip" })
).connect();
}
export async function subscribe(
s: string,
options: Partial<SubscribeOptions> = {}
) {
return await new Subscription(
SubscribeOptions.parse(parse_conn(s, options), { mode: "strip" })
).connect();
}
export type Options = v.Infer<typeof Options>;
export const Options = PoolOptions;
export class Postgres extends Pool { export class Postgres extends Pool {
readonly #options; readonly #options;
constructor(options: ParsedOptions) { constructor(options: Options) {
super(options); super(options);
this.#options = options; this.#options = options;
} }
async connect(options: Options = {}) { async connect(options: Partial<WireOptions> = {}) {
const opts = ParsedOptions.parse({ ...this.#options, ...options }); return await new Wire(
const wire = await wire_connect(opts); WireOptions.parse({ ...this.#options, ...options }, { mode: "strip" })
return wire.on("log", (l, c, s) => this.emit("log", l, c, s)); )
.on("log", (l, c, s) => this.emit("log", l, c, s))
.connect();
}
async subscribe(options: Partial<SubscribeOptions> = {}) {
return await new Subscription(
SubscribeOptions.parse(
{ ...this.#options, ...options },
{ mode: "strip" }
)
)
.on("log", (l, c, s) => this.emit("log", l, c, s))
.connect();
} }
} }

View File

@ -1,4 +1,4 @@
import type { ObjectType } from "./valita.ts"; import type * as v from "./valita.ts";
import { from_hex, to_hex, to_utf8 } from "./lstd.ts"; import { from_hex, to_hex, to_utf8 } from "./lstd.ts";
export const sql_format = Symbol.for(`re.lua.pglue.sql_format`); export const sql_format = Symbol.for(`re.lua.pglue.sql_format`);
@ -470,7 +470,7 @@ export class Query<T = Row> implements PromiseLike<Rows<T>>, RowStream<T> {
}); });
} }
parse<S extends ObjectType>( parse<S extends v.ObjectType>(
type: S, type: S,
{ mode = "strip" }: { mode?: "passthrough" | "strict" | "strip" } = {} { mode = "strip" }: { mode?: "passthrough" | "strict" | "strip" } = {}
) { ) {

2
ser.ts
View File

@ -11,7 +11,7 @@ import {
write_i8, write_i8,
} from "./lstd.ts"; } from "./lstd.ts";
export class EncoderError extends Error { export class EncoderError extends TypeError {
override get name() { override get name() {
return this.constructor.name; return this.constructor.name;
} }

30
test.ts
View File

@ -2,18 +2,14 @@ import pglue, { PostgresError, SqlTypeError } from "./mod.ts";
import { expect } from "jsr:@std/expect"; import { expect } from "jsr:@std/expect";
import { toText } from "jsr:@std/streams"; import { toText } from "jsr:@std/streams";
async function connect(params?: Record<string, string>) { const pool = pglue(`postgres://test:test@localhost:5432/test`, {
const pg = await pglue.connect(`postgres://test:test@localhost:5432/test`, { runtime_params: { client_min_messages: "INFO" },
runtime_params: { client_min_messages: "INFO", ...params }, });
});
return pg.on("log", (_level, ctx, msg) => { pool.on("log", (level, ctx, msg) => console.info(`${level}: ${msg}`, ctx));
console.info(`${msg}`, ctx);
});
}
Deno.test(`integers`, async () => { Deno.test(`integers`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await using _tx = await pg.begin(); await using _tx = await pg.begin();
const { a, b, c } = await pg.query` const { a, b, c } = await pg.query`
@ -44,7 +40,7 @@ Deno.test(`integers`, async () => {
}); });
Deno.test(`boolean`, async () => { Deno.test(`boolean`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await using _tx = await pg.begin(); await using _tx = await pg.begin();
const { a, b, c } = await pg.query` const { a, b, c } = await pg.query`
@ -60,7 +56,7 @@ Deno.test(`boolean`, async () => {
}); });
Deno.test(`bytea`, async () => { Deno.test(`bytea`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await using _tx = await pg.begin(); await using _tx = await pg.begin();
const { string, array, buffer } = await pg.query` const { string, array, buffer } = await pg.query`
@ -76,7 +72,7 @@ Deno.test(`bytea`, async () => {
}); });
Deno.test(`row`, async () => { Deno.test(`row`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await using _tx = await pg.begin(); await using _tx = await pg.begin();
expect( expect(
@ -119,7 +115,7 @@ Deno.test(`row`, async () => {
}); });
Deno.test(`sql injection`, async () => { Deno.test(`sql injection`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await using _tx = await pg.begin(); await using _tx = await pg.begin();
const input = `injection'); drop table users; --`; const input = `injection'); drop table users; --`;
@ -140,7 +136,7 @@ Deno.test(`sql injection`, async () => {
}); });
Deno.test(`listen/notify`, async () => { Deno.test(`listen/notify`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
const sent: string[] = []; const sent: string[] = [];
await using ch = await pg.listen(`my channel`, (payload) => { await using ch = await pg.listen(`my channel`, (payload) => {
@ -157,7 +153,7 @@ Deno.test(`listen/notify`, async () => {
}); });
Deno.test(`transactions`, async () => { Deno.test(`transactions`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await pg.begin(async (pg) => { await pg.begin(async (pg) => {
await pg.begin(async (pg, tx) => { await pg.begin(async (pg, tx) => {
@ -192,7 +188,7 @@ Deno.test(`transactions`, async () => {
}); });
Deno.test(`streaming`, async () => { Deno.test(`streaming`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await using _tx = await pg.begin(); await using _tx = await pg.begin();
await pg.query`create table my_table (field text not null)`; await pg.query`create table my_table (field text not null)`;
@ -211,7 +207,7 @@ Deno.test(`streaming`, async () => {
}); });
Deno.test(`simple`, async () => { Deno.test(`simple`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await using _tx = await pg.begin(); await using _tx = await pg.begin();
const rows = await pg.query` const rows = await pg.query`

101
wire.ts
View File

@ -1,3 +1,5 @@
import * as v from "./valita.ts";
import { join } from "jsr:@std/path@^1.0.8";
import { import {
type BinaryLike, type BinaryLike,
buf_concat, buf_concat,
@ -11,6 +13,7 @@ import {
type Receiver, type Receiver,
semaphore, semaphore,
type Sender, type Sender,
to_base58,
to_base64, to_base64,
to_utf8, to_utf8,
TypedEmitter, TypedEmitter,
@ -45,8 +48,8 @@ import {
type SqlFragment, type SqlFragment,
type SqlTypeMap, type SqlTypeMap,
text, text,
sql_types,
} from "./query.ts"; } from "./query.ts";
import { join } from "jsr:@std/path@^1.0.8";
export class WireError extends Error { export class WireError extends Error {
override get name() { override get name() {
@ -437,29 +440,53 @@ export const StartupMessage = msg("", {
export const Sync = msg("S", {}); export const Sync = msg("S", {});
export const Terminate = msg("X", {}); export const Terminate = msg("X", {});
export type LogLevel = "debug" | "info" | "warn" | "error" | "fatal"; function getenv(name: string) {
return Deno.env.get(name);
export interface Parameters extends Readonly<Partial<Record<string, string>>> {}
export interface WireOptions {
readonly host: string;
readonly port: number;
readonly user: string;
readonly password: string;
readonly database: string | null;
readonly runtime_params: Record<string, string>;
readonly reconnect_delay: number;
readonly types: SqlTypeMap;
} }
export type WireOptions = v.Infer<typeof WireOptions>;
export const WireOptions = v.object({
host: v.string().optional(() => getenv("PGHOST") ?? "localhost"),
port: v
.union(v.string(), v.number())
.optional(() => getenv("PGPORT") ?? 5432)
.map(Number)
.assert(Number.isSafeInteger, `invalid number`),
user: v
.string()
.optional(() => getenv("PGUSER") ?? getenv("USER") ?? "postgres"),
password: v.string().optional(() => getenv("PGPASSWORD") ?? "postgres"),
database: v
.string()
.nullable()
.optional(() => getenv("PGDATABASE") ?? null),
runtime_params: v
.record(v.string())
.map((p) => ((p.application_name ??= "pglue"), p)),
reconnect_delay: v
.number()
.optional(() => 5)
.assert(Number.isSafeInteger, `invalid number`)
.nullable(),
types: v
.record(v.unknown())
.optional(() => ({}))
.map((types): SqlTypeMap => ({ ...sql_types, ...types })),
});
export type WireEvents = { export type WireEvents = {
log(level: LogLevel, ctx: object, msg: string): void; log(level: LogLevel, ctx: object, msg: string): void;
connect(): void;
notice(notice: PostgresError): void; notice(notice: PostgresError): void;
notify(channel: string, payload: string, process_id: number): void; notify(channel: string, payload: string, process_id: number): void;
parameter(name: string, value: string, prev: string | null): void; parameter(name: string, value: string, prev: string | null): void;
close(reason?: unknown): void; close(reason?: unknown): void;
}; };
export type LogLevel = "debug" | "info" | "warn" | "error" | "fatal";
export interface Parameters extends Readonly<Partial<Record<string, string>>> {}
export interface Transaction extends Result, AsyncDisposable { export interface Transaction extends Result, AsyncDisposable {
readonly open: boolean; readonly open: boolean;
commit(): Promise<Result>; commit(): Promise<Result>;
@ -478,15 +505,11 @@ export interface Channel
unlisten(): Promise<Result>; unlisten(): Promise<Result>;
} }
export async function wire_connect(options: WireOptions) {
const wire = new Wire(options);
return await wire.connect(), wire;
}
export class Wire<V extends WireEvents = WireEvents> export class Wire<V extends WireEvents = WireEvents>
extends TypedEmitter<V> extends TypedEmitter<V>
implements Disposable implements Disposable
{ {
readonly #options;
readonly #params; readonly #params;
readonly #connect; readonly #connect;
readonly #query; readonly #query;
@ -509,11 +532,11 @@ export class Wire<V extends WireEvents = WireEvents>
listen: this.#listen, listen: this.#listen,
notify: this.#notify, notify: this.#notify,
close: this.#close, close: this.#close,
} = wire_impl(this, options)); } = wire_impl(this, (this.#options = options)));
} }
connect() { async connect() {
return this.#connect(); return await this.#connect(), this;
} }
query<T = Row>(sql: SqlFragment): Query<T>; query<T = Row>(sql: SqlFragment): Query<T>;
@ -855,16 +878,18 @@ function wire_impl(
read_pop = channel.receiver((push) => read_socket(s, push)); read_pop = channel.receiver((push) => read_socket(s, push));
write_push = channel.sender((pop) => write_socket(s, pop)); write_push = channel.sender((pop) => write_socket(s, pop));
await handle_auth(); // run auth with rw lock await handle_auth(); // run auth with rw lock
(connected = true), (should_reconnect = reconnect_delay !== 0); (connected = true), (should_reconnect = reconnect_delay !== null);
wire.emit("connect");
} catch (e) { } catch (e) {
throw (close(e), e); throw (close(e), e);
} }
} }
function reconnect() { function reconnect() {
if (should_reconnect) return;
connect().catch((err) => { connect().catch((err) => {
log("warn", err as Error, `reconnect failed`); log("warn", err as Error, `reconnect failed`);
setTimeout(reconnect, reconnect_delay); if (reconnect_delay !== null) setTimeout(reconnect, reconnect_delay);
}); });
} }
@ -882,7 +907,7 @@ function wire_impl(
delete (params as Record<string, string>)[name]; delete (params as Record<string, string>)[name];
st_cache.clear(), (st_ids = 0); st_cache.clear(), (st_ids = 0);
(tx_status = "I"), (tx_stack.length = 0); (tx_status = "I"), (tx_stack.length = 0);
should_reconnect &&= (setTimeout(reconnect, reconnect_delay), false); should_reconnect &&= (reconnect(), false);
wire.emit("close", reason); wire.emit("close", reason);
} }
@ -1063,9 +1088,7 @@ function wire_impl(
const cbind_data = ``; const cbind_data = ``;
const cbind_input = `${gs2_header}${cbind_data}`; const cbind_input = `${gs2_header}${cbind_data}`;
const channel_binding = `c=${to_base64(cbind_input)}`; const channel_binding = `c=${to_base64(cbind_input)}`;
const initial_nonce = `r=${to_base64( const initial_nonce = `r=${randstr(20)}`;
crypto.getRandomValues(new Uint8Array(18))
)}`;
const client_first_message_bare = `${username},${initial_nonce}`; const client_first_message_bare = `${username},${initial_nonce}`;
const client_first_message = `${gs2_header}${client_first_message_bare}`; const client_first_message = `${gs2_header}${client_first_message_bare}`;
write(SASLInitialResponse, { mechanism, data: client_first_message }); write(SASLInitialResponse, { mechanism, data: client_first_message });
@ -1550,10 +1573,17 @@ function wire_impl(
return { params, connect, query, begin, listen, notify, close }; return { params, connect, query, begin, listen, notify, close };
} }
export interface PoolOptions extends WireOptions { export type PoolOptions = v.Infer<typeof PoolOptions>;
max_connections: number; export const PoolOptions = WireOptions.extend({
idle_timeout: number; max_connections: v
} .number()
.optional(() => 10)
.assert(Number.isSafeInteger, `invalid number`),
idle_timeout: v
.number()
.optional(() => 30)
.assert(Number.isSafeInteger, `invalid number`),
});
export type PoolEvents = { export type PoolEvents = {
log(level: LogLevel, ctx: object, msg: string): void; log(level: LogLevel, ctx: object, msg: string): void;
@ -1706,12 +1736,11 @@ function pool_impl(
}; };
async function connect() { async function connect() {
const wire = new PoolWire(options); const wire = new PoolWire({ ...options, reconnect_delay: null });
await wire.connect(), all.add(wire); const { connection_id } = wire
const { connection_id } = wire;
return wire
.on("log", (l, c, s) => pool.emit("log", l, { ...c, connection_id }, s)) .on("log", (l, c, s) => pool.emit("log", l, { ...c, connection_id }, s))
.on("close", () => forget(wire)); .on("close", () => forget(wire));
return await wire.connect(), all.add(wire), wire;
} }
async function acquire() { async function acquire() {