Compare commits

...

6 Commits

9 changed files with 501 additions and 279 deletions

View File

@ -14,9 +14,9 @@ The glue for TypeScript to PostgreSQL.
## Installation ## Installation
```ts ```ts
import pglue from "https://git.lua.re/luaneko/pglue/raw/tag/v0.3.0/mod.ts"; import pglue from "https://git.lua.re/luaneko/pglue/raw/tag/v0.3.1/mod.ts";
// ...or from github: // ...or from github:
import pglue from "https://raw.githubusercontent.com/luaneko/pglue/refs/tags/v0.3.0/mod.ts"; import pglue from "https://raw.githubusercontent.com/luaneko/pglue/refs/tags/v0.3.1/mod.ts";
``` ```
## Documentation ## Documentation

View File

@ -1,5 +1,5 @@
{ {
"name": "@luaneko/pglue", "name": "@luaneko/pglue",
"version": "0.3.0", "version": "0.3.1",
"exports": "./mod.ts" "exports": "./mod.ts"
} }

8
deno.lock generated
View File

@ -465,6 +465,12 @@
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/events.ts": "28d395b8eea87f9bf7908a44b351d2d3c609ba7eab62bcecd0d43be8ee603438", "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/events.ts": "28d395b8eea87f9bf7908a44b351d2d3c609ba7eab62bcecd0d43be8ee603438",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/func.ts": "f1935f673365cd68939531d65ef18fe81b5d43dc795b03c34bb5ad821ab1c9ff", "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/func.ts": "f1935f673365cd68939531d65ef18fe81b5d43dc795b03c34bb5ad821ab1c9ff",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/jit.ts": "c1db7820de95c48521b057c7cdf9aa41f7eaba77462407c29d3932e7da252d53", "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/jit.ts": "c1db7820de95c48521b057c7cdf9aa41f7eaba77462407c29d3932e7da252d53",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/mod.ts": "95d8b15048a54cb82391825831f695b74e7c8b206317264a99c906ce25c63f13" "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/mod.ts": "95d8b15048a54cb82391825831f695b74e7c8b206317264a99c906ce25c63f13",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/async.ts": "20bc54c7260c2d2cd27ffcca33b903dde57a3a3635386d8e0c6baca4b253ae4e",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/bytes.ts": "94f4809b375800bb2c949e31082dfdf08d022db56c5b5c9c7dfe6f399285da6f",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/events.ts": "28d395b8eea87f9bf7908a44b351d2d3c609ba7eab62bcecd0d43be8ee603438",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/func.ts": "f1935f673365cd68939531d65ef18fe81b5d43dc795b03c34bb5ad821ab1c9ff",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/jit.ts": "c1db7820de95c48521b057c7cdf9aa41f7eaba77462407c29d3932e7da252d53",
"https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/mod.ts": "589763be8ab18e7d6c5f5921e74ab44580f466c92acead401b2903d42d94112a"
} }
} }

View File

@ -1 +1 @@
export * from "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.0/mod.ts"; export * from "https://git.lua.re/luaneko/lstd/raw/tag/v0.2.1/mod.ts";

106
mod.ts
View File

