Compare commits

...

5 Commits

Author SHA1 Message Date
eeed8b2f66
Rewrite type handling to be more performant 2025-01-10 17:30:04 +11:00
b72c548c33
Rename byten_lp to bytes_lp 2025-01-10 04:25:50 +11:00
60899d1a41
Update lstd to 0.2.0 2025-01-10 04:11:17 +11:00
bdebb22a0e
Use BinaryLike interface 2025-01-09 04:48:19 +11:00
5dadd7c5a2
Fix integer encoding and decoding 2025-01-09 04:37:30 +11:00
8 changed files with 519 additions and 614 deletions

18
deno.lock generated
View File

@ -433,17 +433,11 @@
"https://deno.land/x/postgresjs@v3.4.5/src/result.js": "001ff5e0c8d634674f483d07fbcd620a797e3101f842d6c20ca3ace936260465", "https://deno.land/x/postgresjs@v3.4.5/src/result.js": "001ff5e0c8d634674f483d07fbcd620a797e3101f842d6c20ca3ace936260465",
"https://deno.land/x/postgresjs@v3.4.5/src/subscribe.js": "9e4d0c3e573a6048e77ee2f15abbd5bcd17da9ca85a78c914553472c6d6c169b", "https://deno.land/x/postgresjs@v3.4.5/src/subscribe.js": "9e4d0c3e573a6048e77ee2f15abbd5bcd17da9ca85a78c914553472c6d6c169b",
"https://deno.land/x/postgresjs@v3.4.5/src/types.js": "471f4a6c35412aa202a7c177c0a7e5a7c3bd225f01bbde67c947894c1b8bf6ed", "https://deno.land/x/postgresjs@v3.4.5/src/types.js": "471f4a6c35412aa202a7c177c0a7e5a7c3bd225f01bbde67c947894c1b8bf6ed",
"https://git.lua.re/luaneko/lstd/raw/tag/0.1.2/async.ts": "ec1a2d25af2320f136b8648b25b590b7b6603525474f0d10b3ebf2215a5c23e5", "https://git.lua.re/luaneko/lstd/raw/tag/0.2.0/async.ts": "20bc54c7260c2d2cd27ffcca33b903dde57a3a3635386d8e0c6baca4b253ae4e",
"https://git.lua.re/luaneko/lstd/raw/tag/0.1.2/bytes.ts": "39d4c08f6446041f1d078bbf285187c337d49f853b20ec637cf1516fae8b3729", "https://git.lua.re/luaneko/lstd/raw/tag/0.2.0/bytes.ts": "5ffb12787dc3f9ef9680b6e2e4f5f9903783aa4c33b69e725b5df1d1c116bfe6",
"https://git.lua.re/luaneko/lstd/raw/tag/0.1.2/events.ts": "51bf13b819d1c4af792a40ff5d8d08407502d3f01d94f6b6866156f52cbe5d64", "https://git.lua.re/luaneko/lstd/raw/tag/0.2.0/events.ts": "28d395b8eea87f9bf7908a44b351d2d3c609ba7eab62bcecd0d43be8ee603438",
"https://git.lua.re/luaneko/lstd/raw/tag/0.1.2/func.ts": "f1935f673365cd68939531d65ef18fe81b5d43dc795b03c34bb5ad821ab1c9ff", "https://git.lua.re/luaneko/lstd/raw/tag/0.2.0/func.ts": "f1935f673365cd68939531d65ef18fe81b5d43dc795b03c34bb5ad821ab1c9ff",
"https://git.lua.re/luaneko/lstd/raw/tag/0.1.2/jit.ts": "1b7eec61ece15c05146446972a59d8d5787d4ba53ca1194f4450134d66a65f91", "https://git.lua.re/luaneko/lstd/raw/tag/0.2.0/jit.ts": "c1db7820de95c48521b057c7cdf9aa41f7eaba77462407c29d3932e7da252d53",
"https://git.lua.re/luaneko/lstd/raw/tag/0.1.2/mod.ts": "d7ef832245676b097c4fb7829c5cb2df80c02d2bd28767168c4f83bc309c9b1a", "https://git.lua.re/luaneko/lstd/raw/tag/0.2.0/mod.ts": "95d8b15048a54cb82391825831f695b74e7c8b206317264a99c906ce25c63f13"
"https://git.lua.re/luaneko/lstd/raw/tag/0.1.3/async.ts": "20bc54c7260c2d2cd27ffcca33b903dde57a3a3635386d8e0c6baca4b253ae4e",
"https://git.lua.re/luaneko/lstd/raw/tag/0.1.3/bytes.ts": "39d4c08f6446041f1d078bbf285187c337d49f853b20ec637cf1516fae8b3729",
"https://git.lua.re/luaneko/lstd/raw/tag/0.1.3/events.ts": "c4f2c856cbc7ac5d93b9af9b83d9550db7427cead32514a10424082e492005ae",
"https://git.lua.re/luaneko/lstd/raw/tag/0.1.3/func.ts": "f1935f673365cd68939531d65ef18fe81b5d43dc795b03c34bb5ad821ab1c9ff",
"https://git.lua.re/luaneko/lstd/raw/tag/0.1.3/jit.ts": "260ab418fbc55a5dec594f023c84d36f8d420fd3239e3d27648cba1b9a0e05b1",
"https://git.lua.re/luaneko/lstd/raw/tag/0.1.3/mod.ts": "dd9271f4e5aae4bfb1ec6b0800697ded12e4178af915acb2b96b97614ae8c8d9"
} }
} }

View File

@ -1 +1 @@
export * from "https://git.lua.re/luaneko/lstd/raw/tag/0.1.3/mod.ts"; export * from "https://git.lua.re/luaneko/lstd/raw/tag/0.2.0/mod.ts";

38
mod.ts
View File

