diff --git a/mod.ts b/mod.ts index a6ebf31..c21392f 100644 --- a/mod.ts +++ b/mod.ts @@ -9,7 +9,7 @@ import { unknown, } from "./valita.ts"; import { Pool, wire_connect, type LogLevel } from "./wire.ts"; -import { type FromSql, type ToSql, from_sql, to_sql } from "./sql.ts"; +import { sql_types, type SqlType, type SqlTypeMap } from "./query.ts"; export { WireError, @@ -21,13 +21,11 @@ export { } from "./wire.ts"; export { type SqlFragment, - type FromSql, - type ToSql, - SqlValue, + type SqlType, + type SqlTypeMap, + SqlTypeError, sql, is_sql, -} from "./sql.ts"; -export { Query, type Row, type CommandResult, @@ -45,8 +43,7 @@ export type Options = { max_connections?: number; idle_timeout?: number; runtime_params?: Record; - from_sql?: FromSql; - to_sql?: ToSql; + types?: SqlTypeMap; }; type ParsedOptions = Infer; @@ -64,15 +61,12 @@ const ParsedOptions = object({ runtime_params: record(string()).optional(() => ({})), max_connections: number().optional(() => 10), idle_timeout: number().optional(() => 20), - from_sql: unknown() - .assert((s): s is FromSql => typeof s === "function") - .optional(() => from_sql), - to_sql: unknown() - .assert((s): s is ToSql => typeof s === "function") - .optional(() => to_sql), + types: record(unknown()) + .optional(() => ({})) + .map((types): SqlTypeMap => ({ ...sql_types, ...types })), }); -function parse_opts(s: string, options: Options) { +function parse_opts(s: string, opts: Options) { const { host, port, @@ -87,13 +81,13 @@ function parse_opts(s: string, options: Options) { Deno.env.toObject(); return ParsedOptions.parse({ - ...options, - host: options.host ?? host ?? PGHOST ?? undefined, - port: options.port ?? port ?? PGPORT ?? undefined, - user: options.user ?? user ?? PGUSER ?? USER ?? undefined, - password: options.password ?? password ?? PGPASSWORD ?? undefined, - database: options.database ?? database ?? PGDATABASE ?? undefined, - runtime_params: { ...runtime_params, ...options.runtime_params }, + ...opts, + host: opts.host ?? host ?? PGHOST ?? undefined, + port: opts.port ?? port ?? PGPORT ?? undefined, + user: opts.user ?? user ?? PGUSER ?? USER ?? undefined, + password: opts.password ?? password ?? PGPASSWORD ?? undefined, + database: opts.database ?? database ?? PGDATABASE ?? undefined, + runtime_params: { ...runtime_params, ...opts.runtime_params }, }); } diff --git a/mod_test.ts b/mod_test.ts index fbecb5f..f55213e 100644 --- a/mod_test.ts +++ b/mod_test.ts @@ -6,8 +6,18 @@ await using pool = postgres(`postgres://test:test@localhost:5432/test`, { pool.on("log", (level, ctx, msg) => console.info(`${level}: ${msg}`, ctx)); -await pool.begin(async (pg) => { - await pg.begin(async (pg) => { - console.log(await pg.query`select * from pg_user`); - }); +await pool.begin(async (pg, tx) => { + await pg.query` + create table my_test ( + key integer primary key generated always as identity, + data text not null + ) + `; + + await pg.query` + insert into my_test (data) values (${[1, 2, 3]}::bytea) + `; + + console.log(await pg.query`select * from my_test`); + await tx.rollback(); }); diff --git a/query.ts b/query.ts index 9596447..1c7956b 100644 --- a/query.ts +++ b/query.ts @@ -1,62 +1,302 @@ import type { ObjectType } from "./valita.ts"; -import { from_utf8, jit, to_utf8 } from "./lstd.ts"; -import { type FromSql, SqlValue } from "./sql.ts"; +import { from_hex, to_hex, to_utf8 } from "./lstd.ts"; -export interface Row extends Iterable { - [column: string]: unknown; +export const sql_format = Symbol.for(`re.lua.pglue.sql_format`); + +export interface SqlFragment { + [sql_format](f: SqlFormatter): void; } -export interface RowConstructor { - new (columns: (Uint8Array | string | null)[]): Row; +export interface SqlFormatter { + query: string; + params: unknown[]; } -export interface RowDescription extends ReadonlyArray {} - -export interface ColumnDescription { - readonly name: string; - readonly table_oid: number; - readonly table_column: number; - readonly type_oid: number; - readonly type_size: number; - readonly type_modifier: number; +export function is_sql(x: unknown): x is SqlFragment { + return typeof x === "object" && x !== null && sql_format in x; } -export function row_ctor(from_sql: FromSql, columns: RowDescription) { - function parse(s: Uint8Array | string | null | undefined) { - if (!s && s !== "") return null; - else return from_utf8(s); +export function sql( + { raw: s }: TemplateStringsArray, + ...xs: unknown[] +): SqlFragment { + return { + [sql_format](fmt) { + for (let i = 0, n = s.length; i < n; i++) { + if (i !== 0) fmt_format(fmt, xs[i - 1]); + fmt.query += s[i]; + } + }, + }; +} + +export function fmt_write(fmt: SqlFormatter, s: string | SqlFragment) { + is_sql(s) ? s[sql_format](fmt) : (fmt.query += s); +} + +export function fmt_format(fmt: SqlFormatter, x: unknown) { + is_sql(x) ? x[sql_format](fmt) : fmt_enclose(fmt, x); +} + +export function fmt_enclose(fmt: SqlFormatter, x: unknown) { + const { params } = fmt; + params.push(x), (fmt.query += `$` + params.length); +} + +sql.format = format; +sql.raw = raw; +sql.ident = ident; +sql.fragment = fragment; +sql.map = map; +sql.array = array; +sql.row = row; + +export function format(sql: SqlFragment) { + const fmt: SqlFormatter = { query: "", params: [] }; + return sql[sql_format](fmt), fmt; +} + +export function raw(s: string): SqlFragment; +export function raw(s: TemplateStringsArray, ...xs: unknown[]): SqlFragment; +export function raw( + s: TemplateStringsArray | string, + ...xs: unknown[] +): SqlFragment { + s = typeof s === "string" ? s : String.raw(s, ...xs); + return { + [sql_format](fmt) { + fmt.query += s; + }, + }; +} + +export function ident(s: string): SqlFragment; +export function ident(s: TemplateStringsArray, ...xs: unknown[]): SqlFragment; +export function ident(s: TemplateStringsArray | string, ...xs: unknown[]) { + s = typeof s === "string" ? s : String.raw(s, ...xs); + return raw`"${s.replaceAll('"', '""')}"`; +} + +export function fragment( + sep: string | SqlFragment, + ...xs: unknown[] +): SqlFragment { + return { + [sql_format](fmt) { + for (let i = 0, n = xs.length; i < n; i++) { + if (i !== 0) fmt_write(fmt, sep); + fmt_format(fmt, xs[i]); + } + }, + }; +} + +export function map( + sep: string | SqlFragment, + xs: Iterable, + f: (value: T, index: number) => unknown +) { + return fragment(sep, ...Iterator.from(xs).map(f)); +} + +export function array(...xs: unknown[]) { + return sql`array[${fragment(", ", ...xs)}]`; +} + +export function row(...xs: unknown[]) { + return sql`row(${fragment(", ", ...xs)})`; +} + +export interface SqlType { + input(value: string): unknown; + output(value: unknown): string | null; +} + +export interface SqlTypeMap { + readonly [oid: number]: SqlType | undefined; +} + +export class SqlTypeError extends TypeError { + override get name() { + return this.constructor.name; } - - const Row = jit.compiled`function Row(xs) { - ${jit.map(" ", columns, ({ name, type_oid }, i) => { - return jit`this[${name}] = ${from_sql}(new ${SqlValue}(${type_oid}, ${parse}(xs[${i}])));`; - })} - }`; - - Row.prototype = Object.create(null, { - [Symbol.toStringTag]: { - configurable: true, - value: `Row`, - }, - [Symbol.toPrimitive]: { - configurable: true, - value: function format() { - return [...this].join("\t"); - }, - }, - [Symbol.iterator]: { - configurable: true, - value: jit.compiled`function* iter() { - ${jit.map(" ", columns, ({ name }) => { - return jit`yield this[${name}];`; - })} - }`, - }, - }); - - return Row; } +export const bool: SqlType = { + input(s) { + return s !== "f"; + }, + output(x) { + return typeof x === "undefined" || x === null ? null : x ? "t" : "f"; + }, +}; + +export const text: SqlType = { + input(s) { + return s; + }, + output(x) { + if (typeof x === "undefined" || x === null) return null; + else if (typeof x === "string") return x; + else return String(x); + }, +}; + +export const int2: SqlType = { + input(s) { + const n = Number(s); + if (Number.isInteger(n) && -32768 <= n && n <= 32767) return n; + else throw new SqlTypeError(`invalid int2 input '${s}'`); + }, + output(x) { + let n: number; + if (typeof x === "undefined" || x === null) return null; + else if (typeof x === "number") n = x; + else n = Number(x); + if (Number.isInteger(n) && -32768 <= n && n <= 32767) return n.toString(); + else throw new SqlTypeError(`invalid int2 output '${x}'`); + }, +}; + +export const int4: SqlType = { + input(s) { + const n = Number(s); + if (Number.isInteger(n) && -2147483648 <= n && n <= 2147483647) return n; + else throw new SqlTypeError(`invalid int4 input '${s}'`); + }, + output(x) { + let n: number; + if (typeof x === "undefined" || x === null) return null; + else if (typeof x === "number") n = x; + else n = Number(x); + if (Number.isInteger(n) && -2147483648 <= n && n <= 2147483647) + return n.toString(); + else throw new SqlTypeError(`invalid int4 output '${x}'`); + }, +}; + +export const int8: SqlType = { + input(s) { + const n = BigInt(s); + if (-9007199254740991n <= n && n <= 9007199254740991n) return Number(n); + else if (-9223372036854775808n <= n && n <= 9223372036854775807n) return n; + else throw new SqlTypeError(`invalid int8 input '${s}'`); + }, + output(x) { + let n: number | bigint; + if (typeof x === "undefined" || x === null) return null; + else if (typeof x === "number" || typeof x === "bigint") n = x; + else if (typeof x === "string") n = BigInt(x); + else n = Number(x); + if (Number.isInteger(n)) { + if (-9007199254740991 <= n && n <= 9007199254740991) return n.toString(); + else throw new SqlTypeError(`unsafe int8 output '${x}'`); + } else if (typeof n === "bigint") { + if (-9223372036854775808n <= n && n <= 9223372036854775807n) + return n.toString(); + } + throw new SqlTypeError(`invalid int8 output '${x}'`); + }, +}; + +export const float4: SqlType = { + input(s) { + return Math.fround(Number(s)); + }, + output(x) { + let n: number; + if (typeof x === "undefined" || x === null) return null; + else if (typeof x === "number") n = x; + else { + n = Number(x); + if (Number.isNaN(n)) + throw new SqlTypeError(`invalid float4 output '${x}'`); + } + return Math.fround(n).toString(); + }, +}; + +export const float8: SqlType = { + input(s) { + return Number(s); + }, + output(x) { + let n: number; + if (typeof x === "undefined" || x === null) return null; + else if (typeof x === "number") n = x; + else { + n = Number(x); + if (Number.isNaN(n)) + throw new SqlTypeError(`invalid float8 output '${x}'`); + } + return n.toString(); + }, +}; + +export const timestamptz: SqlType = { + input(s) { + const t = Date.parse(s); + if (!Number.isNaN(t)) return new Date(t); + else throw new SqlTypeError(`invalid timestamptz input '${s}'`); + }, + output(x) { + let t: Date; + if (typeof x === "undefined" || x === null) return null; + else if (x instanceof Date) t = x; + else if (typeof x === "number" || typeof x === "bigint") + t = new Date(Number(x) * 1000); // unix epoch seconds + else t = new Date(String(x)); + if (Number.isFinite(t.getTime())) return t.toISOString(); + else throw new SqlTypeError(`invalid timestamptz output '${x}'`); + }, +}; + +export const bytea: SqlType = { + input(s) { + if (s.startsWith(`\\x`)) return from_hex(s.slice(2)); + else throw new SqlTypeError(`invalid bytea input '${s}'`); + }, + output(x) { + let buf: Uint8Array; + if (typeof x === "undefined" || x === null) return null; + else if (typeof x === "string") buf = to_utf8(x); + else if (x instanceof Uint8Array) buf = x; + else if (x instanceof ArrayBuffer || x instanceof SharedArrayBuffer) + buf = new Uint8Array(x); + else if (Array.isArray(x) || x instanceof Array) buf = Uint8Array.from(x); + else throw new SqlTypeError(`invalid bytea output '${x}'`); + return `\\x` + to_hex(buf); + }, +}; + +export const json: SqlType = { + input(s) { + return JSON.parse(s); + }, + output(x) { + return typeof x === "undefined" ? null : JSON.stringify(x); + }, +}; + +export const sql_types: SqlTypeMap = { + 16: bool, // bool + 25: text, // text + 21: int2, // int2 + 23: int4, // int4 + 20: int8, // int8 + 26: int8, // oid + 700: float4, // float4 + 701: float8, // float8 + 1082: timestamptz, // date + 1114: timestamptz, // timestamp + 1184: timestamptz, // timestamptz + 17: bytea, // bytea + 114: json, // json + 3802: json, // jsonb +}; + +sql.types = sql_types; + type ReadonlyTuple = readonly [...T]; export interface CommandResult { @@ -74,6 +314,10 @@ export interface Results extends CommandResult, ReadonlyArray { export interface ResultStream extends AsyncIterable {} +export interface Row extends Iterable { + [column: string]: unknown; +} + export interface QueryOptions { readonly chunk_size: number; readonly stdin: ReadableStream | null; diff --git a/sql.ts b/sql.ts deleted file mode 100644 index 9d7bdc9..0000000 --- a/sql.ts +++ /dev/null @@ -1,386 +0,0 @@ -import { from_hex, to_hex } from "./lstd.ts"; - -export const sql_format = Symbol.for(`re.lua.pglue.sql_format`); - -export interface SqlFragment { - [sql_format](f: SqlFormatter): void; -} - -export function is_sql(x: unknown): x is SqlFragment { - return typeof x === "object" && x !== null && sql_format in x; -} - -export interface FromSql { - (x: SqlValue): unknown; -} - -export interface ToSql { - (x: unknown): SqlFragment; -} - -export const from_sql = function from_sql(x) { - const { type, value } = x; - if (value === null) return null; - - switch (type) { - case 16: // boolean - return boolean.parse(value); - case 25: // text - return text.parse(value); - case 21: // int2 - return int2.parse(value); - case 23: // int4 - return int4.parse(value); - case 20: // int8 - case 26: // oid - return int8.parse(value); - case 700: // float4 - return float4.parse(value); - case 701: // float8 - return float8.parse(value); - case 1082: // date - case 1114: // timestamp - case 1184: // timestamptz - return timestamptz.parse(value); - case 17: // bytea - return bytea.parse(value); - case 114: // json - case 3802: // jsonb - return json.parse(value); - default: - return x; - } -} as FromSql; - -export const to_sql = function to_sql(x) { - switch (typeof x) { - case "undefined": - return nil(); - case "boolean": - return boolean(x); - case "number": - return float8(x); - case "bigint": - return int8(x); - case "string": - case "symbol": - case "function": - return text(x); - } - - switch (true) { - case x === null: - return nil(); - - case is_sql(x): - return x; - - case Array.isArray(x): - return array(...(x instanceof Array ? x : Array.from(x))); - - case x instanceof Date: - return timestamptz(x); - - case x instanceof Uint8Array: - case x instanceof ArrayBuffer: - case x instanceof SharedArrayBuffer: - return bytea(x); - } - - throw new TypeError(`cannot convert input '${x}' to sql`); -} as ToSql; - -export class SqlValue implements SqlFragment { - constructor( - readonly type: number, - readonly value: string | null - ) {} - - [sql_format](f: SqlFormatter) { - f.write_param(this.type, this.value); - } - - [Symbol.toStringTag]() { - return `${this.constructor.name}<${this.type}>`; - } - - [Symbol.toPrimitive]() { - return this.value; - } - - toString() { - return String(this.value); - } - - toJSON() { - return this.value; - } -} - -export function value(type: number, x: unknown) { - const s = x === null || typeof x === "undefined" ? null : String(x); - return new SqlValue(type, s); -} - -export class SqlFormatter { - readonly #ser; - #query = ""; - #params = { - types: [] as number[], - values: [] as (string | null)[], - }; - - get query() { - return this.#query.trim(); - } - - get params() { - return this.#params; - } - - constructor(serializer: ToSql) { - this.#ser = serializer; - } - - write(s: string | SqlFragment) { - if (is_sql(s)) s[sql_format](this); - else this.#query += s; - } - - write_param(type: number, s: string | null) { - const { types, values } = this.#params; - types.push(type), values.push(s), this.write(`$` + values.length); - } - - format(x: unknown) { - this.write(is_sql(x) ? x : this.#ser(x)); - } -} - -export function format(sql: SqlFragment, serializer = to_sql) { - const fmt = new SqlFormatter(serializer); - return fmt.write(sql), fmt; -} - -export function sql( - { raw: s }: TemplateStringsArray, - ...xs: unknown[] -): SqlFragment { - return { - [sql_format](f) { - for (let i = 0, n = s.length; i < n; i++) { - if (i !== 0) f.format(xs[i - 1]); - f.write(s[i]); - } - }, - }; -} - -sql.value = value; -sql.format = format; -sql.raw = raw; -sql.ident = ident; -sql.fragment = fragment; -sql.map = map; -sql.array = array; -sql.row = row; -sql.null = nil; -sql.boolean = boolean; -sql.text = text; -sql.int2 = int2; -sql.int4 = int4; -sql.int8 = int8; -sql.float4 = float4; -sql.float8 = float8; -sql.timestamptz = timestamptz; -sql.bytea = bytea; -sql.json = json; - -export function raw(s: TemplateStringsArray, ...xs: unknown[]): SqlFragment; -export function raw(s: string): SqlFragment; -export function raw( - s: TemplateStringsArray | string, - ...xs: unknown[] -): SqlFragment { - s = typeof s === "string" ? s : String.raw(s, ...xs); - return { - [sql_format](f) { - f.write(s); - }, - }; -} - -export function ident(s: TemplateStringsArray, ...xs: unknown[]): SqlFragment; -export function ident(s: string): SqlFragment; -export function ident(s: TemplateStringsArray | string, ...xs: unknown[]) { - s = typeof s === "string" ? s : String.raw(s, ...xs); - return raw`"${s.replaceAll('"', '""')}"`; -} - -export function fragment( - sep: string | SqlFragment, - ...xs: unknown[] -): SqlFragment { - return { - [sql_format](f) { - for (let i = 0, n = xs.length; i < n; i++) { - if (i !== 0) f.write(sep); - f.format(xs[i]); - } - }, - }; -} - -export function map( - sep: string | SqlFragment, - xs: Iterable, - f: (value: T, index: number) => unknown -): SqlFragment { - return fragment(sep, ...Iterator.from(xs).map(f)); -} - -export function array(...xs: unknown[]): SqlFragment { - return sql`array[${fragment(", ", ...xs)}]`; -} - -export function row(...xs: unknown[]): SqlFragment { - return sql`row(${fragment(", ", ...xs)})`; -} - -boolean.oid = 16 as const; -text.oid = 25 as const; -int2.oid = 21 as const; -int4.oid = 23 as const; -int8.oid = 20 as const; -float4.oid = 700 as const; -float8.oid = 701 as const; -timestamptz.oid = 1184 as const; -bytea.oid = 17 as const; -json.oid = 114 as const; - -export function nil() { - return value(0, null); -} - -Object.defineProperty(nil, "name", { configurable: true, value: "null" }); - -export function boolean(x: unknown) { - return value( - boolean.oid, - x === null || typeof x === "undefined" ? null : x ? "t" : "f" - ); -} - -boolean.parse = function parse_boolean(s: string) { - return s === "t"; -}; - -export function text(x: unknown) { - return value(text.oid, x); -} - -text.parse = function parse_text(s: string) { - return s; -}; - -const i2_min = -32768; -const i2_max = 32767; - -export function int2(x: unknown) { - return value(int2.oid, x); -} - -int2.parse = function parse_int2(s: string) { - const n = Number(s); - if (Number.isInteger(n) && i2_min <= n && n <= i2_max) return n; - else throw new TypeError(`input '${s}' is not a valid int2 value`); -}; - -const i4_min = -2147483648; -const i4_max = 2147483647; - -export function int4(x: unknown) { - return value(int4.oid, x); -} - -int4.parse = function parse_int4(s: string) { - const n = Number(s); - if (Number.isInteger(n) && i4_min <= n && n <= i4_max) return n; - else throw new TypeError(`input '${s}' is not a valid int4 value`); -}; - -const i8_min = -9223372036854775808n; -const i8_max = 9223372036854775807n; - -export function int8(x: unknown) { - return value(int8.oid, x); -} - -function to_int8(n: number | bigint) { - if (typeof n === "bigint") return i8_min <= n && n <= i8_max ? n : null; - else return Number.isSafeInteger(n) ? BigInt(n) : null; -} - -int8.parse = function parse_int8(s: string) { - const n = to_int8(BigInt(s)); - if (n !== null) return to_float8(n) ?? n; - else throw new TypeError(`input '${s}' is not a valid int8 value`); -}; - -const f8_min = -9007199254740991n; -const f8_max = 9007199254740991n; - -export function float4(x: unknown) { - return value(float4.oid, x); -} - -export function float8(x: unknown) { - return value(float8.oid, x); -} - -function to_float8(n: number | bigint) { - if (typeof n === "bigint") - return f8_min <= n && n <= f8_max ? Number(n) : null; - else return Number.isNaN(n) ? null : n; -} - -float4.parse = float8.parse = function parse_float8(s: string) { - const n = to_float8(Number(s)); - if (n !== null) return n; - else throw new TypeError(`input '${s}' is not a valid float8 value`); -}; - -export function timestamptz(x: unknown) { - if (x instanceof Date) x = x.toISOString(); - else if (typeof x === "number" || typeof x === "bigint") - x = new Date(Number(x) * 1000).toISOString(); // unix epoch - return value(timestamptz.oid, x); -} - -timestamptz.parse = function parse_timestamptz(s: string) { - const t = Date.parse(s); - if (!Number.isNaN(t)) return new Date(t); - else throw new TypeError(`input '${s}' is not a valid timestamptz value`); -}; - -export function bytea(x: Uint8Array | ArrayBufferLike | Iterable) { - let buf; - if (x instanceof Uint8Array) buf = x; - else if (x instanceof ArrayBuffer || x instanceof SharedArrayBuffer) - buf = new Uint8Array(x); - else buf = Uint8Array.from(x); - return value(bytea.oid, `\\x` + to_hex(buf)); -} - -bytea.parse = function parse_bytea(s: string) { - if (s.startsWith(`\\x`)) return from_hex(s.slice(2)); - else throw new TypeError(`input is not a valid bytea value`); -}; - -export function json(x: unknown) { - return value(json.oid, JSON.stringify(x) ?? null); -} - -json.parse = function parse_json(s: string): unknown { - return JSON.parse(s); -}; diff --git a/wire.ts b/wire.ts index ccd8003..ea24504 100644 --- a/wire.ts +++ b/wire.ts @@ -1,10 +1,12 @@ import { + type BinaryLike, buf_concat_fast, buf_eq, buf_xor, channel, from_base64, from_utf8, + jit, semaphore, semaphore_fast, to_base64, @@ -28,21 +30,18 @@ import { i16, i32, i8, + type EncoderType, } from "./ser.ts"; -import { - is_sql, - sql, - type FromSql, - type SqlFragment, - type ToSql, -} from "./sql.ts"; import { type CommandResult, + is_sql, Query, type ResultStream, type Row, - row_ctor, - type RowConstructor, + sql, + type SqlFragment, + type SqlTypeMap, + text, } from "./query.ts"; import { join } from "jsr:@std/path@^1.0.8"; @@ -446,8 +445,7 @@ export interface WireOptions { readonly password: string; readonly database: string | null; readonly runtime_params: Record; - readonly from_sql: FromSql; - readonly to_sql: ToSql; + readonly types: SqlTypeMap; } export type WireEvents = { @@ -542,7 +540,8 @@ export class Wire extends TypedEmitter implements Disposable { if (typeof f !== "undefined") { await using tx = await this.#begin(); const value = await f(this, tx); - return await tx.commit(), value; + if (tx.open) await tx.commit(); + return value; } else { return this.#begin(); } @@ -583,7 +582,7 @@ export class Wire extends TypedEmitter implements Disposable { function wire_impl( wire: Wire, socket: Deno.Conn, - { user, database, password, runtime_params, from_sql, to_sql }: WireOptions + { user, database, password, runtime_params, types }: WireOptions ) { const params: Parameters = Object.create(null); @@ -878,45 +877,42 @@ function wire_impl( const st_cache = new Map(); let st_ids = 0; - function st_get(query: string, param_types: number[]) { - const key = JSON.stringify({ q: query, p: param_types }); - let st = st_cache.get(key); - if (!st) st_cache.set(key, (st = new Statement(query, param_types))); - return st; - } - class Statement { readonly name = `__st${st_ids++}`; + constructor(readonly query: string) {} - constructor( - readonly query: string, - readonly param_types: number[] - ) {} + parse_task: Promise<{ + ser_params: ParameterSerializer; + Row: RowConstructor; + }> | null = null; - parse_task: Promise | null = null; parse() { return (this.parse_task ??= this.#parse()); } async #parse() { try { - const { name, query, param_types } = this; - return row_ctor( - from_sql, - await pipeline( - async () => { - await write(Parse, { statement: name, query, param_types }); - await write(Describe, { which: "S", name }); - }, - async () => { - await read(ParseComplete); - await read(ParameterDescription); + const { name, query } = this; + return await pipeline( + async () => { + await write(Parse, { statement: name, query, param_types: [] }); + await write(Describe, { which: "S", name }); + }, + async () => { + await read(ParseComplete); + const param_desc = await read(ParameterDescription); - const msg = msg_check_err(await read_raw()); - if (msg_type(msg) === NoData.type) return []; - else return ser_decode(RowDescription, msg).columns; - } - ) + const msg = msg_check_err(await read_raw()); + const row_desc = + msg_type(msg) === NoData.type + ? { columns: [] } + : ser_decode(RowDescription, msg); + + return { + ser_params: param_ser(param_desc), + Row: row_ctor(row_desc), + }; + } ); } catch (e) { throw ((this.parse_task = null), e); @@ -929,6 +925,59 @@ function wire_impl( } } + type ParameterDescription = EncoderType; + interface ParameterSerializer { + (params: unknown[]): (string | null)[]; + } + + function param_ser({ param_types }: ParameterDescription) { + return jit.compiled`function ser_params(xs) { + return [ + ${jit.map(", ", param_types, (type_oid, i) => { + const type = types[type_oid] ?? text; + return jit`${type}.output(xs[${i}])`; + })} + ]; + }`; + } + + type RowDescription = EncoderType; + interface RowConstructor { + new (columns: (BinaryLike | null)[]): Row; + } + + function row_ctor({ columns }: RowDescription) { + const Row = jit.compiled`function Row(xs) { + ${jit.map(" ", columns, ({ name, type_oid }, i) => { + const type = types[type_oid] ?? text; + return jit`this[${name}] = xs[${i}] === null ? null : ${type}.input(${from_utf8}(xs[${i}]));`; + })} + }`; + + Row.prototype = Object.create(null, { + [Symbol.toStringTag]: { + configurable: true, + value: `Row`, + }, + [Symbol.toPrimitive]: { + configurable: true, + value: function format() { + return [...this].join("\t"); + }, + }, + [Symbol.iterator]: { + configurable: true, + value: jit.compiled`function* iter() { + ${jit.map(" ", columns, ({ name }) => { + return jit`yield this[${name}];`; + })} + }`, + }, + }); + + return Row; + } + async function read_rows( Row: RowConstructor, stdout: WritableStream | null @@ -1002,7 +1051,7 @@ function wire_impl( async function* execute_fast( st: Statement, - params: { types: number[]; values: (string | null)[] }, + params: unknown[], stdin: ReadableStream | null, stdout: WritableStream | null ): ResultStream { @@ -1012,7 +1061,8 @@ function wire_impl( `executing query` ); - const Row = await st.parse(); + const { ser_params, Row } = await st.parse(); + const param_values = ser_params(params); const portal = st.portal(); try { @@ -1022,7 +1072,7 @@ function wire_impl( portal, statement: st.name, param_formats: [], - param_values: params.values, + param_values, column_formats: [], }); await write(Execute, { portal, row_limit: 0 }); @@ -1049,7 +1099,7 @@ function wire_impl( async function* execute_chunked( st: Statement, - params: { types: number[]; values: (string | null)[] }, + params: unknown[], chunk_size: number, stdin: ReadableStream | null, stdout: WritableStream | null @@ -1060,7 +1110,8 @@ function wire_impl( `executing chunked query` ); - const Row = await st.parse(); + const { ser_params, Row } = await st.parse(); + const param_values = ser_params(params); const portal = st.portal(); try { @@ -1070,7 +1121,7 @@ function wire_impl( portal, statement: st.name, param_formats: [], - param_values: params.values, + param_values, column_formats: [], }); await write(Execute, { portal, row_limit: chunk_size }); @@ -1103,8 +1154,9 @@ function wire_impl( } function query(s: SqlFragment) { - const { query, params } = sql.format(s, to_sql); - const st = st_get(query, params.types); + const { query, params } = sql.format(s); + let st = st_cache.get(query); + if (!st) st_cache.set(query, (st = new Statement(query))); return new Query(({ chunk_size = 0, stdin = null, stdout = null }) => chunk_size !== 0 @@ -1287,7 +1339,8 @@ export class Pool if (typeof f !== "undefined") { await using tx = await this.#begin(); const value = await f(tx.wire, tx); - return await tx.commit(), value; + if (tx.open) await tx.commit(); + return value; } else { return this.#begin(); }