@ -1,15 +1,6 @@
import pg_conn_string from "npm:pg-connection-string@^2.7.0"; import pg_conn_str from "npm:pg-connection-string@^2.7.0";
import { import type * as v from "./valita.ts";
type Infer, import { Pool, PoolOptions, Wire, WireOptions } from "./wire.ts";
number,
object,
record,
string,
union,
unknown,
} from "./valita.ts";
import { Pool, wire_connect } from "./wire.ts";
import { sql_types, type SqlTypeMap } from "./query.ts";
export { export {
WireError, WireError,
@ -33,85 +24,56 @@ export {
type RowStream, type RowStream,
} from "./query.ts"; } from "./query.ts";
export type Options = { export default function postgres(s: string, options: Partial<Options> = {}) {
host?: string; return new Postgres(Options.parse(parse_conn(s, options), { mode: "strip" }));
port?: number | string; }
user?: string;
password?: string;
database?: string | null;
max_connections?: number;
idle_timeout?: number;
runtime_params?: Record<string, string>;
types?: SqlTypeMap;
};
type ParsedOptions = Infer<typeof ParsedOptions>; function parse_conn(s: string, options: Partial<WireOptions>) {
const ParsedOptions = object({
host: string().optional(() => "localhost"),
port: union(
number(),
string().map((s) => parseInt(s, 10))
).optional(() => 5432),
user: string().optional(() => "postgres"),
password: string().optional(() => "postgres"),
database: string()
.nullable()
.optional(() => null),
runtime_params: record(string()).optional(() => ({})),
max_connections: number().optional(() => 10),
idle_timeout: number().optional(() => 20),
reconnect_delay: number().optional(() => 5),
types: record(unknown())
.optional(() => ({}))
.map((types): SqlTypeMap => ({ ...sql_types, ...types })),
});
function parse_opts(s: string, opts: Options) {
const { const {
host, host,
port, port,
user, user,
password, password,
database, database,
ssl: _ssl, // TODO: ssl: _ssl, // TODO: ssl support
...runtime_params ...runtime_params
} = pg_conn_string.parse(s); } = s ? pg_conn_str.parse(s) : {};
const { PGHOST, PGPORT, PGUSER, PGPASSWORD, PGDATABASE, USER } = return {
Deno.env.toObject(); ...options,
host: options.host ?? host,
return ParsedOptions.parse({ port: options.port ?? port,
...opts, user: options.user ?? user,
host: opts.host ?? host ?? PGHOST ?? undefined, password: options.password ?? password,
port: opts.port ?? port ?? PGPORT ?? undefined, database: options.database ?? database,
user: opts.user ?? user ?? PGUSER ?? USER ?? undefined, runtime_params: { ...runtime_params, ...options.runtime_params },
password: opts.password ?? password ?? PGPASSWORD ?? undefined, };
database: opts.database ?? database ?? PGDATABASE ?? undefined,
runtime_params: { ...runtime_params, ...opts.runtime_params },
});
}
export default function postgres(s: string, options: Options = {}) {
return new Postgres(parse_opts(s, options));
}
export function connect(s: string, options: Options = {}) {
return wire_connect(parse_opts(s, options));
} }
postgres.connect = connect; postgres.connect = connect;
export async function connect(s: string, options: Partial<WireOptions> = {}) {
return await new Wire(
WireOptions.parse(parse_conn(s, options), { mode: "strip" })
).connect();
}
export type Options = v.Infer<typeof Options>;
export const Options = PoolOptions;
export class Postgres extends Pool { export class Postgres extends Pool {
readonly #options; readonly #options;
constructor(options: ParsedOptions) { constructor(options: Options) {
super(options); super(options);
this.#options = options; this.#options = options;
} }
async connect(options: Options = {}) { async connect(options: Partial<WireOptions> = {}) {
const opts = ParsedOptions.parse({ ...this.#options, ...options }); return await new Wire(
const wire = await wire_connect(opts); WireOptions.parse({ ...this.#options, ...options }, { mode: "strip" })
return wire.on("log", (l, c, s) => this.emit("log", l, c, s)); )
.on("log", (l, c, s) => this.emit("log", l, c, s))
.connect();
} }
} }

View File

@ -1,4 +1,4 @@
import type { ObjectType } from "./valita.ts"; import type * as v from "./valita.ts";
import { from_hex, to_hex, to_utf8 } from "./lstd.ts"; import { from_hex, to_hex, to_utf8 } from "./lstd.ts";
export const sql_format = Symbol.for(`re.lua.pglue.sql_format`); export const sql_format = Symbol.for(`re.lua.pglue.sql_format`);
@ -168,6 +168,23 @@ export const text: SqlType = {
}, },
}; };
export const char: SqlType = {
input(c) {
const n = c.charCodeAt(0);
if (c.length === 1 && 0 <= n && n <= 255) return c;
throw new SqlTypeError(`invalid char input '${c}'`);
},
output(x) {
let c: string;
if (typeof x === "undefined" || x === null) return null;
else if (typeof x === "number") c = String.fromCharCode(x);
else c = String(x);
const n = c.charCodeAt(0);
if (c.length === 1 && 0 <= n && n <= 255) return c;
else throw new SqlTypeError(`invalid char output '${x}'`);
},
};
export const int2: SqlType = { export const int2: SqlType = {
input(s) { input(s) {
const n = Number(s); const n = Number(s);
@ -201,6 +218,22 @@ export const int4: SqlType = {
}, },
}; };
export const uint4: SqlType = {
input(s) {
const n = Number(s);
if (Number.isInteger(n) && 0 <= n && n <= 4294967295) return n;
else throw new SqlTypeError(`invalid uint4 input '${s}'`);
},
output(x) {
let n: number;
if (typeof x === "undefined" || x === null) return null;
else if (typeof x === "number") n = x;
else n = Number(x);
if (Number.isInteger(n) && 0 <= n && n <= 4294967295) return n.toString();
else throw new SqlTypeError(`invalid uint4 output '${x}'`);
},
};
export const int8: SqlType = { export const int8: SqlType = {
input(s) { input(s) {
const n = BigInt(s); const n = BigInt(s);
@ -214,14 +247,36 @@ export const int8: SqlType = {
else if (typeof x === "number" || typeof x === "bigint") n = x; else if (typeof x === "number" || typeof x === "bigint") n = x;
else if (typeof x === "string") n = BigInt(x); else if (typeof x === "string") n = BigInt(x);
else n = Number(x); else n = Number(x);
if (Number.isInteger(n)) { if (
if (-9007199254740991 <= n && n <= 9007199254740991) return n.toString(); (typeof n === "number" && Number.isSafeInteger(n)) ||
else throw new SqlTypeError(`unsafe int8 output '${x}'`); (typeof n === "bigint" &&
} else if (typeof n === "bigint") { -9223372036854775808n <= n &&
if (-9223372036854775808n <= n && n <= 9223372036854775807n) n <= 9223372036854775807n)
return n.toString(); ) {
} return n.toString();
throw new SqlTypeError(`invalid int8 output '${x}'`); } else throw new SqlTypeError(`invalid int8 output '${x}'`);
},
};
export const uint8: SqlType = {
input(s) {
const n = BigInt(s);
if (0n <= n && n <= 9007199254740991n) return Number(n);
else if (0n <= n && n <= 18446744073709551615n) return n;
else throw new SqlTypeError(`invalid uint8 input '${s}'`);
},
output(x) {
let n: number | bigint;
if (typeof x === "undefined" || x === null) return null;
else if (typeof x === "number" || typeof x === "bigint") n = x;
else if (typeof x === "string") n = BigInt(x);
else n = Number(x);
if (
(typeof n === "number" && Number.isSafeInteger(n) && 0 <= n) ||
(typeof n === "bigint" && 0n <= n && n <= 18446744073709551615n)
) {
return n.toString();
} else throw new SqlTypeError(`invalid uint8 output '${x}'`);
}, },
}; };
@ -305,20 +360,26 @@ export const json: SqlType = {
}; };
export const sql_types: SqlTypeMap = { export const sql_types: SqlTypeMap = {
0: text,
16: bool, // bool 16: bool, // bool
25: text, // text 17: bytea, // bytea
18: char, // char
19: text, // name
20: int8, // int8
21: int2, // int2 21: int2, // int2
23: int4, // int4 23: int4, // int4
20: int8, // int8 25: text, // text
26: int8, // oid 26: uint4, // oid
28: uint4, // xid
29: uint4, // cid
114: json, // json
700: float4, // float4 700: float4, // float4
701: float8, // float8 701: float8, // float8
1082: timestamptz, // date 1082: timestamptz, // date
1114: timestamptz, // timestamp 1114: timestamptz, // timestamp
1184: timestamptz, // timestamptz 1184: timestamptz, // timestamptz
17: bytea, // bytea
114: json, // json
3802: json, // jsonb 3802: json, // jsonb
5069: uint8, // xid8
}; };
sql.types = sql_types; sql.types = sql_types;
@ -409,7 +470,7 @@ export class Query<T = Row> implements PromiseLike<Rows<T>>, RowStream<T> {
}); });
} }
parse<S extends ObjectType>( parse<S extends v.ObjectType>(
type: S, type: S,
{ mode = "strip" }: { mode?: "passthrough" | "strict" | "strip" } = {} { mode = "strip" }: { mode?: "passthrough" | "strict" | "strip" } = {}
) { ) {

2
ser.ts
View File

@ -11,7 +11,7 @@ import {
write_i8, write_i8,
} from "./lstd.ts"; } from "./lstd.ts";
export class EncoderError extends Error { export class EncoderError extends TypeError {
override get name() { override get name() {
return this.constructor.name; return this.constructor.name;
} }

31
test.ts
View File

@ -2,18 +2,15 @@ import pglue, { PostgresError, SqlTypeError } from "./mod.ts";
import { expect } from "jsr:@std/expect"; import { expect } from "jsr:@std/expect";
import { toText } from "jsr:@std/streams"; import { toText } from "jsr:@std/streams";
async function connect(params?: Record<string, string>) { const pool = pglue(`postgres://test:test@localhost:5432/test`, {
const pg = await pglue.connect(`postgres://test:test@localhost:5432/test`, { runtime_params: { client_min_messages: "INFO" },
runtime_params: { client_min_messages: "INFO", ...params }, verbose: true,
}); });
return pg.on("log", (_level, ctx, msg) => { pool.on("log", (level, ctx, msg) => console.info(`${level}: ${msg}`, ctx));
console.info(`${msg}`, ctx);
});
}
Deno.test(`integers`, async () => { Deno.test(`integers`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await using _tx = await pg.begin(); await using _tx = await pg.begin();
const { a, b, c } = await pg.query` const { a, b, c } = await pg.query`
@ -44,7 +41,7 @@ Deno.test(`integers`, async () => {
}); });
Deno.test(`boolean`, async () => { Deno.test(`boolean`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await using _tx = await pg.begin(); await using _tx = await pg.begin();
const { a, b, c } = await pg.query` const { a, b, c } = await pg.query`
@ -60,7 +57,7 @@ Deno.test(`boolean`, async () => {
}); });
Deno.test(`bytea`, async () => { Deno.test(`bytea`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await using _tx = await pg.begin(); await using _tx = await pg.begin();
const { string, array, buffer } = await pg.query` const { string, array, buffer } = await pg.query`
@ -76,7 +73,7 @@ Deno.test(`bytea`, async () => {
}); });
Deno.test(`row`, async () => { Deno.test(`row`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await using _tx = await pg.begin(); await using _tx = await pg.begin();
expect( expect(
@ -119,7 +116,7 @@ Deno.test(`row`, async () => {
}); });
Deno.test(`sql injection`, async () => { Deno.test(`sql injection`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await using _tx = await pg.begin(); await using _tx = await pg.begin();
const input = `injection'); drop table users; --`; const input = `injection'); drop table users; --`;
@ -140,7 +137,7 @@ Deno.test(`sql injection`, async () => {
}); });
Deno.test(`listen/notify`, async () => { Deno.test(`listen/notify`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
const sent: string[] = []; const sent: string[] = [];
await using ch = await pg.listen(`my channel`, (payload) => { await using ch = await pg.listen(`my channel`, (payload) => {
@ -157,7 +154,7 @@ Deno.test(`listen/notify`, async () => {
}); });
Deno.test(`transactions`, async () => { Deno.test(`transactions`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await pg.begin(async (pg) => { await pg.begin(async (pg) => {
await pg.begin(async (pg, tx) => { await pg.begin(async (pg, tx) => {
@ -192,7 +189,7 @@ Deno.test(`transactions`, async () => {
}); });
Deno.test(`streaming`, async () => { Deno.test(`streaming`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await using _tx = await pg.begin(); await using _tx = await pg.begin();
await pg.query`create table my_table (field text not null)`; await pg.query`create table my_table (field text not null)`;
@ -211,7 +208,7 @@ Deno.test(`streaming`, async () => {
}); });
Deno.test(`simple`, async () => { Deno.test(`simple`, async () => {
await using pg = await connect(); await using pg = await pool.connect();
await using _tx = await pg.begin(); await using _tx = await pg.begin();
const rows = await pg.query` const rows = await pg.query`

534
wire.ts
View File

@ -1,3 +1,5 @@
import * as v from "./valita.ts";
import { join } from "jsr:@std/path@^1.0.8";
import { import {
type BinaryLike, type BinaryLike,
buf_concat, buf_concat,
@ -11,7 +13,9 @@ import {
type Receiver, type Receiver,
semaphore, semaphore,
type Sender, type Sender,
to_base58,
to_base64, to_base64,
to_hex,
to_utf8, to_utf8,
TypedEmitter, TypedEmitter,
} from "./lstd.ts"; } from "./lstd.ts";
@ -45,8 +49,8 @@ import {
type SqlFragment, type SqlFragment,
type SqlTypeMap, type SqlTypeMap,
text, text,
sql_types,
} from "./query.ts"; } from "./query.ts";
import { join } from "jsr:@std/path@^1.0.8";
export class WireError extends Error { export class WireError extends Error {
override get name() { override get name() {
@ -437,29 +441,54 @@ export const StartupMessage = msg("", {
export const Sync = msg("S", {}); export const Sync = msg("S", {});
export const Terminate = msg("X", {}); export const Terminate = msg("X", {});
export type LogLevel = "debug" | "info" | "warn" | "error" | "fatal"; function getenv(name: string) {
return Deno.env.get(name);
export interface Parameters extends Readonly<Partial<Record<string, string>>> {}
export interface WireOptions {
readonly host: string;
readonly port: number;
readonly user: string;
readonly password: string;
readonly database: string | null;
readonly runtime_params: Record<string, string>;
readonly reconnect_delay: number;
readonly types: SqlTypeMap;
} }
export type WireOptions = v.Infer<typeof WireOptions>;
export const WireOptions = v.object({
host: v.string().optional(() => getenv("PGHOST") ?? "localhost"),
port: v
.union(v.string(), v.number())
.optional(() => getenv("PGPORT") ?? 5432)
.map(Number)
.assert(Number.isSafeInteger, `invalid number`),
user: v
.string()
.optional(() => getenv("PGUSER") ?? getenv("USER") ?? "postgres"),
password: v.string().optional(() => getenv("PGPASSWORD") ?? "postgres"),
database: v
.string()
.nullable()
.optional(() => getenv("PGDATABASE") ?? null),
runtime_params: v
.record(v.string())
.map((p) => ((p.application_name ??= "pglue"), p)),
reconnect_delay: v
.number()
.optional(() => 5)
.assert(Number.isSafeInteger, `invalid number`)
.nullable(),
types: v
.record(v.unknown())
.optional(() => ({}))
.map((types): SqlTypeMap => ({ ...sql_types, ...types })),
verbose: v.boolean().optional(() => false),
});
export type WireEvents = { export type WireEvents = {
log(level: LogLevel, ctx: object, msg: string): void; log(level: LogLevel, ctx: object, msg: string): void;
connect(): void;
notice(notice: PostgresError): void; notice(notice: PostgresError): void;
notify(channel: string, payload: string, process_id: number): void; notify(channel: string, payload: string, process_id: number): void;
parameter(name: string, value: string, prev: string | null): void; parameter(name: string, value: string, prev: string | null): void;
close(reason?: unknown): void; close(reason?: unknown): void;
}; };
export type LogLevel = "trace" | "debug" | "info" | "warn" | "error" | "fatal";
export interface Parameters extends Readonly<Partial<Record<string, string>>> {}
export interface Transaction extends Result, AsyncDisposable { export interface Transaction extends Result, AsyncDisposable {
readonly open: boolean; readonly open: boolean;
commit(): Promise<Result>; commit(): Promise<Result>;
@ -478,15 +507,11 @@ export interface Channel
unlisten(): Promise<Result>; unlisten(): Promise<Result>;
} }
export async function wire_connect(options: WireOptions) {
const wire = new Wire(options);
return await wire.connect(), wire;
}
export class Wire<V extends WireEvents = WireEvents> export class Wire<V extends WireEvents = WireEvents>
extends TypedEmitter<V> extends TypedEmitter<V>
implements Disposable implements Disposable
{ {
readonly #options;
readonly #params; readonly #params;
readonly #connect; readonly #connect;
readonly #query; readonly #query;
@ -509,11 +534,11 @@ export class Wire<V extends WireEvents = WireEvents>
listen: this.#listen, listen: this.#listen,
notify: this.#notify, notify: this.#notify,
close: this.#close, close: this.#close,
} = wire_impl(this, options)); } = wire_impl(this, (this.#options = options)));
} }
connect() { async connect() {
return this.#connect(); return await this.#connect(), this;
} }
query<T = Row>(sql: SqlFragment): Query<T>; query<T = Row>(sql: SqlFragment): Query<T>;
@ -545,18 +570,163 @@ export class Wire<V extends WireEvents = WireEvents>
return this.#notify(channel, payload); return this.#notify(channel, payload);
} }
async get(param: string) { async current_setting(name: string) {
return await this.query`select current_setting(${param}, true)` return await this.query<
.map(([s]) => String(s)) [string]
>`select current_setting(${name}::text, true)`
.map(([x]) => x)
.first_or(null); .first_or(null);
} }
async set(param: string, value: string, local = false) { async set_config(name: string, value: string, local = false) {
return await this.query`select set_config(${param}, ${value}, ${local})` return await this.query<
.map(([s]) => String(s)) [string]
>`select set_config(${name}::text, ${value}::text, ${local}::boolean)`
.map(([x]) => x)
.first(); .first();
} }
async cancel_backend(pid: number) {
return await this.query<
[boolean]
>`select pg_cancel_backend(${pid}::integer)`
.map(([x]) => x)
.first();
}
async terminate_backend(pid: number, timeout = 0) {
return await this.query<
[boolean]
>`select pg_terminate_backend(${pid}::integer, ${timeout}::bigint)`
.map(([x]) => x)
.first();
}
async inet() {
return await this.query<{
client_addr: string;
client_port: number;
server_addr: string;
server_port: number;
}>`
select
inet_client_addr() as client_addr,
inet_client_port() as client_port,
inet_server_addr() as server_addr,
inet_server_port() as server_por
`.first();
}
async listening_channels() {
return await this.query<[string]>`select pg_listening_channels()`
.map(([x]) => x)
.collect();
}
async notification_queue_usage() {
return await this.query<[number]>`select pg_notification_queue_usage()`
.map(([x]) => x)
.first();
}
async postmaster_start_time() {
return await this.query<[Date]>`select pg_postmaster_start_time()`
.map(([x]) => x)
.first();
}
async current_wal() {
return await this.query<{
lsn: string;
insert_lsn: string;
flush_lsn: string;
}>`
select
pg_current_wal_lsn() as lsn,
pg_current_wal_insert_lsn() as insert_lsn,
pg_current_wal_flush_lsn() as flush_lsn
`.first();
}
async switch_wal() {
return await this.query<[string]>`select pg_switch_wal()`
.map(([x]) => x)
.first();
}
async nextval(seq: string) {
return await this.query<[number | bigint]>`select nextval(${seq}::regclass)`
.map(([x]) => x)
.first();
}
async setval(seq: string, value: number | bigint, is_called = true) {
return await this.query<
[number | bigint]
>`select setval(${seq}::regclass, ${value}::bigint, ${is_called}::boolean)`
.map(([x]) => x)
.first();
}
async currval(seq: string) {
return await this.query<[number]>`select currval(${seq}::regclass)`
.map(([x]) => x)
.first();
}
async lastval() {
return await this.query<[number]>`select lastval()`.map(([x]) => x).first();
}
async validate_input(s: string, type: string) {
return await this.query<{
message: string | null;
detail: string | null;
hint: string | null;
sql_error_code: string | null;
}>`select * from pg_input_error_info(${s}::text, ${type}::text)`.first();
}
async current_xid() {
return await this.query<[number | bigint]>`select pg_current_xact_id()`
.map(([x]) => x)
.first();
}
async current_xid_if_assigned() {
return await this.query<
[number | bigint | null]
>`select pg_current_xact_id_if_assigned()`
.map(([x]) => x)
.first();
}
async xact_info(xid: number | bigint) {
return await this.query<{
status: "progress" | "committed" | "aborted";
age: number;
mxid_age: number;
}>`
select
pg_xact_status(${xid}::xid8) as status,
age(${xid}::xid) as age,
mxid_age(${xid}::xid) as mxid_age
`;
}
async version() {
return await this.query<{
postgres: string;
unicode: string;
icu_unicode: string | null;
}>`
select
version() as postgres,
unicode_version() as unicode,
icu_unicode_version() as icu_unicode
`.first();
}
close(reason?: unknown) { close(reason?: unknown) {
this.#close(reason); this.#close(reason);
} }
@ -566,6 +736,10 @@ export class Wire<V extends WireEvents = WireEvents>
} }
} }
function randstr(entropy: number) {
return to_base58(crypto.getRandomValues(new Uint8Array(entropy)));
}
async function socket_connect(hostname: string, port: number) { async function socket_connect(hostname: string, port: number) {
if (hostname.startsWith("/")) { if (hostname.startsWith("/")) {
const path = join(hostname, `.s.PGSQL.${port}`); const path = join(hostname, `.s.PGSQL.${port}`);
@ -587,59 +761,105 @@ function wire_impl(
runtime_params, runtime_params,
reconnect_delay, reconnect_delay,
types, types,
verbose,
}: WireOptions }: WireOptions
) { ) {
// current runtime parameters as reported by postgres
const params: Parameters = Object.create(null); const params: Parameters = Object.create(null);
function log(level: LogLevel, ctx: object, msg: string) { function log(level: LogLevel, ctx: object, msg: string) {
wire.emit("log", level, ctx, msg); wire.emit("log", level, ctx, msg);
} }
// wire supports re-connection; socket and read/write channels are null when closed
let connected = false; let connected = false;
let should_reconnect = false; let close_requested = false;
let socket: Deno.Conn | null = null; let read_queue: Receiver<Uint8Array> | null = null;
let read_pop: Receiver<Uint8Array> | null = null; let write_queue: Sender<Uint8Array> | null = null;
let write_push: Sender<Uint8Array> | null = null;
async function connect() {
using _rlock = await rlock();
using _wlock = await wlock();
if (connected) return;
else close_requested = false;
let socket: Deno.Conn | undefined;
let closed = false;
try {
const read = channel<Uint8Array>();
const write = channel<Uint8Array>();
socket = await socket_connect(host, port);
read_queue?.close(), (read_queue = read.recv);
write_queue?.close(), (write_queue = write.send);
read_socket(socket, read.send).then(onclose, onclose);
write_socket(socket, write.recv).then(onclose, onclose);
await handle_auth(); // run auth with rw lock
if (close_requested) throw new WireError(`close requested`);
else (connected = true), wire.emit("connect");
} catch (e) {
throw (onclose(e), e);
}
function onclose(reason?: unknown) {
if (closed) return;
else closed = true;
socket?.close();
for (const name of Object.keys(params))
delete (params as Record<string, string>)[name];
st_cache.clear(), (st_ids = 0);
(tx_status = "I"), (tx_stack.length = 0);
connected &&= (wire.emit("close", reason), reconnect(), false);
}
}
let reconnect_timer = -1;
function reconnect() {
if (close_requested || reconnect_delay === null) return;
connect().catch((err) => {
log("warn", err, `reconnect failed`);
clearTimeout(reconnect_timer);
reconnect_timer = setTimeout(reconnect, reconnect_delay);
});
}
function close(reason?: unknown) {
close_requested = true;
clearTimeout(reconnect_timer);
read_queue?.close(reason), (read_queue = null);
write_queue?.close(reason), (write_queue = null);
}
async function read<T>(type: Encoder<T>) { async function read<T>(type: Encoder<T>) {
const msg = read_pop !== null ? await read_pop() : null; const msg = read_queue !== null ? await read_queue() : null;
if (msg !== null) return ser_decode(type, msg_check_err(msg)); if (msg !== null) return ser_decode(type, msg_check_err(msg));
else throw new WireError(`connection closed`); else throw new WireError(`connection closed`);
} }
async function read_msg() { async function read_any() {
const msg = read_pop !== null ? await read_pop() : null; const msg = read_queue !== null ? await read_queue() : null;
if (msg !== null) return msg; if (msg !== null) return msg;
else throw new WireError(`connection closed`); else throw new WireError(`connection closed`);
} }
async function read_socket(socket: Deno.Conn, push: Sender<Uint8Array>) { async function read_socket(socket: Deno.Conn, send: Sender<Uint8Array>) {
let err; const header_size = 5;
try { const read_buf = new Uint8Array(64 * 1024); // shared buffer for all socket reads
const header_size = 5; let buf = new Uint8Array(); // concatenated messages read so far
const read_buf = new Uint8Array(64 * 1024); // shared buffer for all socket reads
let buf = new Uint8Array(); // concatenated messages read so far
for (let read; (read = await socket.read(read_buf)) !== null; ) { for (let read; (read = await socket.read(read_buf)) !== null; ) {
buf = buf_concat_fast(buf, read_buf.subarray(0, read)); // push read bytes to buf buf = buf_concat_fast(buf, read_buf.subarray(0, read)); // push read bytes to buf
while (buf.length >= header_size) { while (buf.length >= header_size) {
const size = ser_decode(Header, buf).length + 1; const size = ser_decode(Header, buf).length + 1;
if (buf.length < size) break; if (buf.length < size) break;
const msg = buf.subarray(0, size); // shift one message from buf const msg = buf.subarray(0, size); // shift one message from buf
buf = buf.subarray(size); buf = buf.subarray(size);
if (!handle_msg(msg)) push(msg); if (verbose)
} log("trace", {}, `RECV <- ${msg_type(msg)} ${to_hex(msg)}`);
if (!handle_msg(msg)) send(msg);
} }
// there should be nothing left in buf if we gracefully exited
if (buf.length !== 0) throw new WireError(`unexpected end of stream`);
} catch (e) {
throw (err = e);
} finally {
onclose(err);
} }
// there should be nothing left in buf if we gracefully exited
if (buf.length !== 0) throw new WireError(`unexpected end of stream`);
} }
function handle_msg(msg: Uint8Array) { function handle_msg(msg: Uint8Array) {
@ -674,77 +894,31 @@ function wire_impl(
wire.emit("parameter", name, value, prev); wire.emit("parameter", name, value, prev);
return true; return true;
} }
}
return false; default:
return false;
}
} }
function write<T>(type: Encoder<T>, value: T) { function write<T>(type: Encoder<T>, value: T) {
write_msg(ser_encode(type, value)); if (write_queue !== null) write_queue(ser_encode(type, value));
}
function write_msg(buf: Uint8Array) {
if (write_push !== null) write_push(buf);
else throw new WireError(`connection closed`); else throw new WireError(`connection closed`);
} }
async function write_socket(socket: Deno.Conn, pop: Receiver<Uint8Array>) { async function write_socket(socket: Deno.Conn, recv: Receiver<Uint8Array>) {
let err; for (let buf; (buf = await recv()) !== null; ) {
try { const msgs = [buf]; // proactively dequeue more queued msgs synchronously, if any
for (let buf; (buf = await pop()) !== null; ) { for (let i = 1, buf; (buf = recv.try()) !== null; ) msgs[i++] = buf;
const bufs = [buf]; // proactively dequeue more queued msgs synchronously, if any if (verbose) {
for (let i = 1, buf; (buf = pop.try()) !== null; ) bufs[i++] = buf; for (const msg of msgs)
if (bufs.length !== 1) buf = buf_concat(bufs); // write queued msgs concatenated, reduce write syscalls log("trace", {}, `SEND -> ${msg_type(msg)} ${to_hex(msg)}`);
for (let i = 0, n = buf.length; i < n; )
i += await socket.write(buf.subarray(i));
} }
} catch (e) { if (msgs.length !== 1) buf = buf_concat(msgs); // write queued msgs concatenated, reduce write syscalls
throw (err = e); for (let i = 0, n = buf.length; i < n; )
} finally { i += await socket.write(buf.subarray(i));
onclose(err);
} }
} }
async function connect() {
using _rlock = await rlock();
using _wlock = await wlock();
if (connected) return;
try {
const s = (socket = await socket_connect(host, port));
read_pop = channel.receiver((push) => read_socket(s, push));
write_push = channel.sender((pop) => write_socket(s, pop));
await handle_auth(); // run auth with rw lock
(connected = true), (should_reconnect = reconnect_delay !== 0);
} catch (e) {
throw (close(e), e);
}
}
function reconnect() {
connect().catch((err) => {
log("warn", err as Error, `reconnect failed`);
setTimeout(reconnect, reconnect_delay);
});
}
function close(reason?: unknown) {
(should_reconnect = false), onclose(reason);
}
function onclose(reason?: unknown) {
if (!connected) return;
else connected = false;
socket?.close(), (socket = null);
read_pop?.close(reason), (read_pop = null);
write_push?.close(reason), (write_push = null);
for (const name of Object.keys(params))
delete (params as Record<string, string>)[name];
st_cache.clear(), (st_ids = 0);
(tx_status = "I"), (tx_stack.length = 0);
should_reconnect &&= (setTimeout(reconnect, reconnect_delay), false);
wire.emit("close", reason);
}
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-PIPELINING // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-PIPELINING
const rlock = semaphore(); const rlock = semaphore();
const wlock = semaphore(); const wlock = semaphore();
@ -766,7 +940,7 @@ function wire_impl(
} finally { } finally {
try { try {
let msg; let msg;
while (msg_type((msg = await read_msg())) !== ReadyForQuery.type); while (msg_type((msg = await read_any())) !== ReadyForQuery.type);
({ tx_status } = ser_decode(ReadyForQuery, msg)); ({ tx_status } = ser_decode(ReadyForQuery, msg));
} catch { } catch {
// ignored // ignored
@ -805,7 +979,7 @@ function wire_impl(
}); });
auth: for (;;) { auth: for (;;) {
const msg = msg_check_err(await read_msg()); const msg = msg_check_err(await read_any());
switch (msg_type(msg)) { switch (msg_type(msg)) {
case NegotiateProtocolVersion.type: { case NegotiateProtocolVersion.type: {
const { bad_options } = ser_decode(NegotiateProtocolVersion, msg); const { bad_options } = ser_decode(NegotiateProtocolVersion, msg);
@ -849,7 +1023,7 @@ function wire_impl(
// wait for ready // wait for ready
ready: for (;;) { ready: for (;;) {
const msg = msg_check_err(await read_msg()); const msg = msg_check_err(await read_any());
switch (msg_type(msg)) { switch (msg_type(msg)) {
case BackendKeyData.type: case BackendKeyData.type:
continue; // ignored continue; // ignored
@ -922,9 +1096,7 @@ function wire_impl(
const cbind_data = ``; const cbind_data = ``;
const cbind_input = `${gs2_header}${cbind_data}`; const cbind_input = `${gs2_header}${cbind_data}`;
const channel_binding = `c=${to_base64(cbind_input)}`; const channel_binding = `c=${to_base64(cbind_input)}`;
const initial_nonce = `r=${to_base64( const initial_nonce = `r=${randstr(20)}`;
crypto.getRandomValues(new Uint8Array(18))
)}`;
const client_first_message_bare = `${username},${initial_nonce}`; const client_first_message_bare = `${username},${initial_nonce}`;
const client_first_message = `${gs2_header}${client_first_message_bare}`; const client_first_message = `${gs2_header}${client_first_message_bare}`;
write(SASLInitialResponse, { mechanism, data: client_first_message }); write(SASLInitialResponse, { mechanism, data: client_first_message });
@ -987,7 +1159,7 @@ function wire_impl(
await read(ParseComplete); await read(ParseComplete);
const ser_params = make_param_ser(await read(ParameterDescription)); const ser_params = make_param_ser(await read(ParameterDescription));
const msg = msg_check_err(await read_msg()); const msg = msg_check_err(await read_any());
const Row = const Row =
msg_type(msg) === NoData.type msg_type(msg) === NoData.type
? EmptyRow ? EmptyRow
@ -1017,7 +1189,7 @@ function wire_impl(
return jit.compiled<ParameterSerializer>`function ser_params(xs) { return jit.compiled<ParameterSerializer>`function ser_params(xs) {
return [ return [
${jit.map(", ", param_types, (type_oid, i) => { ${jit.map(", ", param_types, (type_oid, i) => {
const type = types[type_oid] ?? text; const type = types[type_oid] ?? types[0] ?? text;
return jit`${type}.output(xs[${i}])`; return jit`${type}.output(xs[${i}])`;
})} })}
]; ];
@ -1034,7 +1206,7 @@ function wire_impl(
function make_row_ctor({ columns }: RowDescription) { function make_row_ctor({ columns }: RowDescription) {
const Row = jit.compiled<RowConstructor>`function Row(xs) { const Row = jit.compiled<RowConstructor>`function Row(xs) {
${jit.map(" ", columns, ({ name, type_oid }, i) => { ${jit.map(" ", columns, ({ name, type_oid }, i) => {
const type = types[type_oid] ?? text; const type = types[type_oid] ?? types[0] ?? text;
return jit`this[${name}] = xs[${i}] === null ? null : ${type}.input(${from_utf8}(xs[${i}]));`; return jit`this[${name}] = xs[${i}] === null ? null : ${type}.input(${from_utf8}(xs[${i}]));`;
})} })}
}`; }`;
@ -1068,7 +1240,7 @@ function wire_impl(
stdout: WritableStream<Uint8Array> | null stdout: WritableStream<Uint8Array> | null
) { ) {
for (let rows = [], i = 0; ; ) { for (let rows = [], i = 0; ; ) {
const msg = msg_check_err(await read_msg()); const msg = msg_check_err(await read_any());
switch (msg_type(msg)) { switch (msg_type(msg)) {
default: default:
case DataRow.type: case DataRow.type:
@ -1094,6 +1266,7 @@ function wire_impl(
continue; continue;
case CopyOutResponse.type: case CopyOutResponse.type:
case CopyBothResponse.type:
await read_copy_out(stdout), (stdout = null); await read_copy_out(stdout), (stdout = null);
continue; continue;
} }
@ -1101,40 +1274,47 @@ function wire_impl(
} }
async function read_copy_out(stream: WritableStream<Uint8Array> | null) { async function read_copy_out(stream: WritableStream<Uint8Array> | null) {
if (stream !== null) { const writer = stream?.getWriter();
const writer = stream.getWriter(); try {
try { copy: for (;;) {
for (let msg; msg_type((msg = await read_msg())) !== CopyDone.type; ) { const msg = msg_check_err(await read_any());
const { data } = ser_decode(CopyData, msg_check_err(msg)); switch (msg_type(msg)) {
await writer.write(to_utf8(data)); default:
case CopyData.type: {
const { data } = ser_decode(CopyData, msg);
console.log(`COPY OUT`, to_hex(data));
await writer?.write(to_utf8(data));
continue;
}
case CopyDone.type:
case CommandComplete.type: // walsender sends 'C' to end of CopyBothResponse
await writer?.close();
break copy;
} }
await writer.close();
} catch (e) {
await writer.abort(e);
throw e;
} finally {
writer.releaseLock();
} }
} else { } catch (e) {
while (msg_type(msg_check_err(await read_msg())) !== CopyDone.type); await writer?.abort(e);
throw e;
} finally {
writer?.releaseLock();
} }
} }
async function write_copy_in(stream: ReadableStream<Uint8Array> | null) { async function write_copy_in(stream: ReadableStream<Uint8Array> | null) {
if (stream !== null) { const reader = stream?.getReader();
const reader = stream.getReader(); try {
try { if (reader) {
for (let next; !(next = await reader.read()).done; ) for (let next; !(next = await reader.read()).done; )
write(CopyData, { data: next.value }); write(CopyData, { data: next.value });
write(CopyDone, {});
} catch (e) {
write(CopyFail, { cause: String(e) });
throw e;
} finally {
reader.releaseLock();
} }
} else {
write(CopyDone, {}); write(CopyDone, {});
} catch (e) {
write(CopyFail, { cause: String(e) });
reader?.cancel(e);
throw e;
} finally {
reader?.releaseLock();
} }
} }
@ -1147,27 +1327,34 @@ function wire_impl(
() => { () => {
log("debug", { query }, `executing simple query`); log("debug", { query }, `executing simple query`);
write(QueryMessage, { query }); write(QueryMessage, { query });
write_copy_in(stdin); return write_copy_in(stdin);
}, },
async () => { async () => {
for (let chunks = [], err; ; ) { for (let chunks = [], err; ; ) {
const msg = await read_msg(); const msg = await read_any();
switch (msg_type(msg)) { switch (msg_type(msg)) {
default: default:
case ReadyForQuery.type: case ReadyForQuery.type:
ser_decode(ReadyForQuery, msg);
if (err) throw err; if (err) throw err;
else return chunks; else return chunks;
case RowDescription.type: { case RowDescription.type: {
const Row = make_row_ctor(ser_decode(RowDescription, msg)); const Row = make_row_ctor(ser_decode(RowDescription, msg));
const { rows } = await read_rows(Row, stdout); const { rows } = await read_rows(Row, stdout);
chunks.push(rows); chunks.push(rows), (stdout = null);
stdout = null;
continue; continue;
} }
case EmptyQueryResponse.type: case EmptyQueryResponse.type:
case CommandComplete.type: case CommandComplete.type:
case CopyInResponse.type:
case CopyDone.type:
continue;
case CopyOutResponse.type:
case CopyBothResponse.type:
await read_copy_out(stdout), (stdout = null);
continue; continue;
case ErrorResponse.type: { case ErrorResponse.type: {
@ -1307,9 +1494,10 @@ function wire_impl(
const tx_begin = query(sql`begin`); const tx_begin = query(sql`begin`);
const tx_commit = query(sql`commit`); const tx_commit = query(sql`commit`);
const tx_rollback = query(sql`rollback`); const tx_rollback = query(sql`rollback`);
const sp_savepoint = query(sql`savepoint __tx`); const sp_name = sql.ident`__pglue_tx`;
const sp_release = query(sql`release __tx`); const sp_savepoint = query(sql`savepoint ${sp_name}`);
const sp_rollback_to = query(sql`rollback to __tx`); const sp_release = query(sql`release ${sp_name}`);
const sp_rollback_to = query(sql`rollback to ${sp_name}`);
async function begin() { async function begin() {
const tx = new Transaction( const tx = new Transaction(
@ -1368,7 +1556,9 @@ function wire_impl(
} }
async function notify(channel: string, payload: string) { async function notify(channel: string, payload: string) {
return await query(sql`select pg_notify(${channel}, ${payload})`).execute(); return await query(
sql`select pg_notify(${channel}::text, ${payload}::text)`
).execute();
} }
const Channel = class extends TypedEmitter<ChannelEvents> implements Channel { const Channel = class extends TypedEmitter<ChannelEvents> implements Channel {
@ -1408,10 +1598,17 @@ function wire_impl(
return { params, connect, query, begin, listen, notify, close }; return { params, connect, query, begin, listen, notify, close };
} }
export interface PoolOptions extends WireOptions { export type PoolOptions = v.Infer<typeof PoolOptions>;
max_connections: number; export const PoolOptions = WireOptions.extend({
idle_timeout: number; max_connections: v
} .number()
.optional(() => 10)
.assert(Number.isSafeInteger, `invalid number`),
idle_timeout: v
.number()
.optional(() => 30)
.assert(Number.isSafeInteger, `invalid number`),
});
export type PoolEvents = { export type PoolEvents = {
log(level: LogLevel, ctx: object, msg: string): void; log(level: LogLevel, ctx: object, msg: string): void;
@ -1564,12 +1761,11 @@ function pool_impl(
}; };
async function connect() { async function connect() {
const wire = new PoolWire(options); const wire = new PoolWire({ ...options, reconnect_delay: null });
await wire.connect(), all.add(wire); const { connection_id } = wire
const { connection_id } = wire;
return wire
.on("log", (l, c, s) => pool.emit("log", l, { ...c, connection_id }, s)) .on("log", (l, c, s) => pool.emit("log", l, { ...c, connection_id }, s))
.on("close", () => forget(wire)); .on("close", () => forget(wire));
return await wire.connect(), all.add(wire), wire;
} }
async function acquire() { async function acquire() {