@ -9,7 +9,7 @@ import {
unknown, unknown,
} from "./valita.ts"; } from "./valita.ts";
import { Pool, wire_connect, type LogLevel } from "./wire.ts"; import { Pool, wire_connect, type LogLevel } from "./wire.ts";
import { type FromSql, type ToSql, from_sql, to_sql } from "./sql.ts"; import { sql_types, type SqlType, type SqlTypeMap } from "./query.ts";
export { export {
WireError, WireError,
@ -21,13 +21,11 @@ export {
} from "./wire.ts"; } from "./wire.ts";
export { export {
type SqlFragment, type SqlFragment,
type FromSql, type SqlType,
type ToSql, type SqlTypeMap,
SqlValue, SqlTypeError,
sql, sql,
is_sql, is_sql,
} from "./sql.ts";
export {
Query, Query,
type Row, type Row,
type CommandResult, type CommandResult,
@ -45,8 +43,7 @@ export type Options = {
max_connections?: number; max_connections?: number;
idle_timeout?: number; idle_timeout?: number;
runtime_params?: Record<string, string>; runtime_params?: Record<string, string>;
from_sql?: FromSql; types?: SqlTypeMap;
to_sql?: ToSql;
}; };
type ParsedOptions = Infer<typeof ParsedOptions>; type ParsedOptions = Infer<typeof ParsedOptions>;
@ -64,15 +61,12 @@ const ParsedOptions = object({
runtime_params: record(string()).optional(() => ({})), runtime_params: record(string()).optional(() => ({})),
max_connections: number().optional(() => 10), max_connections: number().optional(() => 10),
idle_timeout: number().optional(() => 20), idle_timeout: number().optional(() => 20),
from_sql: unknown() types: record(unknown())
.assert((s): s is FromSql => typeof s === "function") .optional(() => ({}))
.optional(() => from_sql), .map((types): SqlTypeMap => ({ ...sql_types, ...types })),
to_sql: unknown()
.assert((s): s is ToSql => typeof s === "function")
.optional(() => to_sql),
}); });
function parse_opts(s: string, options: Options) { function parse_opts(s: string, opts: Options) {
const { const {
host, host,
port, port,
@ -87,13 +81,13 @@ function parse_opts(s: string, options: Options) {
Deno.env.toObject(); Deno.env.toObject();
return ParsedOptions.parse({ return ParsedOptions.parse({
...options, ...opts,
host: options.host ?? host ?? PGHOST ?? undefined, host: opts.host ?? host ?? PGHOST ?? undefined,
port: options.port ?? port ?? PGPORT ?? undefined, port: opts.port ?? port ?? PGPORT ?? undefined,
user: options.user ?? user ?? PGUSER ?? USER ?? undefined, user: opts.user ?? user ?? PGUSER ?? USER ?? undefined,
password: options.password ?? password ?? PGPASSWORD ?? undefined, password: opts.password ?? password ?? PGPASSWORD ?? undefined,
database: options.database ?? database ?? PGDATABASE ?? undefined, database: opts.database ?? database ?? PGDATABASE ?? undefined,
runtime_params: { ...runtime_params, ...options.runtime_params }, runtime_params: { ...runtime_params, ...opts.runtime_params },
}); });
} }

View File

@ -6,8 +6,18 @@ await using pool = postgres(`postgres://test:test@localhost:5432/test`, {
pool.on("log", (level, ctx, msg) => console.info(`${level}: ${msg}`, ctx)); pool.on("log", (level, ctx, msg) => console.info(`${level}: ${msg}`, ctx));
await pool.begin(async (pg) => { await pool.begin(async (pg, tx) => {
await pg.begin(async (pg) => { await pg.query`
console.log(await pg.query`select * from pg_user`); create table my_test (
}); key integer primary key generated always as identity,
data text not null
)
`;
await pg.query`
insert into my_test (data) values (${[1, 2, 3]}::bytea)
`;
console.log(await pg.query`select * from my_test`);
await tx.rollback();
}); });

342
query.ts
View File

@ -1,64 +1,302 @@
import type { ObjectType } from "./valita.ts"; import type { ObjectType } from "./valita.ts";
import { from_utf8, jit, to_utf8 } from "./lstd.ts"; import { from_hex, to_hex, to_utf8 } from "./lstd.ts";
import { type FromSql, SqlValue } from "./sql.ts";
export interface Row extends Iterable<unknown, void, void> { export const sql_format = Symbol.for(`re.lua.pglue.sql_format`);
[column: string]: unknown;
export interface SqlFragment {
[sql_format](f: SqlFormatter): void;
} }
export interface RowConstructor { export interface SqlFormatter {
new (columns: (Uint8Array | string | null)[]): Row; query: string;
params: unknown[];
} }
export interface RowDescription extends ReadonlyArray<ColumnDescription> {} export function is_sql(x: unknown): x is SqlFragment {
return typeof x === "object" && x !== null && sql_format in x;
export interface ColumnDescription {
readonly name: string;
readonly table_oid: number;
readonly table_column: number;
readonly type_oid: number;
readonly type_size: number;
readonly type_modifier: number;
} }
export function row_ctor(from_sql: FromSql, columns: RowDescription) { export function sql(
function parse(s: Uint8Array | string | null | undefined) { { raw: s }: TemplateStringsArray,
if (!s && s !== "") return null; ...xs: unknown[]
else return from_utf8(s); ): SqlFragment {
return {
[sql_format](fmt) {
for (let i = 0, n = s.length; i < n; i++) {
if (i !== 0) fmt_format(fmt, xs[i - 1]);
fmt.query += s[i];
}
},
};
}
export function fmt_write(fmt: SqlFormatter, s: string | SqlFragment) {
is_sql(s) ? s[sql_format](fmt) : (fmt.query += s);
}
export function fmt_format(fmt: SqlFormatter, x: unknown) {
is_sql(x) ? x[sql_format](fmt) : fmt_enclose(fmt, x);
}
export function fmt_enclose(fmt: SqlFormatter, x: unknown) {
const { params } = fmt;
params.push(x), (fmt.query += `$` + params.length);
}
sql.format = format;
sql.raw = raw;
sql.ident = ident;
sql.fragment = fragment;
sql.map = map;
sql.array = array;
sql.row = row;
export function format(sql: SqlFragment) {
const fmt: SqlFormatter = { query: "", params: [] };
return sql[sql_format](fmt), fmt;
}
export function raw(s: string): SqlFragment;
export function raw(s: TemplateStringsArray, ...xs: unknown[]): SqlFragment;
export function raw(
s: TemplateStringsArray | string,
...xs: unknown[]
): SqlFragment {
s = typeof s === "string" ? s : String.raw(s, ...xs);
return {
[sql_format](fmt) {
fmt.query += s;
},
};
}
export function ident(s: string): SqlFragment;
export function ident(s: TemplateStringsArray, ...xs: unknown[]): SqlFragment;
export function ident(s: TemplateStringsArray | string, ...xs: unknown[]) {
s = typeof s === "string" ? s : String.raw(s, ...xs);
return raw`"${s.replaceAll('"', '""')}"`;
}
export function fragment(
sep: string | SqlFragment,
...xs: unknown[]
): SqlFragment {
return {
[sql_format](fmt) {
for (let i = 0, n = xs.length; i < n; i++) {
if (i !== 0) fmt_write(fmt, sep);
fmt_format(fmt, xs[i]);
}
},
};
}
export function map<T>(
sep: string | SqlFragment,
xs: Iterable<T>,
f: (value: T, index: number) => unknown
) {
return fragment(sep, ...Iterator.from(xs).map(f));
}
export function array(...xs: unknown[]) {
return sql`array[${fragment(", ", ...xs)}]`;
}
export function row(...xs: unknown[]) {
return sql`row(${fragment(", ", ...xs)})`;
}
export interface SqlType {
input(value: string): unknown;
output(value: unknown): string | null;
}
export interface SqlTypeMap {
readonly [oid: number]: SqlType | undefined;
}
export class SqlTypeError extends TypeError {
override get name() {
return this.constructor.name;
} }
const Row = jit.compiled<RowConstructor>`function Row(xs) {
${jit.map(" ", columns, ({ name, type_oid }, i) => {
return jit`this[${jit.literal(name)}] = ${from_sql}(
new ${SqlValue}(${jit.literal(type_oid)}, ${parse}(xs[${jit.literal(i)}]))
);`;
})}
}`;
Row.prototype = Object.create(null, {
[Symbol.toStringTag]: {
configurable: true,
value: `Row`,
},
[Symbol.toPrimitive]: {
configurable: true,
value: function format() {
return [...this].join("\t");
},
},
[Symbol.iterator]: {
configurable: true,
value: jit.compiled`function* iter() {
${jit.map(" ", columns, ({ name }) => {
return jit`yield this[${jit.literal(name)}];`;
})}
}`,
},
});
return Row;
} }
export const bool: SqlType = {
input(s) {
return s !== "f";
},
output(x) {
return typeof x === "undefined" || x === null ? null : x ? "t" : "f";
},
};
export const text: SqlType = {
input(s) {
return s;
},
output(x) {
if (typeof x === "undefined" || x === null) return null;
else if (typeof x === "string") return x;
else return String(x);
},
};
export const int2: SqlType = {
input(s) {
const n = Number(s);
if (Number.isInteger(n) && -32768 <= n && n <= 32767) return n;
else throw new SqlTypeError(`invalid int2 input '${s}'`);
},
output(x) {
let n: number;
if (typeof x === "undefined" || x === null) return null;
else if (typeof x === "number") n = x;
else n = Number(x);
if (Number.isInteger(n) && -32768 <= n && n <= 32767) return n.toString();
else throw new SqlTypeError(`invalid int2 output '${x}'`);
},
};
export const int4: SqlType = {
input(s) {
const n = Number(s);
if (Number.isInteger(n) && -2147483648 <= n && n <= 2147483647) return n;
else throw new SqlTypeError(`invalid int4 input '${s}'`);
},
output(x) {
let n: number;
if (typeof x === "undefined" || x === null) return null;
else if (typeof x === "number") n = x;
else n = Number(x);
if (Number.isInteger(n) && -2147483648 <= n && n <= 2147483647)
return n.toString();
else throw new SqlTypeError(`invalid int4 output '${x}'`);
},
};
export const int8: SqlType = {
input(s) {
const n = BigInt(s);
if (-9007199254740991n <= n && n <= 9007199254740991n) return Number(n);
else if (-9223372036854775808n <= n && n <= 9223372036854775807n) return n;
else throw new SqlTypeError(`invalid int8 input '${s}'`);
},
output(x) {
let n: number | bigint;
if (typeof x === "undefined" || x === null) return null;
else if (typeof x === "number" || typeof x === "bigint") n = x;
else if (typeof x === "string") n = BigInt(x);
else n = Number(x);
if (Number.isInteger(n)) {
if (-9007199254740991 <= n && n <= 9007199254740991) return n.toString();
else throw new SqlTypeError(`unsafe int8 output '${x}'`);
} else if (typeof n === "bigint") {
if (-9223372036854775808n <= n && n <= 9223372036854775807n)
return n.toString();
}
throw new SqlTypeError(`invalid int8 output '${x}'`);
},
};
export const float4: SqlType = {
input(s) {
return Math.fround(Number(s));
},
output(x) {
let n: number;
if (typeof x === "undefined" || x === null) return null;
else if (typeof x === "number") n = x;
else {
n = Number(x);
if (Number.isNaN(n))
throw new SqlTypeError(`invalid float4 output '${x}'`);
}
return Math.fround(n).toString();
},
};
export const float8: SqlType = {
input(s) {
return Number(s);
},
output(x) {
let n: number;
if (typeof x === "undefined" || x === null) return null;
else if (typeof x === "number") n = x;
else {
n = Number(x);
if (Number.isNaN(n))
throw new SqlTypeError(`invalid float8 output '${x}'`);
}
return n.toString();
},
};
export const timestamptz: SqlType = {
input(s) {
const t = Date.parse(s);
if (!Number.isNaN(t)) return new Date(t);
else throw new SqlTypeError(`invalid timestamptz input '${s}'`);
},
output(x) {
let t: Date;
if (typeof x === "undefined" || x === null) return null;
else if (x instanceof Date) t = x;
else if (typeof x === "number" || typeof x === "bigint")
t = new Date(Number(x) * 1000); // unix epoch seconds
else t = new Date(String(x));
if (Number.isFinite(t.getTime())) return t.toISOString();
else throw new SqlTypeError(`invalid timestamptz output '${x}'`);
},
};
export const bytea: SqlType = {
input(s) {
if (s.startsWith(`\\x`)) return from_hex(s.slice(2));
else throw new SqlTypeError(`invalid bytea input '${s}'`);
},
output(x) {
let buf: Uint8Array;
if (typeof x === "undefined" || x === null) return null;
else if (typeof x === "string") buf = to_utf8(x);
else if (x instanceof Uint8Array) buf = x;
else if (x instanceof ArrayBuffer || x instanceof SharedArrayBuffer)
buf = new Uint8Array(x);
else if (Array.isArray(x) || x instanceof Array) buf = Uint8Array.from(x);
else throw new SqlTypeError(`invalid bytea output '${x}'`);
return `\\x` + to_hex(buf);
},
};
export const json: SqlType = {
input(s) {
return JSON.parse(s);
},
output(x) {
return typeof x === "undefined" ? null : JSON.stringify(x);
},
};
export const sql_types: SqlTypeMap = {
16: bool, // bool
25: text, // text
21: int2, // int2
23: int4, // int4
20: int8, // int8
26: int8, // oid
700: float4, // float4
701: float8, // float8
1082: timestamptz, // date
1114: timestamptz, // timestamp
1184: timestamptz, // timestamptz
17: bytea, // bytea
114: json, // json
3802: json, // jsonb
};
sql.types = sql_types;
type ReadonlyTuple<T extends readonly unknown[]> = readonly [...T]; type ReadonlyTuple<T extends readonly unknown[]> = readonly [...T];
export interface CommandResult { export interface CommandResult {
@ -76,6 +314,10 @@ export interface Results<T> extends CommandResult, ReadonlyArray<T> {
export interface ResultStream<T> export interface ResultStream<T>
extends AsyncIterable<T[], CommandResult, void> {} extends AsyncIterable<T[], CommandResult, void> {}
export interface Row extends Iterable<unknown, void, void> {
[column: string]: unknown;
}
export interface QueryOptions { export interface QueryOptions {
readonly chunk_size: number; readonly chunk_size: number;
readonly stdin: ReadableStream<Uint8Array> | null; readonly stdin: ReadableStream<Uint8Array> | null;

99
ser.ts
View File

@ -1,4 +1,15 @@
import { encode_utf8, from_utf8, jit } from "./lstd.ts"; import {
type BinaryLike,
encode_utf8,
from_utf8,
jit,
read_i16_be,
read_i32_be,
read_i8,
write_i16_be,
write_i32_be,
write_i8,
} from "./lstd.ts";
export class EncoderError extends Error { export class EncoderError extends Error {
override get name() { override get name() {
@ -30,44 +41,31 @@ export interface Encoder<T> {
export type EncoderType<E extends Encoder<unknown>> = export type EncoderType<E extends Encoder<unknown>> =
E extends Encoder<infer T> ? T : never; E extends Encoder<infer T> ? T : never;
export function sum_const_size(...ns: (number | null)[]) {
let sum = 0;
for (const n of ns) {
if (n !== null) sum += n;
else return null;
}
return sum;
}
// https://www.postgresql.org/docs/current/protocol-message-types.html#PROTOCOL-MESSAGE-TYPES // https://www.postgresql.org/docs/current/protocol-message-types.html#PROTOCOL-MESSAGE-TYPES
export const u8: Encoder<number> = { export const i8: Encoder<number> = {
const_size: 1, const_size: 1,
allocs() { allocs() {
return 1; return 1;
}, },
encode(buf, cur, n) { encode(buf, cur, n) {
buf[cur.i++] = n & 0xff; write_i8(buf, n, cur.i++);
}, },
decode(buf, cur) { decode(buf, cur) {
return buf[cur.i++]; return read_i8(buf, cur.i++);
}, },
}; };
export const u16: Encoder<number> = { export const i16: Encoder<number> = {
const_size: 2, const_size: 2,
allocs() { allocs() {
return 2; return 2;
}, },
encode(buf, cur, n) { encode(buf, cur, n) {
let { i } = cur; write_i16_be(buf, n, cur.i), (cur.i += 2);
buf[i++] = (n >>> 8) & 0xff;
buf[i++] = n & 0xff;
cur.i = i;
}, },
decode(buf, cur) { decode(buf, cur) {
let { i } = cur; const n = read_i16_be(buf, cur.i);
const n = (buf[i++] << 8) + buf[i++]; return (cur.i += 2), n;
return (cur.i = i), n;
}, },
}; };
@ -77,17 +75,11 @@ export const i32: Encoder<number> = {
return 4; return 4;
}, },
encode(buf, cur, n) { encode(buf, cur, n) {
let { i } = cur; write_i32_be(buf, n, cur.i), (cur.i += 4);
buf[i++] = (n >>> 24) & 0xff;
buf[i++] = (n >>> 16) & 0xff;
buf[i++] = (n >>> 8) & 0xff;
buf[i++] = n & 0xff;
cur.i = i;
}, },
decode(buf, cur) { decode(buf, cur) {
let { i } = cur; const n = read_i32_be(buf, cur.i);
const n = (buf[i++] << 24) + (buf[i++] << 16) + (buf[i++] << 8) + buf[i++]; return (cur.i += 4), n;
return (cur.i = i), n;
}, },
}; };
@ -118,7 +110,21 @@ export function byten(n: number): Encoder<Uint8Array> {
}; };
} }
export const byten_lp: Encoder<Uint8Array | string | null> = { export const bytes: Encoder<BinaryLike> = {
const_size: null,
allocs(s) {
if (typeof s === "string") return s.length * 3;
else return s.length;
},
encode(buf, cur, s) {
cur.i += encode_utf8(s, buf.subarray(cur.i));
},
decode(buf, cur) {
return buf.subarray(cur.i, (cur.i = buf.length));
},
};
export const bytes_lp: Encoder<BinaryLike | null> = {
const_size: null, const_size: null,
allocs(s) { allocs(s) {
let size = 4; let size = 4;
@ -140,20 +146,6 @@ export const byten_lp: Encoder<Uint8Array | string | null> = {
}, },
}; };
export const byten_rest: Encoder<Uint8Array | string> = {
const_size: null,
allocs(s) {
if (typeof s === "string") return s.length * 3;
else return s.length;
},
encode(buf, cur, s) {
cur.i += encode_utf8(s, buf.subarray(cur.i));
},
decode(buf, cur) {
return buf.subarray(cur.i, (cur.i = buf.length));
},
};
export const cstring: Encoder<string> = { export const cstring: Encoder<string> = {
const_size: null, const_size: null,
allocs(s) { allocs(s) {
@ -215,7 +207,7 @@ export function array<T>(
): ArrayEncoder<T> { ): ArrayEncoder<T> {
const { const_size } = type; const { const_size } = type;
return { return {
const_size, const_size: null,
allocs: allocs:
const_size !== null const_size !== null
? function allocs(xs: T[]) { ? function allocs(xs: T[]) {
@ -249,21 +241,28 @@ export interface ObjectEncoder<S extends ObjectShape>
export function object<S extends ObjectShape>(shape: S): ObjectEncoder<S> { export function object<S extends ObjectShape>(shape: S): ObjectEncoder<S> {
const keys = Object.keys(shape); const keys = Object.keys(shape);
return jit.compiled`{ return jit.compiled`{
const_size: ${jit.literal(sum_const_size(...keys.map((k) => shape[k].const_size)))}, const_size: null,
allocs(x) { allocs(x) {
return ${jit.if(
keys.length === 0,
jit`0`,
jit.map(" + ", keys, (k) => {
return shape[k].const_size ?? jit`${shape[k]}.allocs(x[${k}])`;
})
)};
return 0${jit.map("", keys, (k) => { return 0${jit.map("", keys, (k) => {
return jit` + ${shape[k]}.allocs(x[${jit.literal(k)}])`; return jit` + ${shape[k]}.allocs(x[${k}])`;
})}; })};
}, },
encode(buf, cur, x) { encode(buf, cur, x) {
${jit.map(" ", keys, (k) => { ${jit.map(" ", keys, (k) => {
return jit`${shape[k]}.encode(buf, cur, x[${jit.literal(k)}]);`; return jit`${shape[k]}.encode(buf, cur, x[${k}]);`;
})} })}
}, },
decode(buf, cur) { decode(buf, cur) {
return { return {
${jit.map(", ", keys, (k) => { ${jit.map(", ", keys, (k) => {
return jit`[${jit.literal(k)}]: ${shape[k]}.decode(buf, cur)`; return jit`[${k}]: ${shape[k]}.decode(buf, cur)`;
})} })}
}; };
}, },

386
sql.ts
View File

@ -1,386 +0,0 @@
import { from_hex, to_hex } from "./lstd.ts";
export const sql_format = Symbol.for(`re.lua.pglue.sql_format`);
export interface SqlFragment {
[sql_format](f: SqlFormatter): void;
}
export function is_sql(x: unknown): x is SqlFragment {
return typeof x === "object" && x !== null && sql_format in x;
}
export interface FromSql {
(x: SqlValue): unknown;
}
export interface ToSql {
(x: unknown): SqlFragment;
}
export const from_sql = function from_sql(x) {
const { type, value } = x;
if (value === null) return null;
switch (type) {
case 16: // boolean
return boolean.parse(value);
case 25: // text
return text.parse(value);
case 21: // int2
return int2.parse(value);
case 23: // int4
return int4.parse(value);
case 20: // int8
case 26: // oid
return int8.parse(value);
case 700: // float4
return float4.parse(value);
case 701: // float8
return float8.parse(value);
case 1082: // date
case 1114: // timestamp
case 1184: // timestamptz
return timestamptz.parse(value);
case 17: // bytea
return bytea.parse(value);
case 114: // json
case 3802: // jsonb
return json.parse(value);
default:
return x;
}
} as FromSql;
export const to_sql = function to_sql(x) {
switch (typeof x) {
case "undefined":
return nil();
case "boolean":
return boolean(x);
case "number":
return float8(x);
case "bigint":
return int8(x);
case "string":
case "symbol":
case "function":
return text(x);
}
switch (true) {
case x === null:
return nil();
case is_sql(x):
return x;
case Array.isArray(x):
return array(...(x instanceof Array ? x : Array.from(x)));
case x instanceof Date:
return timestamptz(x);
case x instanceof Uint8Array:
case x instanceof ArrayBuffer:
case x instanceof SharedArrayBuffer:
return bytea(x);
}
throw new TypeError(`cannot convert input '${x}' to sql`);
} as ToSql;
export class SqlValue implements SqlFragment {
constructor(
readonly type: number,
readonly value: string | null
) {}
[sql_format](f: SqlFormatter) {
f.write_param(this.type, this.value);
}
[Symbol.toStringTag]() {
return `${this.constructor.name}<${this.type}>`;
}
[Symbol.toPrimitive]() {
return this.value;
}
toString() {
return String(this.value);
}
toJSON() {
return this.value;
}
}
export function value(type: number, x: unknown) {
const s = x === null || typeof x === "undefined" ? null : String(x);
return new SqlValue(type, s);
}
export class SqlFormatter {
readonly #ser;
#query = "";
#params = {
types: [] as number[],
values: [] as (string | null)[],
};
get query() {
return this.#query.trim();
}
get params() {
return this.#params;
}
constructor(serializer: ToSql) {
this.#ser = serializer;
}
write(s: string | SqlFragment) {
if (is_sql(s)) s[sql_format](this);
else this.#query += s;
}
write_param(type: number, s: string | null) {
const { types, values } = this.#params;
types.push(type), values.push(s), this.write(`$` + values.length);
}
format(x: unknown) {
this.write(is_sql(x) ? x : this.#ser(x));
}
}
export function format(sql: SqlFragment, serializer = to_sql) {
const fmt = new SqlFormatter(serializer);
return fmt.write(sql), fmt;
}
export function sql(
{ raw: s }: TemplateStringsArray,
...xs: unknown[]
): SqlFragment {
return {
[sql_format](f) {
for (let i = 0, n = s.length; i < n; i++) {
if (i !== 0) f.format(xs[i - 1]);
f.write(s[i]);
}
},
};
}
sql.value = value;
sql.format = format;
sql.raw = raw;
sql.ident = ident;
sql.fragment = fragment;
sql.map = map;
sql.array = array;
sql.row = row;
sql.null = nil;
sql.boolean = boolean;
sql.text = text;
sql.int2 = int2;
sql.int4 = int4;
sql.int8 = int8;
sql.float4 = float4;
sql.float8 = float8;
sql.timestamptz = timestamptz;
sql.bytea = bytea;
sql.json = json;
export function raw(s: TemplateStringsArray, ...xs: unknown[]): SqlFragment;
export function raw(s: string): SqlFragment;
export function raw(
s: TemplateStringsArray | string,
...xs: unknown[]
): SqlFragment {
s = typeof s === "string" ? s : String.raw(s, ...xs);
return {
[sql_format](f) {
f.write(s);
},
};
}
export function ident(s: TemplateStringsArray, ...xs: unknown[]): SqlFragment;
export function ident(s: string): SqlFragment;
export function ident(s: TemplateStringsArray | string, ...xs: unknown[]) {
s = typeof s === "string" ? s : String.raw(s, ...xs);
return raw`"${s.replaceAll('"', '""')}"`;
}
export function fragment(
sep: string | SqlFragment,
...xs: unknown[]
): SqlFragment {
return {
[sql_format](f) {
for (let i = 0, n = xs.length; i < n; i++) {
if (i !== 0) f.write(sep);
f.format(xs[i]);
}
},
};
}
export function map<T>(
sep: string | SqlFragment,
xs: Iterable<T>,
f: (value: T, index: number) => unknown
): SqlFragment {
return fragment(sep, ...Iterator.from(xs).map(f));
}
export function array(...xs: unknown[]): SqlFragment {
return sql`array[${fragment(", ", ...xs)}]`;
}
export function row(...xs: unknown[]): SqlFragment {
return sql`row(${fragment(", ", ...xs)})`;
}
boolean.oid = 16 as const;
text.oid = 25 as const;
int2.oid = 21 as const;
int4.oid = 23 as const;
int8.oid = 20 as const;
float4.oid = 700 as const;
float8.oid = 701 as const;
timestamptz.oid = 1184 as const;
bytea.oid = 17 as const;
json.oid = 114 as const;
export function nil() {
return value(0, null);
}
Object.defineProperty(nil, "name", { configurable: true, value: "null" });
export function boolean(x: unknown) {
return value(
boolean.oid,
x === null || typeof x === "undefined" ? null : x ? "t" : "f"
);
}
boolean.parse = function parse_boolean(s: string) {
return s === "t";
};
export function text(x: unknown) {
return value(text.oid, x);
}
text.parse = function parse_text(s: string) {
return s;
};
const i2_min = -32768;
const i2_max = 32767;
export function int2(x: unknown) {
return value(int2.oid, x);
}
int2.parse = function parse_int2(s: string) {
const n = Number(s);
if (Number.isInteger(n) && i2_min <= n && n <= i2_max) return n;
else throw new TypeError(`input '${s}' is not a valid int2 value`);
};
const i4_min = -2147483648;
const i4_max = 2147483647;
export function int4(x: unknown) {
return value(int4.oid, x);
}
int4.parse = function parse_int4(s: string) {
const n = Number(s);
if (Number.isInteger(n) && i4_min <= n && n <= i4_max) return n;
else throw new TypeError(`input '${s}' is not a valid int4 value`);
};
const i8_min = -9223372036854775808n;
const i8_max = 9223372036854775807n;
export function int8(x: unknown) {
return value(int8.oid, x);
}
function to_int8(n: number | bigint) {
if (typeof n === "bigint") return i8_min <= n && n <= i8_max ? n : null;
else return Number.isSafeInteger(n) ? BigInt(n) : null;
}
int8.parse = function parse_int8(s: string) {
const n = to_int8(BigInt(s));
if (n !== null) return to_float8(n) ?? n;
else throw new TypeError(`input '${s}' is not a valid int8 value`);
};
const f8_min = -9007199254740991n;
const f8_max = 9007199254740991n;
export function float4(x: unknown) {
return value(float4.oid, x);
}
export function float8(x: unknown) {
return value(float8.oid, x);
}
function to_float8(n: number | bigint) {
if (typeof n === "bigint")
return f8_min <= n && n <= f8_max ? Number(n) : null;
else return Number.isNaN(n) ? null : n;
}
float4.parse = float8.parse = function parse_float8(s: string) {
const n = to_float8(Number(s));
if (n !== null) return n;
else throw new TypeError(`input '${s}' is not a valid float8 value`);
};
export function timestamptz(x: unknown) {
if (x instanceof Date) x = x.toISOString();
else if (typeof x === "number" || typeof x === "bigint")
x = new Date(Number(x) * 1000).toISOString(); // unix epoch
return value(timestamptz.oid, x);
}
timestamptz.parse = function parse_timestamptz(s: string) {
const t = Date.parse(s);
if (!Number.isNaN(t)) return new Date(t);
else throw new TypeError(`input '${s}' is not a valid timestamptz value`);
};
export function bytea(x: Uint8Array | ArrayBufferLike | Iterable<number>) {
let buf;
if (x instanceof Uint8Array) buf = x;
else if (x instanceof ArrayBuffer || x instanceof SharedArrayBuffer)
buf = new Uint8Array(x);
else buf = Uint8Array.from(x);
return value(bytea.oid, `\\x` + to_hex(buf));
}
bytea.parse = function parse_bytea(s: string) {
if (s.startsWith(`\\x`)) return from_hex(s.slice(2));
else throw new TypeError(`input is not a valid bytea value`);
};
export function json(x: unknown) {
return value(json.oid, JSON.stringify(x) ?? null);
}
json.parse = function parse_json(s: string): unknown {
return JSON.parse(s);
};

230
wire.ts
View File

@ -1,10 +1,12 @@
import { import {
type BinaryLike,
buf_concat_fast, buf_concat_fast,
buf_eq, buf_eq,
buf_xor, buf_xor,
channel, channel,
from_base64, from_base64,
from_utf8, from_utf8,
jit,
semaphore, semaphore,
semaphore_fast, semaphore_fast,
to_base64, to_base64,
@ -14,8 +16,8 @@ import {
import { import {
array, array,
byten, byten,
byten_lp, bytes_lp,
byten_rest, bytes,
char, char,
cstring, cstring,
type Encoder, type Encoder,
@ -25,25 +27,21 @@ import {
oneof, oneof,
ser_decode, ser_decode,
ser_encode, ser_encode,
u16, i16,
i32, i32,
u8, i8,
sum_const_size, type EncoderType,
} from "./ser.ts"; } from "./ser.ts";
import {
is_sql,
sql,
type FromSql,
type SqlFragment,
type ToSql,
} from "./sql.ts";
import { import {
type CommandResult, type CommandResult,
is_sql,
Query, Query,
type ResultStream, type ResultStream,
type Row, type Row,
row_ctor, sql,
type RowConstructor, type SqlFragment,
type SqlTypeMap,
text,
} from "./query.ts"; } from "./query.ts";
import { join } from "jsr:@std/path@^1.0.8"; import { join } from "jsr:@std/path@^1.0.8";
@ -119,11 +117,11 @@ function msg<T extends string, S extends ObjectShape>(
shape: S shape: S
): MessageEncoder<T, S> { ): MessageEncoder<T, S> {
const header_size = type !== "" ? 5 : 4; const header_size = type !== "" ? 5 : 4;
const ty = type !== "" ? oneof(char(u8), type) : null; const ty = type !== "" ? oneof(char(i8), type) : null;
const fields = object(shape); const fields = object(shape);
return { return {
const_size: sum_const_size(header_size, fields.const_size), const_size: null,
get type() { get type() {
return type; return type;
}, },
@ -160,7 +158,7 @@ function msg_check_err(msg: Uint8Array) {
// https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS // https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS
export const Header = object({ export const Header = object({
type: char(u8), type: char(i8),
length: i32, length: i32,
}); });
@ -191,7 +189,7 @@ export const AuthenticationGSS = msg("R", {
export const AuthenticationGSSContinue = msg("R", { export const AuthenticationGSSContinue = msg("R", {
status: oneof(i32, 8 as const), status: oneof(i32, 8 as const),
data: byten_rest, data: bytes,
}); });
export const AuthenticationSSPI = msg("R", { export const AuthenticationSSPI = msg("R", {
@ -221,12 +219,12 @@ export const AuthenticationSASL = msg("R", {
export const AuthenticationSASLContinue = msg("R", { export const AuthenticationSASLContinue = msg("R", {
status: oneof(i32, 11 as const), status: oneof(i32, 11 as const),
data: byten_rest, data: bytes,
}); });
export const AuthenticationSASLFinal = msg("R", { export const AuthenticationSASLFinal = msg("R", {
status: oneof(i32, 12 as const), status: oneof(i32, 12 as const),
data: byten_rest, data: bytes,
}); });
export const BackendKeyData = msg("K", { export const BackendKeyData = msg("K", {
@ -237,9 +235,9 @@ export const BackendKeyData = msg("K", {
export const Bind = msg("B", { export const Bind = msg("B", {
portal: cstring, portal: cstring,
statement: cstring, statement: cstring,
param_formats: array(u16, u16), param_formats: array(i16, i16),
param_values: array(u16, byten_lp), param_values: array(i16, bytes_lp),
column_formats: array(u16, u16), column_formats: array(i16, i16),
}); });
export const BindComplete = msg("2", {}); export const BindComplete = msg("2", {});
@ -251,43 +249,43 @@ export const CancelRequest = msg("", {
}); });
export const Close = msg("C", { export const Close = msg("C", {
which: oneof(char(u8), "S" as const, "P" as const), which: oneof(char(i8), "S" as const, "P" as const),
name: cstring, name: cstring,
}); });
export const CloseComplete = msg("3", {}); export const CloseComplete = msg("3", {});
export const CommandComplete = msg("C", { tag: cstring }); export const CommandComplete = msg("C", { tag: cstring });
export const CopyData = msg("d", { data: byten_rest }); export const CopyData = msg("d", { data: bytes });
export const CopyDone = msg("c", {}); export const CopyDone = msg("c", {});
export const CopyFail = msg("f", { cause: cstring }); export const CopyFail = msg("f", { cause: cstring });
export const CopyInResponse = msg("G", { export const CopyInResponse = msg("G", {
format: u8, format: i8,
column_formats: array(u16, u16), column_formats: array(i16, i16),
}); });
export const CopyOutResponse = msg("H", { export const CopyOutResponse = msg("H", {
format: u8, format: i8,
column_formats: array(u16, u16), column_formats: array(i16, i16),
}); });
export const CopyBothResponse = msg("W", { export const CopyBothResponse = msg("W", {
format: u8, format: i8,
column_formats: array(u16, u16), column_formats: array(i16, i16),
}); });
export const DataRow = msg("D", { export const DataRow = msg("D", {
column_values: array(u16, byten_lp), column_values: array(i16, bytes_lp),
}); });
export const Describe = msg("D", { export const Describe = msg("D", {
which: oneof(char(u8), "S" as const, "P" as const), which: oneof(char(i8), "S" as const, "P" as const),
name: cstring, name: cstring,
}); });
export const EmptyQueryResponse = msg("I", {}); export const EmptyQueryResponse = msg("I", {});
const err_field = char(u8); const err_field = char(i8);
const err_fields: Encoder<Record<string, string>> = { const err_fields: Encoder<Record<string, string>> = {
const_size: null, const_size: null,
allocs(x) { allocs(x) {
@ -325,13 +323,13 @@ export const Flush = msg("H", {});
export const FunctionCall = msg("F", { export const FunctionCall = msg("F", {
oid: i32, oid: i32,
arg_formats: array(u16, u16), arg_formats: array(i16, i16),
arg_values: array(u16, byten_lp), arg_values: array(i16, bytes_lp),
result_format: u16, result_format: i16,
}); });
export const FunctionCallResponse = msg("V", { export const FunctionCallResponse = msg("V", {
result_value: byten_lp, result_value: bytes_lp,
}); });
export const NegotiateProtocolVersion = msg("v", { export const NegotiateProtocolVersion = msg("v", {
@ -352,7 +350,7 @@ export const NotificationResponse = msg("A", {
}); });
export const ParameterDescription = msg("t", { export const ParameterDescription = msg("t", {
param_types: array(u16, i32), param_types: array(i16, i32),
}); });
export const ParameterStatus = msg("S", { export const ParameterStatus = msg("S", {
@ -363,7 +361,7 @@ export const ParameterStatus = msg("S", {
export const Parse = msg("P", { export const Parse = msg("P", {
statement: cstring, statement: cstring,
query: cstring, query: cstring,
param_types: array(u16, i32), param_types: array(i16, i32),
}); });
export const ParseComplete = msg("1", {}); export const ParseComplete = msg("1", {});
@ -379,31 +377,31 @@ export const QueryMessage = msg("Q", {
}); });
export const ReadyForQuery = msg("Z", { export const ReadyForQuery = msg("Z", {
tx_status: oneof(char(u8), "I" as const, "T" as const, "E" as const), tx_status: oneof(char(i8), "I" as const, "T" as const, "E" as const),
}); });
export const RowDescription = msg("T", { export const RowDescription = msg("T", {
columns: array( columns: array(
u16, i16,
object({ object({
name: cstring, name: cstring,
table_oid: i32, table_oid: i32,
table_column: u16, table_column: i16,
type_oid: i32, type_oid: i32,
type_size: u16, type_size: i16,
type_modifier: i32, type_modifier: i32,
format: u16, format: i16,
}) })
), ),
}); });
export const SASLInitialResponse = msg("p", { export const SASLInitialResponse = msg("p", {
mechanism: cstring, mechanism: cstring,
data: byten_lp, data: bytes_lp,
}); });
export const SASLResponse = msg("p", { export const SASLResponse = msg("p", {
data: byten_rest, data: bytes,
}); });
export const StartupMessage = msg("", { export const StartupMessage = msg("", {
@ -421,7 +419,7 @@ export const StartupMessage = msg("", {
for (const { 0: key, 1: value } of Object.entries(x)) { for (const { 0: key, 1: value } of Object.entries(x)) {
cstring.encode(buf, cur, key), cstring.encode(buf, cur, value); cstring.encode(buf, cur, key), cstring.encode(buf, cur, value);
} }
u8.encode(buf, cur, 0); i8.encode(buf, cur, 0);
}, },
decode(buf, cur) { decode(buf, cur) {
const x: Record<string, string> = {}; const x: Record<string, string> = {};
@ -447,8 +445,7 @@ export interface WireOptions {
readonly password: string; readonly password: string;
readonly database: string | null; readonly database: string | null;
readonly runtime_params: Record<string, string>; readonly runtime_params: Record<string, string>;
readonly from_sql: FromSql; readonly types: SqlTypeMap;
readonly to_sql: ToSql;
} }
export type WireEvents = { export type WireEvents = {
@ -543,7 +540,8 @@ export class Wire extends TypedEmitter<WireEvents> implements Disposable {
if (typeof f !== "undefined") { if (typeof f !== "undefined") {
await using tx = await this.#begin(); await using tx = await this.#begin();
const value = await f(this, tx); const value = await f(this, tx);
return await tx.commit(), value; if (tx.open) await tx.commit();
return value;
} else { } else {
return this.#begin(); return this.#begin();
} }
@ -584,7 +582,7 @@ export class Wire extends TypedEmitter<WireEvents> implements Disposable {
function wire_impl( function wire_impl(
wire: Wire, wire: Wire,
socket: Deno.Conn, socket: Deno.Conn,
{ user, database, password, runtime_params, from_sql, to_sql }: WireOptions { user, database, password, runtime_params, types }: WireOptions
) { ) {
const params: Parameters = Object.create(null); const params: Parameters = Object.create(null);
@ -879,45 +877,42 @@ function wire_impl(
const st_cache = new Map<string, Statement>(); const st_cache = new Map<string, Statement>();
let st_ids = 0; let st_ids = 0;
function st_get(query: string, param_types: number[]) {
const key = JSON.stringify({ q: query, p: param_types });
let st = st_cache.get(key);
if (!st) st_cache.set(key, (st = new Statement(query, param_types)));
return st;
}
class Statement { class Statement {
readonly name = `__st${st_ids++}`; readonly name = `__st${st_ids++}`;
constructor(readonly query: string) {}
constructor( parse_task: Promise<{
readonly query: string, ser_params: ParameterSerializer;
readonly param_types: number[] Row: RowConstructor;
) {} }> | null = null;
parse_task: Promise<RowConstructor> | null = null;
parse() { parse() {
return (this.parse_task ??= this.#parse()); return (this.parse_task ??= this.#parse());
} }
async #parse() { async #parse() {
try { try {
const { name, query, param_types } = this; const { name, query } = this;
return row_ctor( return await pipeline(
from_sql, async () => {
await pipeline( await write(Parse, { statement: name, query, param_types: [] });
async () => { await write(Describe, { which: "S", name });
await write(Parse, { statement: name, query, param_types }); },
await write(Describe, { which: "S", name }); async () => {
}, await read(ParseComplete);
async () => { const param_desc = await read(ParameterDescription);
await read(ParseComplete);
await read(ParameterDescription);
const msg = msg_check_err(await read_raw()); const msg = msg_check_err(await read_raw());
if (msg_type(msg) === NoData.type) return []; const row_desc =
else return ser_decode(RowDescription, msg).columns; msg_type(msg) === NoData.type
} ? { columns: [] }
) : ser_decode(RowDescription, msg);
return {
ser_params: param_ser(param_desc),
Row: row_ctor(row_desc),
};
}
); );
} catch (e) { } catch (e) {
throw ((this.parse_task = null), e); throw ((this.parse_task = null), e);
@ -930,6 +925,59 @@ function wire_impl(
} }
} }
type ParameterDescription = EncoderType<typeof ParameterDescription>;
interface ParameterSerializer {
(params: unknown[]): (string | null)[];
}
function param_ser({ param_types }: ParameterDescription) {
return jit.compiled<ParameterSerializer>`function ser_params(xs) {
return [
${jit.map(", ", param_types, (type_oid, i) => {
const type = types[type_oid] ?? text;
return jit`${type}.output(xs[${i}])`;
})}
];
}`;
}
type RowDescription = EncoderType<typeof RowDescription>;
interface RowConstructor {
new (columns: (BinaryLike | null)[]): Row;
}
function row_ctor({ columns }: RowDescription) {
const Row = jit.compiled<RowConstructor>`function Row(xs) {
${jit.map(" ", columns, ({ name, type_oid }, i) => {
const type = types[type_oid] ?? text;
return jit`this[${name}] = xs[${i}] === null ? null : ${type}.input(${from_utf8}(xs[${i}]));`;
})}
}`;
Row.prototype = Object.create(null, {
[Symbol.toStringTag]: {
configurable: true,
value: `Row`,
},
[Symbol.toPrimitive]: {
configurable: true,
value: function format() {
return [...this].join("\t");
},
},
[Symbol.iterator]: {
configurable: true,
value: jit.compiled`function* iter() {
${jit.map(" ", columns, ({ name }) => {
return jit`yield this[${name}];`;
})}
}`,
},
});
return Row;
}
async function read_rows( async function read_rows(
Row: RowConstructor, Row: RowConstructor,
stdout: WritableStream<Uint8Array> | null stdout: WritableStream<Uint8Array> | null
@ -1003,7 +1051,7 @@ function wire_impl(
async function* execute_fast( async function* execute_fast(
st: Statement, st: Statement,
params: { types: number[]; values: (string | null)[] }, params: unknown[],
stdin: ReadableStream<Uint8Array> | null, stdin: ReadableStream<Uint8Array> | null,
stdout: WritableStream<Uint8Array> | null stdout: WritableStream<Uint8Array> | null
): ResultStream<Row> { ): ResultStream<Row> {
@ -1013,7 +1061,8 @@ function wire_impl(
`executing query` `executing query`
); );
const Row = await st.parse(); const { ser_params, Row } = await st.parse();
const param_values = ser_params(params);
const portal = st.portal(); const portal = st.portal();
try { try {
@ -1023,7 +1072,7 @@ function wire_impl(
portal, portal,
statement: st.name, statement: st.name,
param_formats: [], param_formats: [],
param_values: params.values, param_values,
column_formats: [], column_formats: [],
}); });
await write(Execute, { portal, row_limit: 0 }); await write(Execute, { portal, row_limit: 0 });
@ -1050,7 +1099,7 @@ function wire_impl(
async function* execute_chunked( async function* execute_chunked(
st: Statement, st: Statement,
params: { types: number[]; values: (string | null)[] }, params: unknown[],
chunk_size: number, chunk_size: number,
stdin: ReadableStream<Uint8Array> | null, stdin: ReadableStream<Uint8Array> | null,
stdout: WritableStream<Uint8Array> | null stdout: WritableStream<Uint8Array> | null
@ -1061,7 +1110,8 @@ function wire_impl(
`executing chunked query` `executing chunked query`
); );
const Row = await st.parse(); const { ser_params, Row } = await st.parse();
const param_values = ser_params(params);
const portal = st.portal(); const portal = st.portal();
try { try {
@ -1071,7 +1121,7 @@ function wire_impl(
portal, portal,
statement: st.name, statement: st.name,
param_formats: [], param_formats: [],
param_values: params.values, param_values,
column_formats: [], column_formats: [],
}); });
await write(Execute, { portal, row_limit: chunk_size }); await write(Execute, { portal, row_limit: chunk_size });
@ -1104,8 +1154,9 @@ function wire_impl(
} }
function query(s: SqlFragment) { function query(s: SqlFragment) {
const { query, params } = sql.format(s, to_sql); const { query, params } = sql.format(s);
const st = st_get(query, params.types); let st = st_cache.get(query);
if (!st) st_cache.set(query, (st = new Statement(query)));
return new Query(({ chunk_size = 0, stdin = null, stdout = null }) => return new Query(({ chunk_size = 0, stdin = null, stdout = null }) =>
chunk_size !== 0 chunk_size !== 0
@ -1288,7 +1339,8 @@ export class Pool
if (typeof f !== "undefined") { if (typeof f !== "undefined") {
await using tx = await this.#begin(); await using tx = await this.#begin();
const value = await f(tx.wire, tx); const value = await f(tx.wire, tx);
return await tx.commit(), value; if (tx.open) await tx.commit();
return value;
} else { } else {
return this.#begin(); return this.#begin();
} }