import {
	DOMConversion,
	DOMConversionMap,
	DOMConversionOutput,
	LexicalNode,
	NodeKey,
	SerializedTextNode,
	TextNode,
} from "lexical";
import { nodeStyleToString } from "../utils";

export class ExtendedTextNode extends TextNode {
	constructor(text: string, key?: NodeKey) {
		super(text, key);
	}

	static getType(): string {
		return "extended-text";
	}

	static clone(node: ExtendedTextNode): ExtendedTextNode {
		return new ExtendedTextNode(node.__text, node.__key);
	}

	static importDOM(): DOMConversionMap | null {
		const importers = TextNode.importDOM();

		return {
			...importers,
			b: () => ({
				conversion: patchStyleConversion(importers?.b),
				priority: 1,
			}),
			code: () => ({
				conversion: patchStyleConversion(importers?.code),
				priority: 1,
			}),
			em: () => ({
				conversion: patchStyleConversion(importers?.em),
				priority: 1,
			}),
			i: () => ({
				conversion: patchStyleConversion(importers?.i),
				priority: 1,
			}),
			span: () => ({
				conversion: patchStyleConversion(importers?.span),
				priority: 1,
			}),
			strong: () => ({
				conversion: patchStyleConversion(importers?.strong),
				priority: 1,
			}),
			sub: () => ({
				conversion: patchStyleConversion(importers?.sub),
				priority: 1,
			}),
			sup: () => ({
				conversion: patchStyleConversion(importers?.sup),
				priority: 1,
			}),
			s: () => ({
				conversion: patchStyleConversion(importers?.s),
				priority: 1,
			}),
			u: () => ({
				conversion: patchStyleConversion(importers?.u),
				priority: 1,
			}),
		};
	}

	static importJSON(serializedNode: SerializedTextNode): TextNode {
		return TextNode.importJSON(serializedNode);
	}
}

export function patchStyleConversion(
	originalDOMConverter?: (node: HTMLElement) => DOMConversion | null
): (node: HTMLElement) => DOMConversionOutput | null {
	return (node) => {
		const original = originalDOMConverter?.(node);

		if (!original) return null;

		const originalOutput = original.conversion(node);

		if (!originalOutput) return originalOutput;

		const style = nodeStyleToString(node.style);

		return {
			...originalOutput,
			forChild: (lexicalNode, parent) => {
				const originalForChild = originalOutput?.forChild ?? ((x) => x);
				const result = originalForChild(lexicalNode, parent) as LexicalNode | TextNode | undefined;

				if (result && "setStyle" in result && style.length) return result.setStyle(style);

				return result;
			},
		};
	};
}
