diff --git a/new-ui/src/pages/compact/CompactLocationsPage/CompactLocationsPage.tsx b/new-ui/src/pages/compact/CompactLocationsPage/CompactLocationsPage.tsx index 51b026702..8bba51700 100644 --- a/new-ui/src/pages/compact/CompactLocationsPage/CompactLocationsPage.tsx +++ b/new-ui/src/pages/compact/CompactLocationsPage/CompactLocationsPage.tsx @@ -9,6 +9,7 @@ import { Divider } from '../../../shared/components/Divider/Divider'; import { LocationCard } from '../../../shared/components/LocationCard/LocationCard'; import { ScrollContainer } from '../../../shared/components/ScrollContainer/ScrollContainer'; import { WindowHeader } from '../../../shared/components/WindowHeader/WindowHeader'; +import { useAppData } from '../../../shared/providers/AppDataContext'; import { api } from '../../../shared/rust-api/api'; import { getInstancesQueryOptions, @@ -21,7 +22,7 @@ import { CompactPage } from '../CompactPage/CompactPage'; import { InstanceSwitcher } from './components/InstanceSwitcher'; export const CompactLocationsPage = () => { - const selection = useAppStore((s) => s.compactViewSelection); + const { viewSelection: selection, setViewSelection } = useAppData(); const openLocation = useAppStore((s) => s.expandedLocation); const routeData = useLoaderData({ from: '/compact/' }); @@ -53,11 +54,9 @@ export const CompactLocationsPage = () => { useEffect(() => { if (selection === null || instanceInfo === undefined) { - useAppStore.setState({ - compactViewSelection: { kind: 'instance', data: routeData.instances[0] }, - }); + setViewSelection({ kind: 'instance', data: routeData.instances[0] }); } - }, [routeData.instances, instanceInfo, selection]); + }, [routeData.instances, instanceInfo, selection, setViewSelection]); return ( { - const selectedInstance = useAppStore((s) => s.compactViewSelection); + const { viewSelection: selectedInstance, setViewSelection } = useAppData(); const { data: tunnels } = useQuery(getTunnelsQueryOptions); const { data: instances } = useQuery(getInstancesQueryOptions); - const groups = useMemo((): readonly SelectOptionGroup[] => { + const groups = useMemo((): readonly SelectOptionGroup[] => { if (!isPresent(instances) || !isPresent(tunnels)) return []; - const instanceGroup: SelectOptionGroup = { + const instanceGroup: SelectOptionGroup = { key: 'instances', label: 'Instances', options: instances.map((instance) => ({ @@ -34,7 +32,7 @@ export const InstanceSwitcher = () => { })), }; - const tunnelGroup: SelectOptionGroup = { + const tunnelGroup: SelectOptionGroup = { key: 'tunnels', label: 'Tunnels', options: tunnels.map((tunnel) => ({ @@ -52,7 +50,7 @@ export const InstanceSwitcher = () => { [groups], ); - const selectedOption = useMemo((): SelectOption | undefined => { + const selectedOption = useMemo((): SelectOption | undefined => { if (!isPresent(selectedInstance)) return undefined; for (const group of groups) { const found = group.options.find((o) => { @@ -77,7 +75,7 @@ export const InstanceSwitcher = () => { groups={groups} value={selectedOption as never} onChange={(option) => { - useAppStore.setState({ compactViewSelection: option.value }); + setViewSelection(option.value); }} /> ); diff --git a/new-ui/src/pages/full/OverviewPage/OverviewPage.tsx b/new-ui/src/pages/full/OverviewPage/OverviewPage.tsx index c2bea3309..50f3ae7b5 100644 --- a/new-ui/src/pages/full/OverviewPage/OverviewPage.tsx +++ b/new-ui/src/pages/full/OverviewPage/OverviewPage.tsx @@ -11,7 +11,6 @@ import { FullPage } from '../../../shared/layouts/FullPage/FullPage'; import { useAppData } from '../../../shared/providers/AppDataContext'; import { getLocationsQueryOptions } from '../../../shared/rust-api/query'; import type { InstanceInfo } from '../../../shared/rust-api/types'; -import { useAppStore } from '../../../shared/store/useAppStore'; import { ThemeSpacing } from '../../../shared/types'; import { isPresent } from '../../../shared/utils/isPresent'; import { ConnectModal } from './components/ConnectModal/ConnectModal'; @@ -23,7 +22,7 @@ const isWindows = platform() === 'windows'; export const OverviewPage = () => { const [detailsOpen, setDetailsOpen] = useState(false); const { instances, tunnels } = useAppData(); - const selection = useAppStore((s) => s.compactViewSelection); + const { viewSelection: selection } = useAppData(); const queryInstanceId = useMemo(() => { if (!isPresent(selection)) return instances[0].id; diff --git a/new-ui/src/pages/full/OverviewPage/components/ConnectModal/components/ConnectModalPostureCheckLoading/ConnectModalPostureCheckLoading.tsx b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/components/ConnectModalPostureCheckLoading/ConnectModalPostureCheckLoading.tsx new file mode 100644 index 000000000..e8afb4028 --- /dev/null +++ b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/components/ConnectModalPostureCheckLoading/ConnectModalPostureCheckLoading.tsx @@ -0,0 +1,11 @@ +import './style.scss'; +import { LoaderSpinner } from '../../../../../../../shared/components/LoaderSpinner/LoaderSpinner'; + +export const ConnectModalPostureCheckLoading = () => { + return ( +
+ +

Checking device requirements...

+
+ ); +}; diff --git a/new-ui/src/pages/full/OverviewPage/components/ConnectModal/components/ConnectModalPostureCheckLoading/style.scss b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/components/ConnectModalPostureCheckLoading/style.scss new file mode 100644 index 000000000..d0cdd68c1 --- /dev/null +++ b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/components/ConnectModalPostureCheckLoading/style.scss @@ -0,0 +1,14 @@ +.connect-modal-posture-check-loading { + display: flex; + flex-flow: column; + align-items: center; + justify-content: center; + row-gap: var(--spacing-lg); + min-height: 220px; + + p { + font: var(--t-body-xs-400); + color: var(--fg-white-100); + text-align: center; + } +} diff --git a/new-ui/src/pages/full/OverviewPage/components/ConnectModal/hooks/useConnectModalMfaOidc.ts b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/hooks/useConnectModalMfaOidc.ts index 9c352a9bf..de844cd13 100644 --- a/new-ui/src/pages/full/OverviewPage/components/ConnectModal/hooks/useConnectModalMfaOidc.ts +++ b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/hooks/useConnectModalMfaOidc.ts @@ -20,7 +20,7 @@ type MfaFinishResponse = { preshared_key: string }; type MfaErrorResponse = { error: string }; type Options = { - onPostureError?: () => void; + onPostureError?: (msg: string) => void; onSessionExpired?: () => void; }; @@ -146,7 +146,7 @@ export const useConnectModalMfaOidc = ({ startPolling(response.token, instance.proxy_url, headers); } catch (e) { if (shouldShowPostureError(e, location)) { - onPostureError?.(); + onPostureError?.(e.message); return; } setStartError( diff --git a/new-ui/src/pages/full/OverviewPage/components/ConnectModal/style.scss b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/style.scss index 844e9a1ce..db7fbaece 100644 --- a/new-ui/src/pages/full/OverviewPage/components/ConnectModal/style.scss +++ b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/style.scss @@ -6,6 +6,8 @@ } .controls { + width: 100%; + .full { width: 100%; } diff --git a/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaEmail/ConnectModalMfaEmail.tsx b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaEmail/ConnectModalMfaEmail.tsx index 756c4c26f..009f7ee76 100644 --- a/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaEmail/ConnectModalMfaEmail.tsx +++ b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaEmail/ConnectModalMfaEmail.tsx @@ -11,6 +11,7 @@ import { MfaStartMethod } from '../../../../../../../shared/components/LocationC import { useMfaConnect } from '../../../../../../../shared/components/LocationCard/hooks/useMfaConnect'; import type { LocationInfo } from '../../../../../../../shared/rust-api/types'; import { isPresent } from '../../../../../../../shared/utils/isPresent'; +import { ConnectModalPostureCheckLoading } from '../../components/ConnectModalPostureCheckLoading/ConnectModalPostureCheckLoading'; import { ConnectModalView } from '../../hooks/types'; import { useConnectModal } from '../../hooks/useConnectModal'; @@ -28,8 +29,10 @@ export const ConnectModalMfaEmail = () => { debounceMs: location?.posture_check_required ? MIN_POSTURE_LOADER_MS : 0, onSessionExpired: () => useConnectModal.getState().setView(perviousView ?? ConnectModalView.MfaSettings), - onPostureError: () => - useConnectModal.getState().setView(ConnectModalView.PostureCheckFail), + onPostureError: (msg) => { + useConnectModal.setState({ postureError: msg }); + useConnectModal.getState().setView(ConnectModalView.PostureCheckFail); + }, }, ); @@ -57,6 +60,10 @@ export const ConnectModalMfaEmail = () => { if (verifyError) setError(verifyError); }, [verifyError]); + if (isStarting && location?.posture_check_required && !startError) { + return ; + } + return (
{ const { start, isStarting, startError, qrValue, connectionError } = useMfaMobileConnect( location as LocationInfo, { - onPostureError: () => - useConnectModal.getState().setView(ConnectModalView.PostureCheckFail), + onPostureError: (msg) => { + useConnectModal.setState({ postureError: msg }); + useConnectModal.getState().setView(ConnectModalView.PostureCheckFail); + }, }, ); @@ -46,6 +49,10 @@ export const ConnectModalMfaMobile = () => { const errorMessage = startError ?? connectionError; + if (isStarting && location?.posture_check_required && !startError) { + return ; + } + return (

diff --git a/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaOidc/ConnectModalMfaOidc.tsx b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaOidc/ConnectModalMfaOidc.tsx index 6c69b75ad..8b1769615 100644 --- a/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaOidc/ConnectModalMfaOidc.tsx +++ b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaOidc/ConnectModalMfaOidc.tsx @@ -3,6 +3,7 @@ import { useShallow } from 'zustand/shallow'; import { Button } from '../../../../../../../shared/components/Button/Button'; import { ButtonVariant } from '../../../../../../../shared/components/Button/types'; import { Controls } from '../../../../../../../shared/components/Controls/Controls'; +import { ConnectModalPostureCheckLoading } from '../../components/ConnectModalPostureCheckLoading/ConnectModalPostureCheckLoading'; import { ConnectModalView } from '../../hooks/types'; import { useConnectModal } from '../../hooks/useConnectModal'; import { useConnectModalMfaOidc } from '../../hooks/useConnectModalMfaOidc'; @@ -10,13 +11,17 @@ import { useConnectModalMfaOidc } from '../../hooks/useConnectModalMfaOidc'; type Screen = 'idle' | 'polling' | 'error'; export const ConnectModalMfaOidc = () => { - const perviousView = useConnectModal(useShallow((s) => s.perviousView)); + const [perviousView, location] = useConnectModal( + useShallow((s) => [s.perviousView, s.location]), + ); const { start, isStarting, startError, isPolling, pollError } = useConnectModalMfaOidc({ onSessionExpired: () => useConnectModal.getState().setView(perviousView ?? ConnectModalView.MfaSettings), - onPostureError: () => - useConnectModal.getState().setView(ConnectModalView.PostureCheckFail), + onPostureError: (msg) => { + useConnectModal.setState({ postureError: msg }); + useConnectModal.getState().setView(ConnectModalView.PostureCheckFail); + }, }); const [screen, setScreen] = useState('idle'); @@ -36,6 +41,10 @@ export const ConnectModalMfaOidc = () => { const errorMessage = startError ?? pollError; + if (isStarting && location?.posture_check_required && !startError) { + return ; + } + return (

{screen === 'idle' && ( diff --git a/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaSettings/ConnectModalMfaSettings.tsx b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaSettings/ConnectModalMfaSettings.tsx index 6ca736e7a..f7c4d6b26 100644 --- a/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaSettings/ConnectModalMfaSettings.tsx +++ b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaSettings/ConnectModalMfaSettings.tsx @@ -11,6 +11,7 @@ import { IconButton } from '../../../../../../../shared/components/IconButton/Ic import { IconButtonVariant } from '../../../../../../../shared/components/IconButton/types'; import { MfaSelector } from '../../../../../../../shared/components/LocationCard/components/MfaSelector/MfaSelector'; import { SizedBox } from '../../../../../../../shared/components/SizedBox/SizedBox'; +import { useAppData } from '../../../../../../../shared/providers/AppDataContext'; import { api } from '../../../../../../../shared/rust-api/api'; import { LocationMfaMode, @@ -18,6 +19,7 @@ import { type MfaMethodValue, } from '../../../../../../../shared/rust-api/types'; import { ThemeSpacing } from '../../../../../../../shared/types'; +import { ConnectModalView } from '../../hooks/types'; import { useConnectModal } from '../../hooks/useConnectModal'; export const ConnectModalMfaSettings = () => { @@ -26,6 +28,8 @@ export const ConnectModalMfaSettings = () => { meta: { invalidate: [['locations']] }, }); + const { locationMfaPreference, setLocationMfaPreference } = useAppData(); + const [perviousView, location] = useConnectModal( useShallow((s) => [s.perviousView, s.location]), ); @@ -33,7 +37,9 @@ export const ConnectModalMfaSettings = () => { const locationDefaultMfaMethod = location?.mfa_method ?? MfaMethod.Totp; const [selectedMethod, setSelectedMethod] = useState( - location?.mfa_method ?? MfaMethod.Totp, + location + ? (locationMfaPreference[String(location.id)] ?? MfaMethod.Totp) + : MfaMethod.Totp, ); const [setAsDefault, setSetAsDefault] = useState(true); @@ -45,13 +51,32 @@ export const ConnectModalMfaSettings = () => { }, [location?.location_mfa_mode]); const handleSubmit = () => { + if (!location) return; + setLocationMfaPreference(location.id, selectedMethod); if (setAsDefault && selectedMethod !== locationDefaultMfaMethod && location) { setMfaMethod({ locationId: location.id, mfaMethod: selectedMethod }); + } else { } if (perviousView === null) { useConnectModal.setState({ visible: false }); } else { - useConnectModal.getState().setView(perviousView); + switch (selectedMethod) { + case 'totp': + useConnectModal.setState({ view: ConnectModalView.MfaTotp }); + break; + case 'email': + useConnectModal.setState({ view: ConnectModalView.MfaEmail }); + break; + case 'mobileapprove': + useConnectModal.setState({ view: ConnectModalView.MfaMobile }); + break; + case 'oidc': + useConnectModal.setState({ view: ConnectModalView.MfaOidc }); + break; + default: + useConnectModal.setState({ visible: false }); + break; + } } }; diff --git a/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaTotp/ConnectModalMfaTotp.tsx b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaTotp/ConnectModalMfaTotp.tsx index 9e51717d4..f3d920299 100644 --- a/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaTotp/ConnectModalMfaTotp.tsx +++ b/new-ui/src/pages/full/OverviewPage/components/ConnectModal/views/ConnectModalMfaTotp/ConnectModalMfaTotp.tsx @@ -11,6 +11,7 @@ import { MfaStartMethod } from '../../../../../../../shared/components/LocationC import { useMfaConnect } from '../../../../../../../shared/components/LocationCard/hooks/useMfaConnect'; import type { LocationInfo } from '../../../../../../../shared/rust-api/types'; import { isPresent } from '../../../../../../../shared/utils/isPresent'; +import { ConnectModalPostureCheckLoading } from '../../components/ConnectModalPostureCheckLoading/ConnectModalPostureCheckLoading'; import { ConnectModalView } from '../../hooks/types'; import { useConnectModal } from '../../hooks/useConnectModal'; @@ -28,8 +29,10 @@ export const ConnectModalMfaTotp = () => { debounceMs: location?.posture_check_required ? MIN_POSTURE_LOADER_MS : 0, onSessionExpired: () => useConnectModal.getState().setView(perviousView ?? ConnectModalView.MfaSettings), - onPostureError: () => - useConnectModal.getState().setView(ConnectModalView.PostureCheckFail), + onPostureError: (err) => { + useConnectModal.setState({ postureError: err }); + useConnectModal.getState().setView(ConnectModalView.PostureCheckFail); + }, }, ); @@ -57,6 +60,10 @@ export const ConnectModalMfaTotp = () => { if (verifyError) setError(verifyError); }, [verifyError]); + if (isStarting && location?.posture_check_required && !startError) { + return ; + } + return (
( ); export const OverviewSelection = ({ instances, tunnels }: Props) => { - const selection = useAppStore((s) => s.compactViewSelection); + const { viewSelection: selection, setViewSelection } = useAppData(); - const setSelection = (value: CompactViewSelection) => { - useAppStore.setState({ compactViewSelection: value }); + const setSelection = (value: OverviewViewSelection) => { + setViewSelection(value); }; - const isSelected = (candidate: CompactViewSelection): boolean => { + const isSelected = (candidate: OverviewViewSelection): boolean => { if (!selection) return false; if (candidate.kind !== selection.kind) return false; return candidate.data.id === selection.data.id; @@ -44,7 +44,7 @@ export const OverviewSelection = ({ instances, tunnels }: Props) => {

Instances

{instances.map((instance) => { - const value: CompactViewSelection = { kind: 'instance', data: instance }; + const value: OverviewViewSelection = { kind: 'instance', data: instance }; return ( {

Tunnels

{tunnels.map((tunnel) => { - const value: CompactViewSelection = { kind: 'tunnel', data: tunnel }; + const value: OverviewViewSelection = { kind: 'tunnel', data: tunnel }; return ( ()({ function RootComponent() { return ( - - - + + + + + ); } diff --git a/new-ui/src/routes/compact/index.tsx b/new-ui/src/routes/compact/index.tsx index 441551d7c..5f302f0b8 100644 --- a/new-ui/src/routes/compact/index.tsx +++ b/new-ui/src/routes/compact/index.tsx @@ -1,12 +1,13 @@ import { createFileRoute, redirect } from '@tanstack/react-router'; import { CompactLocationsPage } from '../../pages/compact/CompactLocationsPage/CompactLocationsPage'; +import { api } from '../../shared/rust-api/api'; import { getInstancesQueryOptions, getLocationsQueryOptions, + getSessionStateQueryOptions, getTunnelsQueryOptions, } from '../../shared/rust-api/query'; -import type { LocationInfo } from '../../shared/rust-api/types'; -import { useAppStore } from '../../shared/store/useAppStore'; +import type { LocationInfo, OverviewViewSelection } from '../../shared/rust-api/types'; export const Route = createFileRoute('/compact/')({ loader: async ({ context }) => { @@ -19,7 +20,10 @@ export const Route = createFileRoute('/compact/')({ throw redirect({ to: '/empty' }); } - const stored = useAppStore.getState().compactViewSelection; + const sessionState = await context.queryClient.fetchQuery( + getSessionStateQueryOptions, + ); + const stored = sessionState?.view_selection ?? null; let storedIsValid: boolean; if (stored === null) { @@ -30,7 +34,7 @@ export const Route = createFileRoute('/compact/')({ storedIsValid = tunnels.some((t) => t.id === stored.data.id); } - let selected: NonNullable; + let selected: OverviewViewSelection; if (storedIsValid && stored !== null) { selected = stored; } else if (instances.length > 0) { @@ -40,7 +44,8 @@ export const Route = createFileRoute('/compact/')({ } if (!storedIsValid) { - useAppStore.setState({ compactViewSelection: selected }); + await api.patchSessionState({ view_selection: selected }); + await context.queryClient.invalidateQueries({ queryKey: ['session-state'] }); } let locations: LocationInfo[]; diff --git a/new-ui/src/routes/full.tsx b/new-ui/src/routes/full.tsx index f9e4c062f..6321128ae 100644 --- a/new-ui/src/routes/full.tsx +++ b/new-ui/src/routes/full.tsx @@ -1,14 +1,9 @@ import { createFileRoute, Outlet } from '@tanstack/react-router'; -import { AppDataProvider } from '../shared/providers/AppDataContext'; export const Route = createFileRoute('/full')({ component: RouteComponent, }); function RouteComponent() { - return ( - - - - ); + return ; } diff --git a/new-ui/src/routes/full/_default/overview.tsx b/new-ui/src/routes/full/_default/overview.tsx index 711692515..4c4341386 100644 --- a/new-ui/src/routes/full/_default/overview.tsx +++ b/new-ui/src/routes/full/_default/overview.tsx @@ -1,10 +1,11 @@ import { createFileRoute, redirect } from '@tanstack/react-router'; import { OverviewPage } from '../../../pages/full/OverviewPage/OverviewPage'; +import { api } from '../../../shared/rust-api/api'; import { getInstancesQueryOptions, + getSessionStateQueryOptions, getTunnelsQueryOptions, } from '../../../shared/rust-api/query'; -import { useAppStore } from '../../../shared/store/useAppStore'; export const Route = createFileRoute('/full/_default/overview')({ loader: async ({ context }) => { @@ -17,7 +18,10 @@ export const Route = createFileRoute('/full/_default/overview')({ throw redirect({ to: '/empty' }); } - const stored = useAppStore.getState().compactViewSelection; + const sessionState = await context.queryClient.fetchQuery( + getSessionStateQueryOptions, + ); + const stored = sessionState?.view_selection ?? null; let storedIsValid: boolean; if (stored === null) { @@ -33,7 +37,8 @@ export const Route = createFileRoute('/full/_default/overview')({ instances.length > 0 ? { kind: 'instance' as const, data: instances[0] } : { kind: 'tunnel' as const, data: tunnels[0] }; - useAppStore.setState({ compactViewSelection: selected }); + await api.patchSessionState({ view_selection: selected }); + await context.queryClient.invalidateQueries({ queryKey: ['session-state'] }); } }, component: OverviewPage, diff --git a/new-ui/src/routes/full/index.tsx b/new-ui/src/routes/full/index.tsx index f068f8f48..1ad5e6e87 100644 --- a/new-ui/src/routes/full/index.tsx +++ b/new-ui/src/routes/full/index.tsx @@ -1,6 +1,22 @@ -import { createFileRoute, Navigate } from '@tanstack/react-router'; +import { createFileRoute, Navigate, redirect } from '@tanstack/react-router'; +import { + getInstancesQueryOptions, + getTunnelsQueryOptions, +} from '../../shared/rust-api/query'; export const Route = createFileRoute('/full/')({ + beforeLoad: async ({ context }) => { + const [instances, tunnels] = await Promise.all([ + context.queryClient.fetchQuery(getInstancesQueryOptions), + context.queryClient.fetchQuery(getTunnelsQueryOptions), + ]); + + if (instances.length === 0 && tunnels.length === 0) { + throw redirect({ to: '/full/add' }); + } else { + throw redirect({ to: '/full/overview' }); + } + }, component: RouteComponent, }); diff --git a/new-ui/src/routes/index.tsx b/new-ui/src/routes/index.tsx index e3da6280b..5c016c86a 100644 --- a/new-ui/src/routes/index.tsx +++ b/new-ui/src/routes/index.tsx @@ -1,58 +1,9 @@ -import { createFileRoute, redirect } from '@tanstack/react-router'; -import { CompactLocationsPage } from '../pages/compact/CompactLocationsPage/CompactLocationsPage'; -import { - getInstancesQueryOptions, - getLocationsQueryOptions, - getTunnelsQueryOptions, -} from '../shared/rust-api/query'; -import type { LocationInfo } from '../shared/rust-api/types'; -import { useAppStore } from '../shared/store/useAppStore'; +import { createFileRoute, Navigate } from '@tanstack/react-router'; export const Route = createFileRoute('/')({ - loader: async ({ context }) => { - const [instances, tunnels] = await Promise.all([ - context.queryClient.fetchQuery(getInstancesQueryOptions), - context.queryClient.fetchQuery(getTunnelsQueryOptions), - ]); - - if (instances.length === 0 && tunnels.length === 0) { - throw redirect({ to: '/empty' }); - } - - const stored = useAppStore.getState().compactViewSelection; - - let storedIsValid: boolean; - if (stored === null) { - storedIsValid = false; - } else if (stored.kind === 'instance') { - storedIsValid = instances.some((i) => i.id === stored.data.id); - } else { - storedIsValid = tunnels.some((t) => t.id === stored.data.id); - } - - let selected: NonNullable; - if (storedIsValid && stored !== null) { - selected = stored; - } else if (instances.length > 0) { - selected = { kind: 'instance', data: instances[0] }; - } else { - selected = { kind: 'tunnel', data: tunnels[0] }; - } - - if (!storedIsValid) { - useAppStore.setState({ compactViewSelection: selected }); - } - - let locations: LocationInfo[]; - if (selected.kind === 'instance') { - locations = await context.queryClient.fetchQuery( - getLocationsQueryOptions(selected.data.id), - ); - } else { - locations = []; - } - - return { instances, tunnels, locations }; - }, - component: CompactLocationsPage, + component: Component, }); + +function Component() { + return ; +} diff --git a/new-ui/src/shared/components/LocationCard/components/LocationCardMfaEdit/LocationCardMfaEdit.tsx b/new-ui/src/shared/components/LocationCard/components/LocationCardMfaEdit/LocationCardMfaEdit.tsx index b2ead6c68..6e47a71b2 100644 --- a/new-ui/src/shared/components/LocationCard/components/LocationCardMfaEdit/LocationCardMfaEdit.tsx +++ b/new-ui/src/shared/components/LocationCard/components/LocationCardMfaEdit/LocationCardMfaEdit.tsx @@ -1,5 +1,7 @@ import './style.scss'; import clsx from 'clsx'; +import { useMemo } from 'react'; +import { useAppData } from '../../../../providers/AppDataContext'; import { type LocationInfo, MfaMethod } from '../../../../rust-api/types'; import { mfaToText } from '../../../../utils/mfa'; import { IconButton } from '../../../IconButton/IconButton'; @@ -12,7 +14,11 @@ interface Props { } export const LocationCardMfaEdit = ({ location, onEdit, variant }: Props) => { - const mfaMethod = location.mfa_method ?? MfaMethod.Totp; + const { locationMfaPreference } = useAppData(); + const mfaMethod = useMemo( + () => locationMfaPreference[String(location.id)] ?? MfaMethod.Totp, + [locationMfaPreference, location.id], + ); if (location.location_mfa_mode === 'disabled') return null; @@ -22,7 +28,7 @@ export const LocationCardMfaEdit = ({ location, onEdit, variant }: Props) => {

MFA

{mfaToText(mfaMethod)}

- {location.location_mfa_mode === 'internal' && ( + {location.location_mfa_mode === 'internal' && !location.active && ( void; setPostureError: (error: string | null) => void; startMfa: () => void; - localMfaMethod: MfaMethodValue; - setLocalMfaMethod: (method: MfaMethodValue) => void; } const LocationCardContext = createContext(null); @@ -42,9 +41,6 @@ export const LocationCardProvider = ({ const [currentView, setCurrentView] = useState( location.active ? LocationCardViews.Connected : LocationCardViews.Default, ); - const [localMfaMethod, setLocalMfaMethod] = useState( - location.mfa_method ?? MfaMethod.Totp, - ); const setView = useCallback( (view: LocationCardViewsValue) => { @@ -54,8 +50,12 @@ export const LocationCardProvider = ({ [currentView], ); + const { locationMfaPreference } = useAppData(); + const startMfa = useCallback(() => { - switch (localMfaMethod) { + const mfaMethod = locationMfaPreference[String(location.id)] ?? MfaMethod.Totp; + + switch (mfaMethod) { case MfaMethod.Totp: setView(LocationCardViews.MfaTotp); break; @@ -69,7 +69,7 @@ export const LocationCardProvider = ({ setView(LocationCardViews.MfaMobile); break; } - }, [localMfaMethod, setView]); + }, [setView, location.id, locationMfaPreference]); return ( {children} diff --git a/new-ui/src/shared/components/LocationCard/views/LocationCardMfaSettings/LocationCardMfaSettings.tsx b/new-ui/src/shared/components/LocationCard/views/LocationCardMfaSettings/LocationCardMfaSettings.tsx index 04b066f63..134bb0a7d 100644 --- a/new-ui/src/shared/components/LocationCard/views/LocationCardMfaSettings/LocationCardMfaSettings.tsx +++ b/new-ui/src/shared/components/LocationCard/views/LocationCardMfaSettings/LocationCardMfaSettings.tsx @@ -1,6 +1,7 @@ import './style.scss'; import { useMutation } from '@tanstack/react-query'; import { useMemo, useState } from 'react'; +import { useAppData } from '../../../../providers/AppDataContext'; import { api } from '../../../../rust-api/api'; import { LocationMfaMode, @@ -30,13 +31,13 @@ export const LocationCardMfaSettings = () => { }, }); - const { previousView, setView, location, localMfaMethod, setLocalMfaMethod } = - useLocationCardContext(); + const { locationMfaPreference, setLocationMfaPreference } = useAppData(); + const { previousView, setView, location } = useLocationCardContext(); const locationDefaultMfaMethod = location.mfa_method ?? MfaMethod.Totp; const [selectedMethod, setSelectedPref] = useState( - localMfaMethod ?? MfaMethod.Totp, + locationMfaPreference[String(location.id)] ?? MfaMethod.Totp, ); const isFromDefault = previousView === LocationCardViews.Default; @@ -50,7 +51,7 @@ export const LocationCardMfaSettings = () => { }, [location.location_mfa_mode]); const handleSubmit = () => { - setLocalMfaMethod(selectedMethod); + setLocationMfaPreference(location.id, selectedMethod); if ((isFromDefault || setAsDefault) && selectedMethod !== locationDefaultMfaMethod) { setMfaMethod({ locationId: location.id, diff --git a/new-ui/src/shared/components/OverviewLocationCard/style.scss b/new-ui/src/shared/components/OverviewLocationCard/style.scss index a5d15ae04..226f40488 100644 --- a/new-ui/src/shared/components/OverviewLocationCard/style.scss +++ b/new-ui/src/shared/components/OverviewLocationCard/style.scss @@ -29,5 +29,6 @@ display: flex; flex-flow: row nowrap; column-gap: var(--spacing-5xl); + min-height: 24px; } } diff --git a/new-ui/src/shared/providers/AppDataContext.tsx b/new-ui/src/shared/providers/AppDataContext.tsx index fb4f1f2a0..6d7c8049b 100644 --- a/new-ui/src/shared/providers/AppDataContext.tsx +++ b/new-ui/src/shared/providers/AppDataContext.tsx @@ -1,12 +1,25 @@ -import { useQuery } from '@tanstack/react-query'; -import { createContext, type PropsWithChildren, useContext } from 'react'; -import { getInstancesQueryOptions, getTunnelsQueryOptions } from '../rust-api/query'; -import type { InstanceInfo, LocationInfo } from '../rust-api/types'; +import { useQuery, useQueryClient } from '@tanstack/react-query'; +import { createContext, type PropsWithChildren, useCallback, useContext } from 'react'; +import { api } from '../rust-api/api'; +import { + getInstancesQueryOptions, + getSessionStateQueryOptions, + getTunnelsQueryOptions, +} from '../rust-api/query'; +import type { + InstanceInfo, + LocationInfo, + MfaMethodValue, + OverviewViewSelection, +} from '../rust-api/types'; +import type { SharedSessionStorage } from './types'; -interface AppDataContextValue { +interface AppDataContextValue extends SharedSessionStorage { instances: InstanceInfo[]; tunnels: LocationInfo[]; isEmpty: boolean; + setViewSelection: (selection: OverviewViewSelection | null) => void; + setLocationMfaPreference: (locationId: number, method: MfaMethodValue) => void; } const AppDataContext = createContext(null); @@ -18,11 +31,45 @@ export const useAppData = (): AppDataContextValue => { }; export const AppDataProvider = ({ children }: PropsWithChildren) => { + const queryClient = useQueryClient(); const { data: instances = [] } = useQuery(getInstancesQueryOptions); const { data: tunnels = [] } = useQuery(getTunnelsQueryOptions); + const { data: sessionState } = useQuery(getSessionStateQueryOptions); const isEmpty = instances.length === 0 && tunnels.length === 0; + + const setViewSelection = useCallback( + (selection: OverviewViewSelection | null) => { + api + .patchSessionState({ view_selection: selection }) + .then(() => queryClient.invalidateQueries({ queryKey: ['session-state'] })); + }, + [queryClient], + ); + + const setLocationMfaPreference = useCallback( + (locationId: number, method: MfaMethodValue) => { + const current = sessionState?.location_mfa_preference ?? {}; + api + .patchSessionState({ + location_mfa_preference: { ...current, [String(locationId)]: method }, + }) + .then(() => queryClient.invalidateQueries({ queryKey: ['session-state'] })); + }, + [queryClient, sessionState?.location_mfa_preference], + ); + return ( - + {children} ); diff --git a/new-ui/src/shared/providers/TauriEventProvider.tsx b/new-ui/src/shared/providers/TauriEventProvider.tsx index e9c1f4184..4d7faa5a7 100644 --- a/new-ui/src/shared/providers/TauriEventProvider.tsx +++ b/new-ui/src/shared/providers/TauriEventProvider.tsx @@ -2,6 +2,7 @@ import { useQueryClient } from '@tanstack/react-query'; import { useNavigate } from '@tanstack/react-router'; import { listen } from '@tauri-apps/api/event'; import { getCurrentWindow } from '@tauri-apps/api/window'; +import { debug } from '@tauri-apps/plugin-log'; import { Fragment, type PropsWithChildren, useEffect } from 'react'; import { WindowId } from '../consts'; import { @@ -18,7 +19,7 @@ export const TauriEventProvider = ({ children }: PropsWithChildren) => { useEffect(() => { const unlisteners = Promise.all([ listen(TauriEvent.AddInstance, (event) => { - console.log('[TauriEvent] AddInstance', event.payload); + void debug(`UI Received event AddInstance: ${JSON.stringify(event.payload)}`); const windowLabel = getCurrentWindow().label; if (windowLabel === WindowId.FullView) { const { token, url } = event.payload; @@ -32,7 +33,9 @@ export const TauriEventProvider = ({ children }: PropsWithChildren) => { } }), listen(TauriEvent.ConnectionChanged, (event) => { - console.log('[TauriEvent] ConnectionChanged', event.payload); + void debug( + `UI Received event ConnectionChanged: ${JSON.stringify(event.payload)}`, + ); void queryClient.invalidateQueries({ queryKey: ['alive-connection'] }); void queryClient.invalidateQueries({ queryKey: ['active-connection'] }); void queryClient.invalidateQueries({ queryKey: ['locations'] }); @@ -42,26 +45,26 @@ export const TauriEventProvider = ({ children }: PropsWithChildren) => { }), listen(TauriEvent.InstanceUpdate, (event) => { - console.log('[TauriEvent] InstanceUpdate', event.payload); + void debug(`UI Received event InstanceUpdate: ${JSON.stringify(event.payload)}`); void queryClient.invalidateQueries({ queryKey: ['instances'] }); void queryClient.invalidateQueries({ queryKey: ['locations'] }); void queryClient.invalidateQueries({ queryKey: ['has-any-visible-locations'] }); }), listen(TauriEvent.LocationUpdate, (event) => { - console.log('[TauriEvent] LocationUpdate', event.payload); + void debug(`UI Received event LocationUpdate: ${JSON.stringify(event.payload)}`); void queryClient.invalidateQueries({ queryKey: ['locations'] }); void queryClient.invalidateQueries({ queryKey: ['location-details'] }); void queryClient.invalidateQueries({ queryKey: ['has-any-visible-locations'] }); }), listen(TauriEvent.AppVersionFetch, (event) => { - console.log('[TauriEvent] AppVersionFetch', event.payload); + void debug(`UI Received event AppVersionFetch: ${JSON.stringify(event.payload)}`); void queryClient.invalidateQueries({ queryKey: ['latest-app-version'] }); }), listen(TauriEvent.ConfigChanged, (event) => { - console.log('[TauriEvent] ConfigChanged', event.payload); + void debug(`UI Received event ConfigChanged: ${JSON.stringify(event.payload)}`); void queryClient.invalidateQueries({ queryKey: ['settings'] }); void queryClient.invalidateQueries({ queryKey: ['provisioning-config'] }); void queryClient.invalidateQueries({ queryKey: ['instances'] }); @@ -69,7 +72,9 @@ export const TauriEventProvider = ({ children }: PropsWithChildren) => { }), listen(TauriEvent.DeadConnectionDropped, (event) => { - console.log('[TauriEvent] DeadConnectionDropped', event.payload); + void debug( + `UI Received event DeadConnectionDropped: ${JSON.stringify(event.payload)}`, + ); void queryClient.invalidateQueries({ queryKey: ['alive-connection'] }); void queryClient.invalidateQueries({ queryKey: ['active-connection'] }); void queryClient.invalidateQueries({ queryKey: ['locations'] }); @@ -79,7 +84,9 @@ export const TauriEventProvider = ({ children }: PropsWithChildren) => { listen( TauriEvent.DeadConnectionReconnected, (event) => { - console.log('[TauriEvent] DeadConnectionReconnected', event.payload); + void debug( + `UI Received event DeadConnectionReconnected: ${JSON.stringify(event.payload)}`, + ); void queryClient.invalidateQueries({ queryKey: ['alive-connection'] }); void queryClient.invalidateQueries({ queryKey: ['active-connection'] }); void queryClient.invalidateQueries({ queryKey: ['locations'] }); @@ -88,19 +95,25 @@ export const TauriEventProvider = ({ children }: PropsWithChildren) => { ), listen(TauriEvent.ApplicationConfigChanged, (event) => { - console.log('[TauriEvent] ApplicationConfigChanged', event.payload); + void debug( + `UI Received event ApplicationConfigChanged: ${JSON.stringify(event.payload)}`, + ); void queryClient.invalidateQueries({ queryKey: ['settings'] }); }), listen(TauriEvent.AddInstance, (event) => { - console.log('[TauriEvent] AddInstance (instances invalidation)', event.payload); + void debug(`UI Received event AddInstance: ${JSON.stringify(event.payload)}`); void queryClient.invalidateQueries({ queryKey: ['instances'] }); }), listen(TauriEvent.UuidMismatch, (event) => { - console.log('[TauriEvent] UuidMismatch', event.payload); + void debug(`UI Received event UuidMismatch: ${JSON.stringify(event.payload)}`); void queryClient.invalidateQueries({ queryKey: ['instances'] }); }), + + listen(TauriEvent.SessionStateChanged, () => { + void queryClient.invalidateQueries({ queryKey: ['session-state'] }); + }), ]); return () => { diff --git a/new-ui/src/shared/providers/types.ts b/new-ui/src/shared/providers/types.ts new file mode 100644 index 000000000..afc490534 --- /dev/null +++ b/new-ui/src/shared/providers/types.ts @@ -0,0 +1,8 @@ +import type { MfaMethodValue, OverviewViewSelection } from '../rust-api/types'; + +export type { OverviewViewSelection }; + +export type SharedSessionStorage = { + viewSelection: OverviewViewSelection | null; + locationMfaPreference: Record; +}; diff --git a/new-ui/src/shared/rust-api/api.ts b/new-ui/src/shared/rust-api/api.ts index 7e66f95ea..ab6f75891 100644 --- a/new-ui/src/shared/rust-api/api.ts +++ b/new-ui/src/shared/rust-api/api.ts @@ -15,6 +15,8 @@ import type { RoutingArgs, SaveConfigArgs, SaveDeviceConfigResponse, + SessionState, + SessionStatePatch, SetLocationMfaMethodArgs, StatsArgs, TunnelInfo, @@ -124,6 +126,11 @@ const swapToTray = async () => invoke(TauriCommand.SwapToTray); const closeTrayWindow = async () => invoke(TauriCommand.CloseTrayWindow); +const getSessionState = (): Promise => invoke(TauriCommand.GetSessionState); + +const patchSessionState = (patch: SessionStatePatch): Promise => + invoke(TauriCommand.PatchSessionState, { patch }); + export const api = { // Instances getInstances, @@ -167,4 +174,7 @@ export const api = { swapToFullView, swapToTray, closeTrayWindow, + // Session state + getSessionState, + patchSessionState, }; diff --git a/new-ui/src/shared/rust-api/query.ts b/new-ui/src/shared/rust-api/query.ts index 0d44f3ce2..ca5ea023d 100644 --- a/new-ui/src/shared/rust-api/query.ts +++ b/new-ui/src/shared/rust-api/query.ts @@ -91,3 +91,8 @@ export const getPlatformHeaderQueryOptions = queryOptions({ queryKey: ['platform-header'] as const, queryFn: () => api.getPlatformHeader(), }); + +export const getSessionStateQueryOptions = queryOptions({ + queryKey: ['session-state'] as const, + queryFn: () => api.getSessionState(), +}); diff --git a/new-ui/src/shared/rust-api/types.ts b/new-ui/src/shared/rust-api/types.ts index 67f019b02..176afe790 100644 --- a/new-ui/src/shared/rust-api/types.ts +++ b/new-ui/src/shared/rust-api/types.ts @@ -123,6 +123,9 @@ export const TauriCommand = { SwapToFullView: 'swap_to_full_view', SwapToTray: 'swap_to_tray', CloseTrayWindow: 'close_tray_window', + // Session state + GetSessionState: 'get_session_state', + PatchSessionState: 'patch_session_state', } as const; export type TauriCommand = (typeof TauriCommand)[keyof typeof TauriCommand]; @@ -142,6 +145,8 @@ export const TauriEvent = { VersionMismatch: 'version-mismatch', UuidMismatch: 'uuid-mismatch', GlobalLogUpdate: 'log-update-global', + WindowSwapped: 'window-swapped', + SessionStateChanged: 'session-state-changed', } as const; export type TauriEventValue = (typeof TauriEvent)[keyof typeof TauriEvent]; @@ -361,3 +366,16 @@ export type SetLocationMfaMethodArgs = { locationId: number; mfaMethod: MfaMethodValue; }; + +export type OverviewViewSelection = + | { kind: 'instance'; data: InstanceInfo } + | { kind: 'tunnel'; data: LocationInfo }; + +/** Mirrors `SessionState` in src/session_state.rs. Fields are snake_case (raw serde output). */ +export type SessionState = { + view_selection: OverviewViewSelection | null; + /** Keys are location IDs serialized as strings (JSON object keys are always strings). */ + location_mfa_preference: Record; +}; + +export type SessionStatePatch = Partial; diff --git a/new-ui/src/shared/store/useAppStore.tsx b/new-ui/src/shared/store/useAppStore.tsx index 99d839cb4..d013881c3 100644 --- a/new-ui/src/shared/store/useAppStore.tsx +++ b/new-ui/src/shared/store/useAppStore.tsx @@ -1,13 +1,8 @@ import { create } from 'zustand'; import { createJSONStorage, persist } from 'zustand/middleware'; -import type { InstanceInfo, LocationInfo } from '../rust-api/types'; - -export type CompactViewSelection = - | { kind: 'instance'; data: InstanceInfo } - | { kind: 'tunnel'; data: LocationInfo }; interface StoreValues { - compactViewSelection: CompactViewSelection | null; + // only used in compact mode expandedLocation: number | null; } @@ -16,13 +11,12 @@ interface Store extends StoreValues {} export const useAppStore = create()( persist( (_) => ({ - compactViewSelection: null, expandedLocation: null, }), { name: 'app-store', storage: createJSONStorage(() => localStorage), - version: 3, + version: 4, }, ), ); diff --git a/src-tauri/core/src/database/models/instance.rs b/src-tauri/core/src/database/models/instance.rs index fcfc7d78c..1b3a91e2a 100644 --- a/src-tauri/core/src/database/models/instance.rs +++ b/src-tauri/core/src/database/models/instance.rs @@ -205,7 +205,7 @@ impl Instance { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub struct InstanceInfo { pub id: I, pub name: String, @@ -226,7 +226,7 @@ impl fmt::Display for InstanceInfo { } /// Describes allowed traffic options for clients connecting to an instance. -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Type)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Type)] #[repr(u32)] #[serde(rename_all = "snake_case")] pub enum ClientTrafficPolicy { diff --git a/src-tauri/core/src/events.rs b/src-tauri/core/src/events.rs index 1a3504e80..cc8802a24 100644 --- a/src-tauri/core/src/events.rs +++ b/src-tauri/core/src/events.rs @@ -14,6 +14,8 @@ pub enum EventKey { MfaTrigger, VersionMismatch, UuidMismatch, + WindowSwapped, + SessionStateChanged, } impl From for &'static str { @@ -31,6 +33,8 @@ impl From for &'static str { EventKey::MfaTrigger => "mfa-trigger", EventKey::VersionMismatch => "version-mismatch", EventKey::UuidMismatch => "uuid-mismatch", + EventKey::WindowSwapped => "window-swapped", + EventKey::SessionStateChanged => "session-state-changed", } } } diff --git a/src-tauri/permissions/default.toml b/src-tauri/permissions/default.toml index b1b216f9e..18b028e8e 100644 --- a/src-tauri/permissions/default.toml +++ b/src-tauri/permissions/default.toml @@ -39,4 +39,6 @@ commands.allow = [ "all_active_connections", "disconnect_locations", "get_posture_data", + "get_session_state", + "patch_session_state", ] diff --git a/src-tauri/src/appstate.rs b/src-tauri/src/appstate.rs index e9a98113d..1072c15fd 100644 --- a/src-tauri/src/appstate.rs +++ b/src-tauri/src/appstate.rs @@ -10,6 +10,7 @@ use crate::{ app_config::AppConfig, database::models::{connection::ActiveConnection, Id}, enterprise::provisioning::ProvisioningConfig, + session_state::SessionState, utils::stats_handler, ConnectionType, }; @@ -21,6 +22,7 @@ pub struct AppState { pub tray_click_position: Mutex>>, stat_threads: Mutex>>, // location ID is the key pub provisioning_config: Mutex>, + pub session_state: Mutex, } impl AppState { @@ -32,6 +34,7 @@ impl AppState { tray_click_position: Mutex::new(None), stat_threads: Mutex::new(HashMap::new()), provisioning_config: Mutex::new(provisioning_config), + session_state: Mutex::new(SessionState::default()), } } diff --git a/src-tauri/src/bin/defguard-client.rs b/src-tauri/src/bin/defguard-client.rs index fcdf4d895..919c12330 100644 --- a/src-tauri/src/bin/defguard-client.rs +++ b/src-tauri/src/bin/defguard-client.rs @@ -26,7 +26,7 @@ use defguard_client::{ enterprise::provisioning::handle_client_initialization, events::handle_deep_link, periodic::run_periodic_tasks, - service, + service, session_state, tray::{configure_tray_icon, setup_tray}, utils::load_log_targets, window_manager::*, @@ -208,6 +208,8 @@ fn main() { close_tray_window, all_active_connections, disconnect_locations, + session_state::get_session_state, + session_state::patch_session_state, ]) .on_window_event(|window, event| { if let WindowEvent::CloseRequested { api, .. } = event { diff --git a/src-tauri/src/commands.rs b/src-tauri/src/commands.rs index 2a413c152..9b262267b 100644 --- a/src-tauri/src/commands.rs +++ b/src-tauri/src/commands.rs @@ -520,7 +520,7 @@ pub async fn all_instances() -> Result>, Error> { Ok(instance_info) } -#[derive(Debug, Serialize)] +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct LocationInfo { pub id: Id, pub instance_id: Id, diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 335ad4b5c..6b54ce62d 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -10,6 +10,7 @@ pub mod events; pub mod log_watcher; pub mod periodic; pub mod service; +pub mod session_state; pub mod tray; pub mod utils; pub mod window_manager; diff --git a/src-tauri/src/session_state.rs b/src-tauri/src/session_state.rs new file mode 100644 index 000000000..8928918d8 --- /dev/null +++ b/src-tauri/src/session_state.rs @@ -0,0 +1,48 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; +use struct_patch::Patch; +use tauri::{AppHandle, Emitter, Manager, State}; + +use defguard_client_core::{ + database::models::{instance::InstanceInfo, location::LocationMfaMethod, Id}, + events::EventKey, +}; + +use crate::{appstate::AppState, commands::LocationInfo}; + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[serde(tag = "kind", content = "data", rename_all = "lowercase")] +pub enum OverviewViewSelection { + Instance(InstanceInfo), + Tunnel(LocationInfo), +} + +#[derive(Clone, Debug, Default, Deserialize, Patch, Serialize)] +#[patch(attribute(derive(Debug, Deserialize, Serialize)))] +pub struct SessionState { + pub view_selection: Option, + pub location_mfa_preference: HashMap, +} + +#[tauri::command] +pub fn get_session_state(app_state: State<'_, AppState>) -> SessionState { + app_state.session_state.lock().unwrap().clone() +} + +#[tauri::command(async)] +pub async fn patch_session_state( + patch: SessionStatePatch, + app_handle: AppHandle, +) -> Result { + let app_state = app_handle.state::(); + let updated = { + let mut session_state = app_state.session_state.lock().unwrap(); + session_state.apply(patch); + session_state.clone() + }; + if let Err(err) = app_handle.emit(EventKey::SessionStateChanged.into(), ()) { + error!("Failed to emit session-state-changed event: {err}"); + } + Ok(updated) +} diff --git a/src-tauri/src/window_manager/mod.rs b/src-tauri/src/window_manager/mod.rs index 6e70b235c..d600f1c65 100644 --- a/src-tauri/src/window_manager/mod.rs +++ b/src-tauri/src/window_manager/mod.rs @@ -1,8 +1,11 @@ #[cfg(not(target_os = "windows"))] use tauri::Manager; -use tauri::{AppHandle, WebviewUrl, WebviewWindow, WebviewWindowBuilder}; +use tauri::{AppHandle, Emitter, WebviewUrl, WebviewWindow, WebviewWindowBuilder}; -use crate::database::{models::location::Location, DB_POOL}; +use crate::{ + database::{models::location::Location, DB_POOL}, + events::EventKey, +}; /// Returns `true` if there are any non-service locations in the database. pub async fn has_non_service_locations() -> bool { @@ -141,6 +144,8 @@ pub fn swap_to_full_view(app: AppHandle) { } if let Err(err) = WindowManager::open_full_view(&app) { error!("swap_to_full_view task: Failed to open full view: {err:?}"); + } else if let Err(err) = app.emit(EventKey::WindowSwapped.into(), ()) { + error!("swap_to_full_view task: Failed to emit window swapped event: {err:?}"); } } @@ -167,4 +172,7 @@ pub fn swap_to_tray(app: AppHandle) { error!("swap_to_tray task: Failed to hide full-view window: {err:?}"); } } + if let Err(err) = app.emit(EventKey::WindowSwapped.into(), ()) { + error!("swap_to_tray task: Failed to emit window swapped event: {err:?}"); + } }