diff --git a/src/components/Process/Workers/Configurations/Fields/ModelField.vue b/src/components/Process/Workers/Configurations/Fields/ModelField.vue new file mode 100644 index 0000000000000000000000000000000000000000..2b91efeff5e096e64c2ac4aac1e5002785746001 --- /dev/null +++ b/src/components/Process/Workers/Configurations/Fields/ModelField.vue @@ -0,0 +1,78 @@ +<template> + <div class="field"> + <SearchableSelect + :model-value="modelValue" + results-name="models" + is-fixed + v-on:update:model-value="value => $emit('update:modelValue', value || undefined)" + /> + </div> +</template> + +<script lang="ts"> +import { errorParser } from '@/helpers' +import { useModelStore, useNotificationStore } from '@/stores' +import { ModelUserConfigurationField } from '@/types/workerConfiguration' +import { mapActions, mapState } from 'pinia' +import { defineComponent, PropType } from 'vue' +import { UUID } from '@/types' +import SearchableSelect, { GetSuggestionsFunc } from '@/components/SearchableSelect.vue' + +export default defineComponent({ + emits: ['update:modelValue'], + props: { + field: { + type: Object as PropType<ModelUserConfigurationField>, + required: true + }, + modelValue: { + type: String as PropType<UUID | null | undefined>, + default: null + } + }, + components: { + SearchableSelect + }, + data: () => ({ + loading: false, + page: 1 + }), + provide () { + const getSuggestions: GetSuggestionsFunc = async (search: string) => { + // Process models to turn them into { id: display_name } suggestions + const results = Object.entries(this.allModels).map(([id, model]) => [id, `${model.name} (${id})`]) + const suggestions = Object.fromEntries(results.filter( + ([, label]) => label.toLowerCase().includes(search.toLowerCase()) + )) + return { + suggestions, + count: Object.keys(suggestions).length + } + } + return { getSuggestions } + }, + methods: { + ...mapActions(useModelStore, ['listAllModels']), + ...mapActions(useNotificationStore, ['notify']), + async getModels () { + this.loading = true + try { + await this.listAllModels() + } catch (err) { + this.notify({ type: 'error', text: `An error occurred listing models: ${errorParser(err)}` }) + } finally { + this.loading = false + } + } + }, + computed: { + ...mapState(useModelStore, ['allModels']) + }, + watch: { + page: { + immediate: true, + handler: 'getModels' + } + } +}) +</script> diff --git a/src/components/Process/Workers/Configurations/Fields/index.ts b/src/components/Process/Workers/Configurations/Fields/index.ts index faf62f074f636ed589c9e2fe7168acc9cbc8d78b..b5a69978bd6018e8bc26a7c3bc91cb47842547ed 100644 --- a/src/components/Process/Workers/Configurations/Fields/index.ts +++ b/src/components/Process/Workers/Configurations/Fields/index.ts @@ -1,5 +1,6 @@ import { toNumber } from 'lodash' import { Component, markRaw } from 'vue' +import { UUID_REGEX } from '@/config' import IntegerField from './IntegerField.vue' import FloatField from './FloatField.vue' @@ -8,11 +9,15 @@ import StringField from './StringField.vue' import BooleanField from './BooleanField.vue' import DictField from './DictField.vue' import ListField from './ListField.vue' -import { EnumUserConfigurationField, ListUserConfigurationFieldValue, ListUserConfigurationField, UserConfigurationField, UserConfigurationFields, UserConfigurationListSubtypes } from '@/types/workerConfiguration' +import ModelField from './ModelField.vue' +import { EnumUserConfigurationField, ListUserConfigurationFieldValue, ListUserConfigurationField, UserConfigurationField, UserConfigurationFields } from '@/types/workerConfiguration' +import { UUID } from '@/types' +import { Model } from '@/types/model' +import { useModelStore } from '@/stores' -type ConfigurationFields = { [TypeName in UserConfigurationField["type"]] : { +type ConfigurationFields = { [TypeName in UserConfigurationField['type']] : { component: Component, - validate: (value: unknown, field?: UserConfigurationFields[TypeName]) => Exclude<UserConfigurationFields[TypeName]["default"], undefined> + validate: (value: unknown, field?: UserConfigurationFields[TypeName], models?: { [id: UUID]: Model }) => Exclude<UserConfigurationFields[TypeName]['default'], undefined> }} const FIELDS: ConfigurationFields = { @@ -74,6 +79,15 @@ const FIELDS: ConfigurationFields = { throw new Error(`Value must be a valid list of ${field.subtype}.`) } } + }, + model: { + component: markRaw(ModelField), + validate (value: unknown): UUID { + const modelStore = useModelStore() + if (typeof value !== 'string' || !UUID_REGEX.test(value)) throw new Error('Value must be a valid UUID') + if (!(value in modelStore.allModels)) throw new Error('This model does not exist') + return value + } } } diff --git a/src/components/SearchableSelect.vue b/src/components/SearchableSelect.vue index 68d949aaa85086fa4358f7de3aff0a9b68cc2a1c..da4ba9c9fe871f98b337cf5093387ba21c66986d 100644 --- a/src/components/SearchableSelect.vue +++ b/src/components/SearchableSelect.vue @@ -41,14 +41,23 @@ </fieldset> </template> -<script> +<script lang="ts"> +import { PropType, defineComponent, inject } from 'vue' import { SEARCHABLE_SELECT_MAX_MATCHES, SEARCHABLE_SELECT_SUGGESTION_DELAY } from '@/config' import { highlight } from '@/helpers' -export default { - inject: [ + +interface GetSuggestionsResult { + suggestions: Record<string, string> + count: number +} + +export type GetSuggestionsFunc = (value: string) => GetSuggestionsResult | Promise<GetSuggestionsResult> + +export default defineComponent({ + setup () { /* * This component requires an implementation of the getSuggestions method. * @@ -64,10 +73,11 @@ export default { * ... * * You can provide getSuggestions using the `provide` component option. - * https://vuejs.org/v2/guide/components-edge-cases.html#Dependency-Injection + * https://vuejs.org/guide/components/provide-inject.html */ - 'getSuggestions' - ], + const getSuggestions = inject('getSuggestions') as GetSuggestionsFunc + return { getSuggestions } + }, emits: [ 'submit', 'update:isValid', @@ -76,8 +86,8 @@ export default { expose: ['focus', 'clear'], props: { modelValue: { - required: true, - validator: value => (value === null || typeof value === 'string') + type: String as PropType<string | null>, + default: null }, isValid: { type: Boolean, @@ -132,7 +142,7 @@ export default { }, data: () => ({ // Retrieved matching suggestions - suggestions: {}, + suggestions: {} as Record<string, string>, /* * Total number of results. If this exceeds SEARCHABLE_SELECT_MAX_MATCHES, * the remainder is displayed as "And … more results". @@ -144,11 +154,11 @@ export default { // Suggestions visibility toggled: false, // Hovered suggestion using key bindings - current: null, + current: null as string | null, // Fallback when no isValid property is bound valid: true, // Timeout ID used to debounce suggestion fetching - suggestionTimeoutId: null, + suggestionTimeoutId: null as NodeJS.Timeout | null, // Set to true while the suggestions are loading loading: false }), @@ -172,15 +182,15 @@ export default { } }, methods: { - setValidInput (bool) { + setValidInput (bool: boolean) { // Update input validity bound prop or data if (this.isValid !== null) this.$emit('update:isValid', bool) else this.valid = bool }, - hlText (suggestion) { + hlText (suggestion: string) { return highlight(suggestion, this.input.split(/\s+/)) }, - navigate (event) { + navigate (event: KeyboardEvent) { /* * If the user presses Enter and there already is a valid input, trigger a submit event * to imply the user is trying to submit a form. @@ -194,12 +204,12 @@ export default { } this.toggled = true - const currentIndex = Object.keys(this.suggestions).indexOf(this.current) + const currentIndex = this.current === null ? -1 : Object.keys(this.suggestions).indexOf(this.current) let newIndex = currentIndex const max = this.suggestionsCount if (event.key === 'ArrowDown') newIndex++ else if (event.key === 'ArrowUp') newIndex = Math.max(-1, newIndex - 1) - else if (event.key === 'Enter' && currentIndex >= 0) { + else if (event.key === 'Enter' && this.current !== null && currentIndex >= 0) { this.select(this.current) return } else if (event.key === 'Escape') { @@ -222,7 +232,7 @@ export default { this.current = null if (this.toggled) this.resetSuggestionTimeout() }, - select (value) { + select (value: string) { this.toggled = false this.setValidInput(true) this.$emit('update:modelValue', value) @@ -239,7 +249,7 @@ export default { // this.getSuggestions might not always be async; just ensure it is a promise output = await Promise.resolve(this.getSuggestions(queryTerms)) } catch { - this.suggestions = [] + this.suggestions = {} return } finally { this.loading = false @@ -248,7 +258,7 @@ export default { // Input has changed while waiting for the response if (queryTerms !== this.input.trim()) return this.suggestions = suggestions - this.count = count || this.suggestions.length + this.count = count }, check () { // Reset the value as the input text has changed @@ -266,24 +276,14 @@ export default { this.setValidInput(true) } else this.setValidInput(false) }, - updateFromValue () { - if (!this.modelValue) return - const suggestion = this.suggestions[this.modelValue] - if (suggestion) { - this.input = suggestion - } else { - this.input = this.default - } - this.setValidInput(true) - }, focus () { - this.$refs.input.focus() + (this.$refs.input as HTMLInputElement).focus() }, clear () { this.input = '' this.$emit('update:modelValue', '') this.toggled = false - this.suggestions = [] + this.suggestions = {} } }, watch: { @@ -297,13 +297,27 @@ export default { }, modelValue: { immediate: true, - handler: 'updateFromValue' + async handler (newValue) { + if (!newValue) { + this.input = '' + this.setValidInput(this.allowEmpty) + return + } + + const suggestion = (await Promise.resolve(this.getSuggestions(newValue))).suggestions[newValue] + if (suggestion) { + this.input = suggestion + } else { + this.input = this.default + } + this.setValidInput(true) + } }, default () { this.input = this.default } } -} +}) </script> <style lang="scss" scoped> diff --git a/src/stores/model.ts b/src/stores/model.ts index 5abb3fefb6fe3d3cf03dc52bd79d8498adc31950..c41ed6cf89d530e95088f603d1f129dc63f8e178 100644 --- a/src/stores/model.ts +++ b/src/stores/model.ts @@ -21,12 +21,14 @@ import { useNotificationStore } from '.' interface State { models: { [id: UUID]: Model } modelVersions: { [id: UUID]: ModelVersion } + allModels: { [id: UUID]: Model } } export const useModelStore = defineStore('model', { state: (): State => ({ models: {}, - modelVersions: {} + modelVersions: {}, + allModels: {} }), actions: { async createModel (params: CreateModelPayload) { @@ -50,6 +52,25 @@ export const useModelStore = defineStore('model', { return resp }, + async listAllModels (page = 1) { + // Do not start fetching models if they have been retrieved already + if (page === 1 && this.allModels.length) return + const resp = await listModels({ page }) + + this.allModels = { + ...this.allModels, + ...Object.fromEntries(resp.results.map(model => [model.id, model])) + } + + if (!resp || !resp.number || page !== resp.number) { + // Avoid any loop + throw new Error('Pagination failed listing models') + } + + // Load other pages + if (resp.next) await this.listAllModels(page + 1) + }, + async listModelVersions (modelId: UUID, page = 1) { const resp = await listModelVersions(modelId, { page }) this.modelVersions = { diff --git a/src/types/workerConfiguration.ts b/src/types/workerConfiguration.ts index cd67555369ab1e69d5c8eadbf81c2b9afc10ae85..8186ca90317112176318822abc5018f3f063eabf 100644 --- a/src/types/workerConfiguration.ts +++ b/src/types/workerConfiguration.ts @@ -59,6 +59,8 @@ export type BooleanUserConfigurationField = _UserConfigurationField<'bool', bool export type DictUserConfigurationField = _UserConfigurationField<'dict', Record<string, string>> +export type ModelUserConfigurationField = _UserConfigurationField<'model', UUID> + export interface UserConfigurationFields { string: StringUserConfigurationField int: IntegerUserConfigurationField @@ -67,6 +69,7 @@ export interface UserConfigurationFields { dict: DictUserConfigurationField enum: EnumUserConfigurationField list: ListUserConfigurationField + model: ModelUserConfigurationField } -export type UserConfigurationField = UserConfigurationFields[keyof UserConfigurationFields] \ No newline at end of file +export type UserConfigurationField = UserConfigurationFields[keyof UserConfigurationFields] diff --git a/tests/unit/stores/model.spec.js b/tests/unit/stores/model.spec.js index dfbd4df5cdca6988e150a41cdd49b100519cdd7f..9e3a92192a3812dda4053e3acd3287d5bd4f7162 100644 --- a/tests/unit/stores/model.spec.js +++ b/tests/unit/stores/model.spec.js @@ -205,5 +205,53 @@ describe('model', () => { }) }) }) + + describe('listAllModels', () => { + it('lists all models (not paginated)', async () => { + const pages = [ + { + count: 4, + number: 1, + next: 'nextpage', + results: [ + { id: 'model1' }, + { id: 'model2' }, + { id: 'model3' } + ] + }, + { + count: 4, + number: 2, + next: null, + results: [ + { id: 'model4' } + ] + } + ] + mock.onGet('/models/', { params: { page: 1 } }).reply(200, pages[0]) + mock.onGet('/models/', { params: { page: 2 } }).reply(200, pages[1]) + + await store.listAllModels() + + assert.deepStrictEqual(mock.history.all.map(req => pick(req, ['method', 'url', 'params'])), [ + { + method: 'get', + url: '/models/', + params: { page: 1 } + }, + { + method: 'get', + url: '/models/', + params: { page: 2 } + } + ]) + assert.deepStrictEqual(store.allModels, { + model1: { id: 'model1' }, + model2: { id: 'model2' }, + model3: { id: 'model3' }, + model4: { id: 'model4' } + }) + }) + }) }) })