diff --git a/codegenerator/cli.py b/codegenerator/cli.py index e97df91..18fc799 100644 --- a/codegenerator/cli.py +++ b/codegenerator/cli.py @@ -33,6 +33,7 @@ from codegenerator.openapi_spec import OpenApiSchemaGenerator # from codegenerator.osc import OSCGenerator from codegenerator.rust_cli import RustCliGenerator from codegenerator.rust_tui import RustTuiGenerator +from codegenerator.rust_types import RustTypesGenerator from codegenerator.rust_sdk import RustSdkGenerator from codegenerator.types import Metadata @@ -159,6 +160,7 @@ def main(): "rust-sdk", "rust-cli", "rust-tui", + "rust-types", "openapi-spec", "jsonschema", "metadata", @@ -201,6 +203,7 @@ def main(): "rust-cli": RustCliGenerator(), "rust-tui": RustTuiGenerator(), "rust-sdk": RustSdkGenerator(), + "rust-types": RustTypesGenerator(), "openapi-spec": OpenApiSchemaGenerator(), "jsonschema": JsonSchemaGenerator(), "metadata": MetadataGenerator(), @@ -226,8 +229,13 @@ def main(): continue for op, op_data in res_data.operations.items(): logging.debug(f"Processing operation {op_data.operation_id}") - if args.target in op_data.targets: - op_args = op_data.targets[args.target] + metadata_target = ( + "rust-sdk" + if args.target in ["rust-sdk", "rust-types"] + else args.target + ) + if metadata_target in op_data.targets: + op_args = op_data.targets[metadata_target] if not op_args.service_type: op_args.service_type = res.split(".")[0] if not op_args.api_version: @@ -254,7 +262,7 @@ def main(): ): res_mods.append((mod_path, mod_name, path, class_name)) rust_sdk_extensions = res_data.extensions.get("rust-sdk") - if rust_sdk_extensions: + if rust_sdk_extensions and args.target != "rust-types": additional_modules = rust_sdk_extensions.setdefault( "additional_modules", [] ) @@ -273,7 +281,10 @@ def main(): ) ) - if args.target in ["rust-sdk", "rust-tui"] and not args.resource: + if ( + args.target in ["rust-sdk", "rust-tui", "rust-types"] + and not args.resource + ): resource_results: dict[str, dict] = {} for mod_path, mod_name, path, class_name in res_mods: mn = "/".join(mod_path) diff --git a/codegenerator/common/__init__.py b/codegenerator/common/__init__.py index 8c3c629..9544071 100644 --- a/codegenerator/common/__init__.py +++ b/codegenerator/common/__init__.py @@ -562,6 +562,13 @@ def get_rust_sdk_mod_path(service_type: str, api_version: str, path: str): return mod_path +def get_rust_types_mod_path(service_type: str, api_version: str, path: str): + """Construct mod path for rust types crate""" + mod_path = [service_type.replace("-", "_"), api_version] + mod_path.extend([x.lower() for x in get_resource_names_from_url(path)]) + return mod_path + + def get_rust_cli_mod_path(service_type: str, api_version: str, path: str): """Construct mod path for rust sdk""" mod_path = [service_type.replace("-", "_"), api_version] diff --git a/codegenerator/common/rust.py b/codegenerator/common/rust.py index 0e393ec..4e50516 100644 --- a/codegenerator/common/rust.py +++ b/codegenerator/common/rust.py @@ -26,6 +26,19 @@ from codegenerator import common CODEBLOCK_RE = re.compile(r"```(\w*)$") +BASIC_FIELDS = [ + "id", + "name", + "title", + "created_at", + "updated_at", + "uuid", + "state", + "status", + "operating_status", +] + + class Boolean(BasePrimitiveType): """Basic Boolean""" @@ -264,6 +277,20 @@ class Dictionary(BaseCombinedType): base_type: str = "dict" value_type: BasePrimitiveType | BaseCombinedType | BaseCompoundType + @property + def imports(self): + imports: set[str] = {"std::collections::HashMap"} + imports.update(self.value_type.imports) + return imports + + @property + def type_hint(self): + return f"HashMap" + + @property + def lifetimes(self): + return set() + class StructField(BaseModel): local_name: str @@ -338,6 +365,101 @@ class Struct(BaseCompoundType): return set() +class StructFieldResponse(StructField): + """Response Structure Field""" + + @property + def type_hint(self): + typ_hint = self.data_type.type_hint + if self.is_optional and not typ_hint.startswith("Option<"): + typ_hint = f"Option<{typ_hint}>" + return typ_hint + + @property + def serde_macros(self): + macros = set() + if self.local_name != self.remote_name: + macros.add(f'rename="{self.remote_name}"') + if len(macros) > 0: + return f"#[serde({', '.join(sorted(macros))})]" + return "" + + def get_structable_macros( + self, + struct: "StructResponse", + service_name: str, + resource_name: str, + operation_type: str, + ): + macros = set() + if self.is_optional or self.data_type.type_hint.startswith("Option<"): + macros.add("optional") + if self.local_name != self.remote_name: + macros.add(f'title="{self.remote_name}"') + # Fully Qualified Attribute Name + fqan: str = ".".join( + [service_name, resource_name, self.remote_name] + ).lower() + # Check the known alias of the field by FQAN + alias = common.FQAN_ALIAS_MAP.get(fqan) + if operation_type in ["list", "list_from_struct"]: + if ( + "id" in struct.fields.keys() + and not ( + self.local_name in BASIC_FIELDS or alias in BASIC_FIELDS + ) + ) or ( + "id" not in struct.fields.keys() + and (self.local_name not in list(struct.fields.keys())[-10:]) + and not ( + self.local_name in BASIC_FIELDS or alias in BASIC_FIELDS + ) + ): + # Only add "wide" flag if field is not in the basic fields AND + # there is at least "id" field existing in the struct OR the + # field is not in the first 10 + macros.add("wide") + if ( + self.local_name == "state" + and "status" not in struct.fields.keys() + ): + macros.add("status") + elif ( + self.local_name == "operating_status" + and "status" not in struct.fields.keys() + ): + macros.add("status") + if self.data_type.type_hint in [ + "Value", + "Option", + "Vec", + "Option>", + ]: + macros.add("pretty") + return f"#[structable({', '.join(sorted(macros))})]" + + +class StructResponse(Struct): + field_type_class_: Type[StructField] = StructFieldResponse + + @property + def imports(self): + imports: set[str] = {"serde::Deserialize", "serde::Serialize"} + for field in self.fields.values(): + imports.update(field.data_type.imports) + # In difference to the SDK and Input we do not currently handle + # additional_fields of the struct in response + # if self.additional_fields_type: + # imports.add("std::collections::BTreeMap") + # imports.update(self.additional_fields_type.imports) + return imports + + @property + def static_lifetime(self): + """Return Rust `<'lc>` lifetimes representation""" + return f"<{', '.join(self.lifetimes)}>" if self.lifetimes else "" + + class EnumKind(BaseModel): name: str description: str | None = None @@ -346,7 +468,11 @@ class EnumKind(BaseModel): @property def type_hint(self): if isinstance(self.data_type, Struct): - return self.data_type.name + self.data_type.static_lifetime + print(f"Getting type hint of {self.data_type}") + try: + return self.data_type.name + self.data_type.static_lifetime + except Exception as ex: + print(f"Error {ex}") return self.data_type.type_hint @property @@ -361,6 +487,14 @@ class Enum(BaseCompoundType): original_data_type: BaseCompoundType | BaseCompoundType | None = None _kind_type_class = EnumKind + @property + def derive_container_macros(self) -> str: + return "#[derive(Debug, Deserialize, Clone, Serialize)]" + + @property + def serde_container_macros(self) -> str: + return "#[serde(untagged)]" + @property def type_hint(self): return self.name + ( @@ -394,14 +528,18 @@ class StringEnum(BaseCompoundType): variants: dict[str, set[str]] = {} imports: set[str] = {"serde::Deserialize", "serde::Serialize"} lifetimes: set[str] = set() - derive_container_macros: str = ( - "#[derive(Debug, Deserialize, Clone, Serialize)]" - ) builder_container_macros: str | None = None - serde_container_macros: str | None = None # "#[serde(untagged)]" serde_macros: set[str] | None = None original_data_type: BaseCompoundType | BaseCompoundType | None = None + @property + def derive_container_macros(self) -> str: + return "#[derive(Debug, Deserialize, Clone, Serialize)]" + + @property + def serde_container_macros(self) -> str: + return "#[serde(untagged)]" + @property def type_hint(self): """Get type hint""" @@ -435,6 +573,36 @@ class StringEnum(BaseCompoundType): return "#[serde(" + ", ".join(sorted(macros)) + ")]" +class HashMapResponse(Dictionary): + """Wrapper around a simple dictionary to implement Display trait""" + + lifetimes: set[str] = set() + + @property + def type_hint(self): + return f"HashMapString{self.value_type.type_hint.replace('<', '').replace('>', '')}" + + @property + def imports(self): + imports = self.value_type.imports + imports.add("std::collections::HashMap") + return imports + + +class TupleStruct(Struct): + """Rust tuple struct without named fields""" + + base_type: str = "struct" + tuple_fields: list[StructField] = [] + + @property + def imports(self): + imports: set[str] = set() + for field in self.tuple_fields: + imports.update(field.data_type.imports) + return imports + + class RequestParameter(BaseModel): """OpenAPI request parameter in the Rust SDK form""" @@ -521,6 +689,8 @@ class TypeManager: #: List of the models to be ignored ignored_models: list[model.Reference] = [] + root_name: str | None = "Body" + def __init__(self): self.models = [] self.refs = {} @@ -672,7 +842,9 @@ class TypeManager: ) if not model_ref: - model_ref = model.Reference(name="Body", type=typ.__class__) + model_ref = model.Reference( + name=self.root_name, type=typ.__class__ + ) self.refs[model_ref] = typ return typ @@ -901,8 +1073,14 @@ class TypeManager: name = getattr(model_data_type, "name", None) if ( name - and name in unique_models - and unique_models[name].hash_ != model_.reference.hash_ + and model_.reference + and ( + ( + name in unique_models + and unique_models[name].hash_ != model_.reference.hash_ + ) + or name == self.root_name + ) ): # There is already a model with this name. if model_.reference and model_.reference.parent: @@ -975,6 +1153,7 @@ class TypeManager: elif ( name and name in unique_models + and model_.reference and unique_models[name].hash_ == model_.reference.hash_ # image.metadef.namespace have weird occurences of itself and model_.reference != unique_models[name] @@ -993,12 +1172,12 @@ class TypeManager: if ( k and isinstance(v, (Enum, Struct, StringEnum)) - and k.name != "Body" + and k.name != self.root_name ): yield v elif ( k - and k.name != "Body" + and k.name != self.root_name and isinstance(v, self.option_type_class) ): if isinstance(v.item_type, Enum): @@ -1007,7 +1186,7 @@ class TypeManager: def get_root_data_type(self): """Get TLA type""" for k, v in self.refs.items(): - if not k or (k.name == "Body" and isinstance(v, Struct)): + if not k or (k.name == self.root_name and isinstance(v, Struct)): if isinstance(v.fields, dict): # There might be tuple Struct (with # fields as list) @@ -1022,7 +1201,9 @@ class TypeManager: ) v.fields[field_names[0]].is_optional = False return v - elif not k or (k.name == "Body" and isinstance(v, Dictionary)): + elif not k or ( + k.name == self.root_name and isinstance(v, Dictionary) + ): # Response is a free style Dictionary return v # No root has been found, make a dummy one diff --git a/codegenerator/rust_cli.py b/codegenerator/rust_cli.py index 467e067..40316f1 100644 --- a/codegenerator/rust_cli.py +++ b/codegenerator/rust_cli.py @@ -645,7 +645,7 @@ class RequestTypeManager(common_rust.TypeManager): ) if not model_ref: model_ref = model.Reference( - name="Body", type=typ.__class__ + name=self.root_name, type=typ.__class__ ) if type_model.value_type.reference: self.ignored_models.append( @@ -999,7 +999,7 @@ class ResponseTypeManager(common_rust.TypeManager): common_rust.Array, ), ) - and k.name != "Body" + and k.name != self.root_name ): key = v.base_type + v.type_hint if key not in emited_data: @@ -1023,7 +1023,7 @@ class RustCliGenerator(BaseGenerator): :param *args: Path to the code to format """ for path in args: - subprocess.run(["rustfmt", "--edition", "2021", path]) + subprocess.run(["rustfmt", "--edition", "2024", path]) def get_parser(self, parser): parser.add_argument( @@ -1284,7 +1284,8 @@ class RustCliGenerator(BaseGenerator): ) response_type_manager.refs[ model.Reference( - name="Body", type=HashMapResponse + name=response_type_manager.root_name, + type=HashMapResponse, ) ] = root_dict @@ -1329,7 +1330,10 @@ class RustCliGenerator(BaseGenerator): tuple_struct = TupleStruct(name="Response") tuple_struct.tuple_fields.append(field) response_type_manager.refs[ - model.Reference(name="Body", type=TupleStruct) + model.Reference( + name=response_type_manager.root_name, + type=TupleStruct, + ) ] = tuple_struct elif ( response_def["type"] == "array" diff --git a/codegenerator/rust_sdk.py b/codegenerator/rust_sdk.py index e147f95..b74d522 100644 --- a/codegenerator/rust_sdk.py +++ b/codegenerator/rust_sdk.py @@ -345,7 +345,7 @@ class RustSdkGenerator(BaseGenerator): :param *args: Path to the code to format """ for path in args: - subprocess.run(["rustfmt", "--edition", "2021", path]) + subprocess.run(["rustfmt", "--edition", "2024", path]) def get_parser(self, parser): parser.add_argument( diff --git a/codegenerator/rust_tui.py b/codegenerator/rust_tui.py index edb08fb..91336dc 100644 --- a/codegenerator/rust_tui.py +++ b/codegenerator/rust_tui.py @@ -336,7 +336,7 @@ class TypeManager(common_rust.TypeManager): """Get all subtypes excluding TLA""" for k, v in self.refs.items(): if self.sdk_type_manager: - if k.name == "Body": + if k.name == self.root_name: sdk_type = self.sdk_type_manager.get_root_data_type() else: sdk_type = self.sdk_type_manager.refs[k] @@ -347,12 +347,12 @@ class TypeManager(common_rust.TypeManager): and isinstance( v, (common_rust.Enum, Struct, common_rust.StringEnum) ) - and k.name != "Body" + and k.name != self.root_name ): yield (v, sdk_type) elif ( k - and k.name != "Body" + and k.name != self.root_name and isinstance(v, self.option_type_class) ): if isinstance(v.item_type, common_rust.Enum): @@ -485,7 +485,7 @@ class ResponseTypeManager(common_rust.TypeManager): common_rust.Array, ), ) - and k.name != "Body" + and k.name != self.root_name ): key = v.base_type + v.type_hint if key not in emited_data: diff --git a/codegenerator/rust_types.py b/codegenerator/rust_types.py new file mode 100644 index 0000000..8ee29e5 --- /dev/null +++ b/codegenerator/rust_types.py @@ -0,0 +1,462 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +import logging +from pathlib import Path +import re +import subprocess +from typing import Type, Any + +from codegenerator.base import BaseGenerator +from codegenerator.common import BasePrimitiveType +from codegenerator.common import BaseCombinedType +from codegenerator.common import BaseCompoundType +from codegenerator import common +from codegenerator import model +from codegenerator.common import BaseCompoundType +from codegenerator.common import rust as common_rust + + +class IntString(common.BasePrimitiveType): + """CLI Integer or String""" + + imports: set[str] = {"crate::common::IntString"} + type_hint: str = "IntString" + clap_macros: set[str] = set() + + +class NumString(common.BasePrimitiveType): + """CLI Number or String""" + + imports: set[str] = {"crate::common::NumString"} + type_hint: str = "NumString" + clap_macros: set[str] = set() + + +class BoolString(common.BasePrimitiveType): + """CLI Boolean or String""" + + imports: set[str] = {"crate::common::BoolString"} + type_hint: str = "BoolString" + clap_macros: set[str] = set() + + +class ResponseTypeManager(common_rust.TypeManager): + primitive_type_mapping = {} + data_type_mapping = {model.Struct: common_rust.StructResponse} + + def get_model_name(self, model_ref: model.Reference | None) -> str: + """Get the localized model type name + + In order to avoid collision between structures in request and + response we prefix all types with `Response` + :returns str: Type name + """ + if not model_ref: + return self.root_name or "Response" + return "".join( + x.capitalize() + for x in re.split(common.SPLIT_NAME_RE, model_ref.name) + ) + + def _simplify_oneof_combinations(self, type_model, kinds): + """Simplify certain known oneOf combinations""" + kinds_classes = [x["class"] for x in kinds] + if ( + common_rust.String in kinds_classes + and common_rust.Number in kinds_classes + ): + # oneOf [string, number] => NumString + kinds.clear() + kinds.append({"local": NumString(), "class": NumString}) + elif ( + common_rust.String in kinds_classes + and common_rust.Integer in kinds_classes + ): + # oneOf [string, integer] => NumString + kinds.clear() + kinds.append({"local": IntString(), "class": IntString}) + elif ( + common_rust.String in kinds_classes + and common_rust.Boolean in kinds_classes + ): + # oneOf [string, boolean] => String + kinds.clear() + kinds.append({"local": BoolString(), "class": BoolString}) + super()._simplify_oneof_combinations(type_model, kinds) + + def _get_struct_type(self, type_model: model.Struct) -> common_rust.Struct: + """Convert model.Struct into Rust `Struct`""" + struct_class = self.data_type_mapping[model.Struct] + mod = struct_class( + name=self.get_model_name(type_model.reference), + description=common_rust.sanitize_rust_docstrings( + type_model.description + ), + ) + field_class = mod.field_type_class_ + for field_name, field in type_model.fields.items(): + is_nullable: bool = False + field_data_type = self.convert_model(field.data_type) + if isinstance(field_data_type, self.option_type_class): + # Unwrap Option into "is_nullable" NOTE: but perhaps + # Option