Skip to content

Instantly share code, notes, and snippets.

@mrchaofan
Created July 9, 2024 09:00
Show Gist options
  • Select an option

  • Save mrchaofan/3c474f4005712e071ec9e30601ba418d to your computer and use it in GitHub Desktop.

Select an option

Save mrchaofan/3c474f4005712e071ec9e30601ba418d to your computer and use it in GitHub Desktop.
import net from 'net';
import fs from 'fs';
import https from 'https';
import http from 'http';
import stream from 'stream';
import url from 'url';
import { WebSocket, WebSocketServer } from 'ws';
import { HttpProxyAgent } from 'http-proxy-agent';
import { HttpsProxyAgent } from 'https-proxy-agent';
import { SyncHook } from '~/frameworks/utils/hook';
import { once } from 'lodash';
import { getResourcePath } from '~/main-process/utils/directory';
export type OnRequestCallback = (
req: http.IncomingMessage,
res: http.ServerResponse<http.IncomingMessage> & {
req: http.IncomingMessage;
},
options: http.RequestOptions,
next: () => void,
) => void;
interface ForwardInfo {
hostname: string;
port?: number;
}
interface HttpsServerWrapOptions {
forwardInfo: ForwardInfo;
onZeroRefCallback: () => void;
}
class HttpsServerWrap {
public startTs = Date.now();
public server: https.Server;
public proxyOptions: ProxyOptions | undefined = undefined;
public disableTLS = false;
public onRequestCallback: OnRequestCallback | undefined = undefined;
private _port = 0;
private _rc = 0;
constructor(private options: HttpsServerWrapOptions) {
this.server = https.createServer({
key: fs.readFileSync(getResourcePath(['app.key'])),
cert: fs.readFileSync(getResourcePath(['app.crt'])),
});
const wss = new WebSocketServer({ noServer: true });
// 监听代理服务器的 upgrade 事件
this.server.on('upgrade', (req, socket, head) => {
// 检查请求是否为 WebSocket 握手请求
if (req.headers.upgrade === 'websocket') {
// 将请求转发到 WebSocket 服务器
wss.handleUpgrade(req, socket, head, (ws) => {
wss.emit('connection', ws, req);
});
} else {
// 如果不是 WebSocket 握手请求,则继续处理为普通的 HTTPS 请求
debugger;
}
});
// 监听 WebSocket 连接事件
wss.on('connection', (ws, req) => {
const webSocketURL = `wss://${this.options.forwardInfo.hostname}:${this.options.forwardInfo.port ?? 443}${req.url}`;
console.log(webSocketURL);
const forwardWSS = new WebSocket(webSocketURL);
forwardWSS.once('unexpected-response', (req, res) => {
debugger;
});
forwardWSS.once('open', () => {
ws.on('message', (message) => {
// 处理接收到的消息
console.log('WS message:', message);
forwardWSS.send(message.toString());
});
forwardWSS.on('message', (message) => {
ws.send(message.toString());
});
});
});
this.server.on('request', (req, res) => {
if (!req.url) {
res.destroy();
return;
}
let options: https.RequestOptions;
if (req.url.startsWith('/')) {
options = {
hostname: this.options.forwardInfo.hostname ?? req.headers.host,
port: this.options.forwardInfo.port ?? 443,
path: req.url,
};
} else {
let urlObj: url.UrlWithStringQuery;
try {
urlObj = url.parse(req.url);
} catch (er) {
res.destroy();
return;
}
options = urlObj;
}
options.method = req.method;
options.headers = req.headers;
const httpRequest = () => {
if (this.disableTLS) {
options.rejectUnauthorized = false;
}
if (this.proxyOptions?.http) {
const agent = new HttpsProxyAgent(this.proxyOptions.http);
options.agent = agent;
}
const remoteReq = https.request(options, (remoteRes) => {
res.statusCode = remoteRes.statusCode!;
Object.entries(remoteRes.headers).forEach(([key, val]) => {
res.setHeader(key, val ?? '');
});
remoteRes.pipe(res);
});
req.socket.once('close', () => {
// console.log('httpsServerWrap.close', req.url);
remoteReq.destroy();
});
req.once('error', () => {
res.destroy();
});
req.pipe(remoteReq);
};
if (this.onRequestCallback != null) {
this.onRequestCallback(req, res, options, httpRequest);
} else {
httpRequest();
}
});
this.server.listen(undefined, '127.0.0.1');
this.server.once('error', (er) => {
console.log('server error', er);
});
}
public async getPort(): Promise<number> {
const getPort = () => {
const address = this.server.address() as net.AddressInfo;
this._port = address.port;
};
if (this.server.listening) {
if (this._port === 0) {
getPort();
}
return this._port;
}
return await new Promise<number>((resolve, reject) => {
const onError = (er: Error) => {
console.error('server error', er);
reject(er);
};
this.server.once('listening', () => {
this.server.off('error', onError);
getPort();
resolve(this._port);
});
this.server.once('error', onError);
});
}
public inc() {
// eslint-disable-next-line no-plusplus
++this._rc;
}
public dec() {
// eslint-disable-next-line no-plusplus
--this._rc;
if (this._rc <= 0) {
this.options.onZeroRefCallback?.();
}
}
}
const HALF_MINUTES = 30 * 1000;
class HttpsServerPool {
private pool = new Map<string, HttpsServerWrap>();
public getOrCreateServer(target: string): HttpsServerWrap {
let server = this.pool.get(target);
if (server == null) {
server = this.createServer(target);
}
return server;
}
public clear(): void {
const { pool } = this;
this.pool = new Map();
pool.forEach((httpsServerWrap) => {
if (httpsServerWrap.server.listening) {
httpsServerWrap.server.close();
}
});
}
private createServer(target: string) {
const arr = target.split(':');
const hostname = arr[0];
const port = arr.length === 1 ? undefined : parseInt(arr[1], 10);
const httpsServerWrap = new HttpsServerWrap({
forwardInfo: {
hostname,
port,
},
onZeroRefCallback: () => {
if (Date.now() - httpsServerWrap.startTs < HALF_MINUTES) {
httpsServerWrap.inc();
setTimeout(() => {
httpsServerWrap.dec();
}, HALF_MINUTES);
return;
}
this.pool.delete(target);
process.nextTick(() => {
if (httpsServerWrap.server.listening) {
httpsServerWrap.server.close();
}
});
},
});
httpsServerWrap.server.once('close', () => {
this.pool.delete(target);
});
this.pool.set(target, httpsServerWrap);
return httpsServerWrap;
}
}
export const httpsServerPool = new HttpsServerPool();
function noop() {}
export interface ProxyOptions {
http?: string;
}
export class HttpTunnel {
public server = http.createServer();
public proxyOptions: ProxyOptions | undefined = undefined;
public disableTLS = false;
public onRequestCallback: OnRequestCallback | undefined = undefined;
private _port = 0;
constructor() {
this.server.on('connect', (req, socket) => {
const method = req.method ?? 'CONNECT';
const target = req.url;
if (!target || method.toUpperCase() !== 'CONNECT') {
socket.destroy();
return;
}
this.handleConnectRequest(target, socket);
});
this.server.on('request', (req, res) => {
if (!req.url) {
res.destroy();
return;
}
let options: http.RequestOptions;
if (req.url.startsWith('/')) {
if (req.headers.host == null) {
res.destroy();
return;
}
options = {
hostname: req.headers.host,
port: 80,
path: req.url,
};
} else {
let urlObj: url.UrlWithStringQuery;
try {
urlObj = url.parse(req.url);
} catch (er) {
res.destroy();
return;
}
options = urlObj;
}
options.method = req.method;
options.headers = req.headers;
const httpRequest = () => {
if (this.proxyOptions?.http) {
const agent = new HttpProxyAgent(this.proxyOptions.http);
options.agent = agent;
}
req.socket.once('close', () => {
// console.log('httpServerWrap.close', req.url);
remoteReq.destroy();
});
req.once('error', () => {
res.destroy();
});
const remoteReq = http.request(options, (remoteRes) => {
res.statusCode = remoteRes.statusCode!;
Object.entries(remoteRes.headers).forEach(([key, val]) => {
res.setHeader(key, val ?? '');
});
remoteRes.pipe(res);
});
req.pipe(remoteReq);
};
if (this.onRequestCallback) {
this.onRequestCallback(req, res, options, httpRequest);
} else {
httpRequest();
}
});
this.server.listen(undefined, '127.0.0.1');
}
public async getPort(): Promise<number> {
const getPort = () => {
const address = this.server.address() as net.AddressInfo;
this._port = address.port;
};
if (this.server.listening) {
if (this._port === 0) {
getPort();
}
return this._port;
}
return await new Promise<number>((resolve, reject) => {
const onError = (er: Error) => {
console.error('server error', er);
reject(er);
};
this.server.once('listening', () => {
this.server.off('error', onError);
getPort();
resolve(this._port);
});
this.server.once('error', onError);
});
}
private handleConnectRequest(target: string, clientSocket: stream.Duplex) {
clientSocket.once('error', (er) => {
console.error('clientSocket.error', er);
});
clientSocket.write('HTTP/1.1 200 Connection Established\r\n\r\n');
const server = httpsServerPool.getOrCreateServer(target);
server.proxyOptions = this.proxyOptions;
server.disableTLS = this.disableTLS;
server.onRequestCallback = this.onRequestCallback;
server.inc();
server.getPort().then(
(port) => {
if (clientSocket.destroyed) {
server.dec();
return;
}
const serverSocket = net.connect(port, '127.0.0.1');
clientSocket.pipe(serverSocket);
serverSocket.pipe(clientSocket);
},
() => {
server.dec();
clientSocket.destroy();
},
);
}
}
export class GetHttpTunnel {
public hooks = {
tunnelClosed: new SyncHook(),
};
private currentTunnel: HttpTunnel | undefined = undefined;
public get(): HttpTunnel {
if (this.currentTunnel == null) {
this.currentTunnel = new HttpTunnel();
const onCloseOrError = once(() => {
this.currentTunnel = undefined;
this.hooks.tunnelClosed.call();
});
this.currentTunnel.server.once('close', onCloseOrError);
this.currentTunnel.server.once('error', onCloseOrError);
}
return this.currentTunnel;
}
}
export const getHttpTunnel = new GetHttpTunnel();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment