diff --git a/src/routes/ChatRoute.tsx b/src/routes/ChatRoute.tsx index 9f467ce..57903a3 100644 --- a/src/routes/ChatRoute.tsx +++ b/src/routes/ChatRoute.tsx @@ -13,9 +13,8 @@ import { } from "@mantine/core"; import { notifications } from "@mantine/notifications"; import { useLiveQuery } from "dexie-react-hooks"; -import { findLast } from "lodash"; import { nanoid } from "nanoid"; -import { useState } from "react"; +import { KeyboardEvent, useState, type ChangeEvent } from "react"; import { AiOutlineSend } from "react-icons/ai"; import { MessageItem } from "../components/MessageItem"; import { db } from "../db"; @@ -37,7 +36,13 @@ export function ChatRoute() { if (!chatId) return []; return db.messages.where("chatId").equals(chatId).sortBy("createdAt"); }, [chatId]); + const userMessages = + messages + ?.filter((message) => message.role === "user") + .map((message) => message.content) || []; + const [userMsgIndex, setUserMsgIndex] = useState(0); const [content, setContent] = useState(""); + const [contentDraft, setContentDraft] = useState(""); const [submitting, setSubmitting] = useState(false); const chat = useLiveQuery(async () => { @@ -183,6 +188,38 @@ export function ChatRoute() { } }; + const onUserMsgToggle = (event: KeyboardEvent) => { + const { selectionStart, selectionEnd } = event.currentTarget; + if ( + !["ArrowUp", "ArrowDown"].includes(event.code) || + selectionStart !== selectionEnd || + (event.code === "ArrowUp" && selectionStart !== 0) || + (event.code === "ArrowDown" && + selectionStart !== event.currentTarget.value.length) + ) { + // do nothing + return; + } + event.preventDefault(); + + const newMsgIndex = userMsgIndex + (event.code === "ArrowUp" ? 1 : -1); + const allMessages = [contentDraft, ...userMessages.reverse()]; + + if (newMsgIndex < 0 || newMsgIndex >= allMessages.length) { + // index out of range, do nothing + return; + } + setContent(allMessages.at(newMsgIndex) || ""); + setUserMsgIndex(newMsgIndex); + }; + + const onContentChange = (event: ChangeEvent) => { + const { value } = event.currentTarget; + setContent(value); + setContentDraft(value); + setUserMsgIndex(0); + }; + if (!chatId) return null; return ( @@ -282,36 +319,18 @@ export function ChatRoute() { minRows={1} maxRows={5} value={content} - onChange={(event) => setContent(event.currentTarget.value)} + onChange={onContentChange} onKeyDown={async (event) => { if (event.code === "Enter" && !event.shiftKey) { event.preventDefault(); submit(); + setUserMsgIndex(0); } if (event.code === "ArrowUp") { - const { selectionStart, selectionEnd } = event.currentTarget; - if (selectionStart !== selectionEnd) return; - if (selectionStart !== 0) return; - event.preventDefault(); - const nextUserMessage = findLast( - messages, - (message) => message.role === "user" - ); - setContent(nextUserMessage?.content ?? ""); + onUserMsgToggle(event); } if (event.code === "ArrowDown") { - const { selectionStart, selectionEnd } = event.currentTarget; - if (selectionStart !== selectionEnd) return; - if (selectionStart !== event.currentTarget.value.length) - return; - event.preventDefault(); - const lastUserMessage = findLast( - messages, - (message) => message.role === "user" - ); - if (lastUserMessage?.content === content) { - setContent(""); - } + onUserMsgToggle(event); } }} />