editor-assistant.tsx 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import {
  2. useCallback, useEffect, useState, useRef, useMemo,
  3. } from 'react';
  4. import { GlobalCodeMirrorEditorKey } from '@growi/editor';
  5. import {
  6. acceptAllChunks, useTextSelectionEffect,
  7. } from '@growi/editor/dist/client/services/unified-merge-view';
  8. import { useCodeMirrorEditorIsolated } from '@growi/editor/dist/client/stores/codemirror-editor';
  9. import { useSecondaryYdocs } from '@growi/editor/dist/client/stores/use-secondary-ydocs';
  10. import { useForm, type UseFormReturn } from 'react-hook-form';
  11. import { useTranslation } from 'react-i18next';
  12. import { type Text as YText } from 'yjs';
  13. import { apiv3Post } from '~/client/util/apiv3-client';
  14. import {
  15. SseMessageSchema,
  16. SseDetectedDiffSchema,
  17. SseFinalizedSchema,
  18. isReplaceDiff,
  19. // isInsertDiff,
  20. // isDeleteDiff,
  21. // isRetainDiff,
  22. type SseMessage,
  23. type SseDetectedDiff,
  24. type SseFinalized,
  25. } from '~/features/openai/interfaces/editor-assistant/sse-schemas';
  26. import { handleIfSuccessfullyParsed } from '~/features/openai/utils/handle-if-successfully-parsed';
  27. import { useIsEnableUnifiedMergeView } from '~/stores-universal/context';
  28. import { EditorMode, useEditorMode } from '~/stores-universal/ui';
  29. import { useCurrentPageId } from '~/stores/page';
  30. import type { AiAssistantHasId } from '../../interfaces/ai-assistant';
  31. import type { MessageLog } from '../../interfaces/message';
  32. import type { IThreadRelationHasId } from '../../interfaces/thread-relation';
  33. import { ThreadType } from '../../interfaces/thread-relation';
  34. import { AiAssistantDropdown } from '../components/AiAssistant/AiAssistantSidebar/AiAssistantDropdown';
  35. // import { type FormData } from '../components/AiAssistant/AiAssistantSidebar/AiAssistantSidebar';
  36. import { MessageCard, type MessageCardRole } from '../components/AiAssistant/AiAssistantSidebar/MessageCard';
  37. import { QuickMenuList } from '../components/AiAssistant/AiAssistantSidebar/QuickMenuList';
  38. import { useAiAssistantSidebar } from '../stores/ai-assistant';
  39. interface CreateThread {
  40. (): Promise<IThreadRelationHasId>;
  41. }
  42. interface PostMessage {
  43. (threadId: string, formData: FormData): Promise<Response>;
  44. }
  45. interface ProcessMessage {
  46. (data: unknown, handler: {
  47. onMessage: (data: SseMessage) => void;
  48. onDetectedDiff: (data: SseDetectedDiff) => void;
  49. onFinalized: (data: SseFinalized) => void;
  50. }): void;
  51. }
  52. interface GenerateInitialView {
  53. (onSubmit: (data: FormData) => Promise<void>): JSX.Element;
  54. }
  55. interface GenerateMessageCard {
  56. (role: MessageCardRole, children: string, messageId: string, messageLogs: MessageLog[], generatingAnswerMessage?: MessageLog): JSX.Element;
  57. }
  58. export interface FormData {
  59. input: string,
  60. markdownType?: 'full' | 'selected' | 'none'
  61. }
  62. type DetectedDiff = Array<{
  63. data: SseDetectedDiff,
  64. applied: boolean,
  65. id: string,
  66. }>
  67. type UseEditorAssistant = () => {
  68. createThread: CreateThread,
  69. postMessage: PostMessage,
  70. processMessage: ProcessMessage,
  71. form: UseFormReturn<FormData>
  72. resetForm: () => void
  73. isTextSelected: boolean,
  74. isGeneratingEditorText: boolean,
  75. // Views
  76. generateInitialView: GenerateInitialView,
  77. generateMessageCard: GenerateMessageCard,
  78. headerIcon: JSX.Element,
  79. headerText: JSX.Element,
  80. placeHolder: string,
  81. }
  82. const insertTextAtLine = (yText: YText, lineNumber: number, textToInsert: string): void => {
  83. // Get the entire text content
  84. const content = yText.toString();
  85. // Split by newlines to get all lines
  86. const lines = content.split('\n');
  87. // Calculate the index position for insertion
  88. let insertPosition = 0;
  89. // Sum the length of all lines before the target line (plus newline characters)
  90. for (let i = 0; i < lineNumber && i < lines.length; i++) {
  91. insertPosition += lines[i].length + 1; // +1 for the newline character
  92. }
  93. // Insert the text at the calculated position
  94. yText.insert(insertPosition, textToInsert);
  95. };
  96. const appendTextLastLine = (yText: YText, textToAppend: string) => {
  97. const content = yText.toString();
  98. const insertPosition = content.length;
  99. yText.insert(insertPosition, `\n\n${textToAppend}`);
  100. };
  101. const getLineInfo = (yText: YText, lineNumber: number): { text: string, startIndex: number } | null => {
  102. // Get the entire text content
  103. const content = yText.toString();
  104. // Split by newlines to get all lines
  105. const lines = content.split('\n');
  106. // Check if the requested line exists
  107. if (lineNumber < 0 || lineNumber >= lines.length) {
  108. return null; // Line doesn't exist
  109. }
  110. // Get the text of the specified line
  111. const text = lines[lineNumber];
  112. // Calculate the start index of the line
  113. let startIndex = 0;
  114. for (let i = 0; i < lineNumber; i++) {
  115. startIndex += lines[i].length + 1; // +1 for the newline character
  116. }
  117. // Return comprehensive line information
  118. return {
  119. text,
  120. startIndex,
  121. };
  122. };
  123. export const useEditorAssistant: UseEditorAssistant = () => {
  124. // Refs
  125. // const positionRef = useRef<number>(0);
  126. const lineRef = useRef<number>(0);
  127. // States
  128. const [detectedDiff, setDetectedDiff] = useState<DetectedDiff>();
  129. const [selectedAiAssistant, setSelectedAiAssistant] = useState<AiAssistantHasId>();
  130. const [selectedText, setSelectedText] = useState<string>();
  131. const [isGeneratingEditorText, setIsGeneratingEditorText] = useState<boolean>(false);
  132. const isTextSelected = useMemo(() => selectedText != null && selectedText.length !== 0, [selectedText]);
  133. // Hooks
  134. const { t } = useTranslation();
  135. const { data: currentPageId } = useCurrentPageId();
  136. const { data: isEnableUnifiedMergeView, mutate: mutateIsEnableUnifiedMergeView } = useIsEnableUnifiedMergeView();
  137. const { data: codeMirrorEditor } = useCodeMirrorEditorIsolated(GlobalCodeMirrorEditorKey.MAIN);
  138. const yDocs = useSecondaryYdocs(isEnableUnifiedMergeView ?? false, { pageId: currentPageId ?? undefined, useSecondary: isEnableUnifiedMergeView ?? false });
  139. const { data: aiAssistantSidebarData } = useAiAssistantSidebar();
  140. const form = useForm<FormData>({
  141. defaultValues: {
  142. input: '',
  143. },
  144. });
  145. // Functions
  146. const resetForm = useCallback(() => {
  147. form.reset({ input: '' });
  148. }, [form]);
  149. const createThread: CreateThread = useCallback(async() => {
  150. const response = await apiv3Post<IThreadRelationHasId>('/openai/thread', {
  151. type: ThreadType.EDITOR,
  152. aiAssistantId: selectedAiAssistant?._id,
  153. });
  154. return response.data;
  155. }, [selectedAiAssistant?._id]);
  156. const postMessage: PostMessage = useCallback(async(threadId, formData) => {
  157. const getMarkdown = (): string | undefined => {
  158. if (formData.markdownType === 'none') {
  159. return undefined;
  160. }
  161. if (formData.markdownType === 'selected') {
  162. return selectedText;
  163. }
  164. if (formData.markdownType === 'full') {
  165. return codeMirrorEditor?.getDoc();
  166. }
  167. };
  168. const response = await fetch('/_api/v3/openai/edit', {
  169. method: 'POST',
  170. headers: { 'Content-Type': 'application/json' },
  171. body: JSON.stringify({
  172. threadId,
  173. userMessage: formData.input,
  174. markdown: getMarkdown(),
  175. }),
  176. });
  177. return response;
  178. }, [codeMirrorEditor, selectedText]);
  179. const processMessage: ProcessMessage = useCallback((data, handler) => {
  180. handleIfSuccessfullyParsed(data, SseMessageSchema, (data: SseMessage) => {
  181. handler.onMessage(data);
  182. setIsGeneratingEditorText(true);
  183. });
  184. handleIfSuccessfullyParsed(data, SseDetectedDiffSchema, (data: SseDetectedDiff) => {
  185. mutateIsEnableUnifiedMergeView(true);
  186. setDetectedDiff((prev) => {
  187. const newData = { data, applied: false, id: crypto.randomUUID() };
  188. if (prev == null) {
  189. return [newData];
  190. }
  191. return [...prev, newData];
  192. });
  193. handler.onDetectedDiff(data);
  194. });
  195. handleIfSuccessfullyParsed(data, SseFinalizedSchema, (data: SseFinalized) => {
  196. setIsGeneratingEditorText(false);
  197. handler.onFinalized(data);
  198. });
  199. }, [mutateIsEnableUnifiedMergeView]);
  200. const selectTextHandler = useCallback((selectedText: string, selectedTextFirstLineNumber: number) => {
  201. setSelectedText(selectedText);
  202. lineRef.current = selectedTextFirstLineNumber;
  203. }, []);
  204. // Effects
  205. useTextSelectionEffect(codeMirrorEditor, selectTextHandler);
  206. useEffect(() => {
  207. const pendingDetectedDiff: DetectedDiff | undefined = detectedDiff?.filter(diff => diff.applied === false);
  208. if (yDocs?.secondaryDoc != null && pendingDetectedDiff != null && pendingDetectedDiff.length > 0) {
  209. // For debug
  210. // const testDetectedDiff = [
  211. // {
  212. // data: { diff: { retain: 9 } },
  213. // applied: false,
  214. // id: crypto.randomUUID(),
  215. // },
  216. // {
  217. // data: { diff: { delete: 5 } },
  218. // applied: false,
  219. // id: crypto.randomUUID(),
  220. // },
  221. // {
  222. // data: { diff: { insert: 'growi' } },
  223. // applied: false,
  224. // id: crypto.randomUUID(),
  225. // },
  226. // ];
  227. const yText = yDocs.secondaryDoc.getText('codemirror');
  228. yDocs.secondaryDoc.transact(() => {
  229. pendingDetectedDiff.forEach((detectedDiff) => {
  230. if (isReplaceDiff(detectedDiff.data)) {
  231. if (isTextSelected) {
  232. const lineInfo = getLineInfo(yText, lineRef.current);
  233. if (lineInfo != null && lineInfo.text !== detectedDiff.data.diff.replace) {
  234. yText.delete(lineInfo.startIndex, lineInfo.text.length);
  235. insertTextAtLine(yText, lineRef.current, detectedDiff.data.diff.replace);
  236. }
  237. lineRef.current += 1;
  238. }
  239. else {
  240. appendTextLastLine(yText, detectedDiff.data.diff.replace);
  241. }
  242. }
  243. // if (isInsertDiff(detectedDiff.data)) {
  244. // yText.insert(positionRef.current, detectedDiff.data.diff.insert);
  245. // }
  246. // if (isDeleteDiff(detectedDiff.data)) {
  247. // yText.delete(positionRef.current, detectedDiff.data.diff.delete);
  248. // }
  249. // if (isRetainDiff(detectedDiff.data)) {
  250. // positionRef.current += detectedDiff.data.diff.retain;
  251. // }
  252. });
  253. });
  254. // Mark items as applied after applying to secondaryDoc
  255. setDetectedDiff((prev) => {
  256. if (!prev) return prev;
  257. const pendingDetectedDiffIds = pendingDetectedDiff.map(diff => diff.id);
  258. return prev.map((diff) => {
  259. if (pendingDetectedDiffIds.includes(diff.id)) {
  260. return { ...diff, applied: true };
  261. }
  262. return diff;
  263. });
  264. });
  265. }
  266. }, [codeMirrorEditor, detectedDiff, isTextSelected, selectedText, yDocs?.secondaryDoc]);
  267. // Set detectedDiff to undefined after applying all detectedDiff to secondaryDoc
  268. useEffect(() => {
  269. if (detectedDiff?.filter(detectedDiff => detectedDiff.applied === false).length === 0) {
  270. setSelectedText(undefined);
  271. setDetectedDiff(undefined);
  272. lineRef.current = 0;
  273. // positionRef.current = 0;
  274. }
  275. }, [detectedDiff]);
  276. // Views
  277. const headerIcon = useMemo(() => {
  278. return <span className="material-symbols-outlined growi-ai-chat-icon me-3 fs-4">support_agent</span>;
  279. }, []);
  280. const headerText = useMemo(() => {
  281. return <>{t('Editor Assistant')}</>;
  282. }, [t]);
  283. const placeHolder = useMemo(() => { return 'sidebar_ai_assistant.editor_assistant_placeholder' }, []);
  284. const generateInitialView: GenerateInitialView = useCallback((onSubmit) => {
  285. const selectAiAssistantHandler = (aiAssistant?: AiAssistantHasId) => {
  286. setSelectedAiAssistant(aiAssistant);
  287. };
  288. const clickQuickMenuHandler = async(quickMenu: string) => {
  289. await onSubmit({ input: quickMenu, markdownType: 'full' });
  290. };
  291. return (
  292. <>
  293. <div className="py-2">
  294. <AiAssistantDropdown
  295. selectedAiAssistant={selectedAiAssistant}
  296. onSelect={selectAiAssistantHandler}
  297. />
  298. </div>
  299. <QuickMenuList
  300. onClick={clickQuickMenuHandler}
  301. />
  302. </>
  303. );
  304. }, [selectedAiAssistant]);
  305. const generateMessageCard: GenerateMessageCard = useCallback((role, children, messageId, messageLogs, generatingAnswerMessage) => {
  306. const isActionButtonShown = (() => {
  307. if (!aiAssistantSidebarData?.isEditorAssistant) {
  308. return false;
  309. }
  310. if (generatingAnswerMessage != null) {
  311. return false;
  312. }
  313. const latestAssistantMessageLogId = messageLogs
  314. .filter(message => !message.isUserMessage)
  315. .slice(-1)[0];
  316. if (messageId === latestAssistantMessageLogId?.id) {
  317. return true;
  318. }
  319. return false;
  320. })();
  321. const accept = () => {
  322. if (codeMirrorEditor?.view == null) {
  323. return;
  324. }
  325. acceptAllChunks(codeMirrorEditor.view);
  326. mutateIsEnableUnifiedMergeView(false);
  327. };
  328. const reject = () => {
  329. mutateIsEnableUnifiedMergeView(false);
  330. };
  331. return (
  332. <MessageCard
  333. role={role}
  334. showActionButtons={isActionButtonShown}
  335. onAccept={accept}
  336. onDiscard={reject}
  337. >
  338. {children}
  339. </MessageCard>
  340. );
  341. }, [aiAssistantSidebarData?.isEditorAssistant, codeMirrorEditor?.view, mutateIsEnableUnifiedMergeView]);
  342. return {
  343. createThread,
  344. postMessage,
  345. processMessage,
  346. form,
  347. resetForm,
  348. isTextSelected,
  349. isGeneratingEditorText,
  350. // Views
  351. generateInitialView,
  352. generateMessageCard,
  353. headerIcon,
  354. headerText,
  355. placeHolder,
  356. };
  357. };
  358. // type guard
  359. export const isEditorAssistantFormData = (formData): formData is FormData => {
  360. return 'markdownType' in formData;
  361. };