diff --git a/codegenerator/common/rust.py b/codegenerator/common/rust.py index 34ba88d..07def8f 100644 --- a/codegenerator/common/rust.py +++ b/codegenerator/common/rust.py @@ -604,6 +604,8 @@ class TypeManager: typ = self.primitive_type_mapping[ model.ConstraintInteger ]() + elif base_type is model.ConstraintNumber: + typ = self.primitive_type_mapping[model.ConstraintNumber]() elif base_type is model.PrimitiveBoolean: typ = self.primitive_type_mapping[model.PrimitiveBoolean]() @@ -808,28 +810,57 @@ class TypeManager: kinds.append(bck) def set_models(self, models): - """Process (translate) ADT models into Rust SDK style""" + """Process (translate) ADT models into Rust models""" self.models = models self.refs = {} self.ignored_models = [] # A dictionary of model names to references to assign unique names unique_models: dict[str, model.Reference] = {} + # iterate over all incoming models for model_ in models: + # convert ADT based model into rust saving the result under self.refs model_data_type = self.convert_model(model_) + # post process conversion results if not isinstance(model_data_type, BaseCompoundType): continue name = getattr(model_data_type, "name", None) if ( name and name in unique_models - and unique_models[name] != model_.reference + and unique_models[name].hash_ != model_.reference.hash_ ): - # There is already a model with this name. Try adding suffix from datatype name - new_name = name + model_data_type.__class__.__name__ + # There is already a model with this name. + if model_.reference and model_.reference.parent: + # Try adding parent_name as prefix + new_name = ( + "".join( + x.title() + for x in model_.reference.parent.name.split("_") + ) + + name + ) + else: + # Try adding suffix from datatype name + new_name = name + model_data_type.__class__.__name__ + logging.debug(f"rename {name} to {new_name}") + if new_name not in unique_models: # New name is still unused model_data_type.name = new_name unique_models[new_name] = model_.reference + # rename original model to the same naming scheme + other_model = unique_models.get(name) + if other_model and other_model.parent: + # Try adding parent_name as prefix + new_other_name = ( + "".join( + x.title() + for x in other_model.parent.name.split("_") + ) + + name + ) + other_model.name = new_other_name + unique_models[new_other_name] = other_model elif isinstance(model_data_type, Struct): # This is already an exceptional case (identity.mapping # with remote being oneOf with multiple structs) @@ -867,6 +898,13 @@ class TypeManager: raise RuntimeError( "Model name %s is already present" % new_name ) + elif ( + name + and name in unique_models + and unique_models[name].hash_ == model_.reference.hash_ + ): + # Ignore duplicated (or more precisely same) model + self.ignored_models.append(model_.reference) elif name: unique_models[name] = model_.reference diff --git a/codegenerator/model.py b/codegenerator/model.py index e4bc3af..9d95675 100644 --- a/codegenerator/model.py +++ b/codegenerator/model.py @@ -10,6 +10,9 @@ # License for the specific language governing permissions and limitations # under the License. # +# Reference.parent (Self) is only valid from py3.11. Till 3.11 is min we need to have this import +from __future__ import annotations + import copy import hashlib import json @@ -19,6 +22,7 @@ from typing import Type import typing as ty from pydantic import BaseModel +from pydantic import ConfigDict from codegenerator import common @@ -34,10 +38,13 @@ def dicthash_(data: dict[str, Any]) -> str: class Reference(BaseModel): """Reference of the complex type to the occurence instance""" + model_config = ConfigDict(arbitrary_types_allowed=True) + #: Name of the object that uses the type under reference name: str type: Type | None = None hash_: str | None = None + parent: Reference | None = None def __hash__(self): return hash((self.name, self.type, self.hash_)) @@ -188,7 +195,7 @@ class JsonSchemaParser: schema, results: list[ADT], name: str | None = None, - parent_name: str | None = None, + parent: Reference | None = None, min_ver: str | None = None, max_ver: str | None = None, ignore_read_only: bool | None = False, @@ -199,7 +206,7 @@ class JsonSchemaParser: schema, results, name=name, - parent_name=parent_name, + parent=parent, ignore_read_only=ignore_read_only, ) if isinstance(type_, list): @@ -207,7 +214,7 @@ class JsonSchemaParser: schema, results, name=name, - parent_name=parent_name, + parent=parent, ignore_read_only=ignore_read_only, ) if isinstance(type_, str): @@ -216,7 +223,7 @@ class JsonSchemaParser: schema, results, name=name, - parent_name=parent_name, + parent=parent, min_ver=min_ver, max_ver=max_ver, ignore_read_only=ignore_read_only, @@ -226,7 +233,7 @@ class JsonSchemaParser: schema, results, name=name, - parent_name=parent_name, + parent=parent, ignore_read_only=ignore_read_only, ) if type_ == "string": @@ -253,7 +260,7 @@ class JsonSchemaParser: schema, results, name=name, - parent_name=parent_name, + parent=parent, ignore_read_only=ignore_read_only, ) if "allOf" in schema: @@ -261,7 +268,7 @@ class JsonSchemaParser: schema, results, name=name, - parent_name=parent_name, + parent=parent, ignore_read_only=ignore_read_only, ) if not type_ and "properties" in schema: @@ -270,7 +277,7 @@ class JsonSchemaParser: schema, results, name=name, - parent_name=parent_name, + parent=parent, min_ver=min_ver, max_ver=max_ver, ignore_read_only=ignore_read_only, @@ -287,7 +294,7 @@ class JsonSchemaParser: schema, results: list[ADT], name: str | None = None, - parent_name: str | None = None, + parent: Reference | None = None, min_ver: str | None = None, max_ver: str | None = None, ignore_read_only: bool | None = False, @@ -321,6 +328,14 @@ class JsonSchemaParser: if properties: # `"type": "object", "properties": {...}}` obj = Struct() + if name: + obj.reference = Reference( + name=name, + type=obj.__class__, + hash_=dicthash_(schema), + parent=parent, + ) + for k, v in properties.items(): if k == "additionalProperties" and isinstance(v, bool): # Some schemas (in keystone) are Broken @@ -331,7 +346,7 @@ class JsonSchemaParser: v, results, name=k, - parent_name=name, + parent=obj.reference, min_ver=min_ver, max_ver=max_ver, ignore_read_only=ignore_read_only, @@ -409,7 +424,7 @@ class JsonSchemaParser: schema, results, name=name, - parent_name=parent_name, + parent=parent, ignore_read_only=ignore_read_only, ) elif "allOf" in schema: @@ -418,7 +433,7 @@ class JsonSchemaParser: schema, results, name=name, - parent_name=parent_name, + parent=parent, ignore_read_only=ignore_read_only, ) @@ -427,9 +442,12 @@ class JsonSchemaParser: if not obj: raise RuntimeError("Object %s is not supported", schema) - if name: + if name and not obj.reference: obj.reference = Reference( - name=name, type=obj.__class__, hash_=dicthash_(schema) + name=name, + type=obj.__class__, + hash_=dicthash_(schema), + parent=parent, ) if obj: @@ -443,8 +461,8 @@ class JsonSchemaParser: if x.reference ] ): - if obj.reference in [ - x.reference for x in results if x.reference + if obj.reference.__hash__() in [ + x.reference.__hash__() for x in results if x.reference ]: # This is already same object - we have luck and can # de-duplicate structures. It is at the moment the case in @@ -452,10 +470,11 @@ class JsonSchemaParser: # object present few times pass else: + logging.error(f"replace {obj.reference.name}") # Structure with the same name is already present. Prefix the # new one with the parent name - if parent_name and name: - new_name = parent_name + "_" + name + if parent and name: + new_name = parent.name + "_" + name if Reference( name=new_name, type=obj.reference.type @@ -471,7 +490,7 @@ class JsonSchemaParser: schema, results: list[ADT], name: str | None = None, - parent_name: str | None = None, + parent: Reference | None = None, ignore_read_only: bool | None = False, ): obj = OneOfType() @@ -494,7 +513,10 @@ class JsonSchemaParser: obj.kinds.append(kind_type) if name: obj.reference = Reference( - name=name, type=obj.__class__, hash_=dicthash_(schema) + name=name, + type=obj.__class__, + hash_=dicthash_(schema), + parent=parent, ) results.append(obj) return obj @@ -504,7 +526,7 @@ class JsonSchemaParser: schema, results: list[ADT], name: str | None = None, - parent_name: str | None = None, + parent: Reference | None = None, ignore_read_only: bool | None = False, ): if len(schema.get("type")) == 1: @@ -536,7 +558,10 @@ class JsonSchemaParser: obj.kinds.append(kind_type) if name: obj.reference = Reference( - name=name, type=obj.__class__, hash_=dicthash_(schema) + name=name, + type=obj.__class__, + hash_=dicthash_(schema), + parent=parent, ) results.append(obj) return obj @@ -546,7 +571,7 @@ class JsonSchemaParser: schema, results: list[ADT], name: str | None = None, - parent_name: str | None = None, + parent: Reference | None = None, ignore_read_only: bool | None = False, ): # todo: decide whether some constraints can be under items @@ -555,6 +580,7 @@ class JsonSchemaParser: results, name=name, ignore_read_only=ignore_read_only, + parent=parent, ) ref = getattr(item_type, "reference", None) if ref: @@ -563,7 +589,10 @@ class JsonSchemaParser: obj = Array(item_type=item_type) if name: obj.reference = Reference( - name=name, type=obj.__class__, hash_=dicthash_(schema) + name=name, + type=obj.__class__, + hash_=dicthash_(schema), + parent=parent, ) results.append(obj) return obj @@ -573,7 +602,7 @@ class JsonSchemaParser: schema, results: list[ADT], name: str | None = None, - parent_name: str | None = None, + parent: Reference | None = None, ignore_read_only: bool | None = False, ): # todo: decide whether some constraints can be under items @@ -585,12 +614,17 @@ class JsonSchemaParser: obj.base_types.append(ConstraintString) elif literal_type is int: obj.base_types.append(ConstraintInteger) + elif literal_type is float: + obj.base_types.append(ConstraintNumber) elif literal_type is bool: obj.base_types.append(PrimitiveBoolean) if name: obj.reference = Reference( - name=name, type=obj.__class__, hash_=dicthash_(schema) + name=name, + type=obj.__class__, + hash_=dicthash_(schema), + parent=parent, ) results.append(obj) return obj @@ -600,7 +634,7 @@ class JsonSchemaParser: schema, results: list[ADT], name: str | None = None, - parent_name: str | None = None, + parent: Reference | None = None, ignore_read_only: bool | None = False, ): sch = copy.deepcopy(schema) diff --git a/codegenerator/rust_cli.py b/codegenerator/rust_cli.py index 61d48b1..c971e76 100644 --- a/codegenerator/rust_cli.py +++ b/codegenerator/rust_cli.py @@ -215,6 +215,7 @@ class EnumGroupStruct(common_rust.Struct): sdk_enum_name: str is_group: bool = True is_required: bool = False + reference: model.Reference | None = None class StructFieldResponse(common_rust.StructField): @@ -493,11 +494,13 @@ class RequestTypeManager(common_rust.TypeManager): # On the SDK side where this method is not overriden there # would be a naming conflict resulting in `set_models` call # adding type name as a suffix. - sdk_enum_name = result.name + result.__class__.__name__ + # sdk_enum_name = result.name + result.__class__.__name__ + sdk_enum_name = self.get_model_name(type_model.reference) obj = EnumGroupStruct( name=self.get_model_name(type_model.reference), kinds={}, sdk_enum_name=sdk_enum_name, + reference=type_model.reference, ) field_class = obj.field_type_class_ if not type_model.reference: @@ -509,7 +512,7 @@ class RequestTypeManager(common_rust.TypeManager): field = field_class( local_name=f"{x.lower()}_{name}", remote_name=f"{v.data_type.name}::{x}", - sdk_parent_enum_variant=f"{sdk_enum_name}::{k}", + sdk_parent_enum_variant=f"{k}", data_type=BooleanFlag(), is_optional=False, is_nullable=False, @@ -518,7 +521,7 @@ class RequestTypeManager(common_rust.TypeManager): else: field = field_class( local_name=f"{name}", - remote_name=f"{sdk_enum_name}::{k}", + remote_name=f"{k}", data_type=v.data_type, is_optional=True, is_nullable=False, @@ -1042,7 +1045,10 @@ class RustCliGenerator(BaseGenerator): ) or param["in"] != "path": # Respect path params that appear in path and not path params param_ = openapi_parser.parse_parameter(param) - if param_.name == f"{resource_name}_id": + if param_.name in [ + f"{resource_name}_id", + f"{resource_name.replace('_', '')}_id", + ]: # for i.e. routers/{router_id} we want local_name to be `id` and not `router_id` param_.name = "id" operation_params.append(param_) diff --git a/codegenerator/rust_sdk.py b/codegenerator/rust_sdk.py index 9f5467f..b4db3a4 100644 --- a/codegenerator/rust_sdk.py +++ b/codegenerator/rust_sdk.py @@ -339,8 +339,11 @@ class RustSdkGenerator(BaseGenerator): ) or param["in"] != "path": # Respect path params that appear in path and not path params param_ = openapi_parser.parse_parameter(param) - if param_.name == f"{res_name}_id": - path = path.replace(f"{res_name}_id", "id") + if param_.name in [ + f"{res_name}_id", + f"{res_name.replace('_', '')}_id", + ]: + path = path.replace(param_.name, "id") # for i.e. routers/{router_id} we want local_name to be `id` and not `router_id` param_.name = "id" operation_params.append(param_) @@ -539,8 +542,11 @@ class RustSdkGenerator(BaseGenerator): if ("{" + param["name"] + "}") in path and param["in"] == "path": # Respect path params that appear in path and not in path params param_ = openapi_parser.parse_parameter(param) - if param_.name == f"{res_name}_id": - path = path.replace(f"{res_name}_id", "id") + if param_.name in [ + f"{res_name}_id", + f"{res_name.replace('_', '')}_id", + ]: + path = path.replace(param_.name, "id") # for i.e. routers/{router_id} we want local_name to be `id` and not `router_id` param_.name = "id" operation_path_params.append(param_) diff --git a/codegenerator/templates/rust_cli/set_body_parameters.j2 b/codegenerator/templates/rust_cli/set_body_parameters.j2 index d548f10..fbae51d 100644 --- a/codegenerator/templates/rust_cli/set_body_parameters.j2 +++ b/codegenerator/templates/rust_cli/set_body_parameters.j2 @@ -19,7 +19,7 @@ {%- for k, v in root_field.data_type.fields.items() %} {%- if v.is_optional %} if let Some(val) = &args.{{ v.local_name }} { - {{ macros.set_request_data_from_input(builder_name, v, "val") }} + {{ macros.set_request_data_from_input(type_manager, builder_name, v, "val") }} } {%- elif v.data_type.format is defined and v.data_type.format == "password" %} if let Some(val) = &args.{{ v.local_name }} { @@ -36,7 +36,7 @@ {{ builder_name }}.{{ v.remote_name }}(secret.to_string()); } {%- else %} - {{ macros.set_request_data_from_input(builder_name, v, "&args." + v.local_name) }} + {{ macros.set_request_data_from_input(type_manager, builder_name, v, "&args." + v.local_name) }} {%- endif %} {% endfor %} @@ -49,10 +49,10 @@ {%- if root_field.is_optional %} if let Some(arg) = &self.{{ root_field.local_name }} { - {{ macros.set_request_data_from_input("ep_builder", root_field, "arg") }} + {{ macros.set_request_data_from_input(type_manager, "ep_builder", root_field, "arg") }} } {%- else -%} - {{ macros.set_request_data_from_input("ep_builder", root_field, "&self." + root_field.local_name) }} + {{ macros.set_request_data_from_input(type_manager, "ep_builder", root_field, "&self." + root_field.local_name) }} {%- endif %} {%- endif %} diff --git a/codegenerator/templates/rust_cli/set_query_parameters.j2 b/codegenerator/templates/rust_cli/set_query_parameters.j2 index 1d77682..8864818 100644 --- a/codegenerator/templates/rust_cli/set_query_parameters.j2 +++ b/codegenerator/templates/rust_cli/set_query_parameters.j2 @@ -15,9 +15,9 @@ {%- endif %} {%- elif not v.is_required %} if let Some(val) = &self.query.{{ v.local_name }} { - {{ macros.set_request_data_from_input("ep_builder", v, "val")}} + {{ macros.set_request_data_from_input(type_manager, "ep_builder", v, "val")}} } {%- else %} - {{ macros.set_request_data_from_input("ep_builder", v, "&self.query." + v.local_name )}} + {{ macros.set_request_data_from_input(type_manager, "ep_builder", v, "&self.query." + v.local_name )}} {%- endif %} {%- endfor %} diff --git a/codegenerator/templates/rust_macros.j2 b/codegenerator/templates/rust_macros.j2 index ed854b6..fd93491 100644 --- a/codegenerator/templates/rust_macros.j2 +++ b/codegenerator/templates/rust_macros.j2 @@ -126,7 +126,7 @@ Some({{ val }}) {%- endmacro %} {#- Macros to render setting Request data from CLI input #} -{%- macro set_request_data_from_input(dst_var, param, val_var) %} +{%- macro set_request_data_from_input(manager, dst_var, param, val_var) %} {%- set is_nullable = param.is_nullable if param.is_nullable is defined else False %} {%- if param.type_hint in ["Option>", "Option>", "Option>"] %} @@ -173,7 +173,7 @@ Some({{ val }}) {%- if v.data_type.__class__.__name__ in ["Boolean", "BooleanFlag"] %} if {{ val_var | replace("&", "") }}.{{ v.local_name }} { {{ dst_var }}.{{ param.remote_name }}( - {{ sdk_mod_path[-1] }}::{{ v.sdk_parent_enum_variant }}( + {{ sdk_mod_path[-1] }}::{{ param.data_type.name }}::{{ v.sdk_parent_enum_variant }}( {{ sdk_mod_path[-1] }}::{{ v.remote_name }} ) ); @@ -181,7 +181,7 @@ Some({{ val }}) {%- elif v.data_type.__class__.__name__ == "ArrayInput" %} {% set original_type = v.data_type.original_item_type %} if let Some(data) = {{ val_var }}.{{ v.local_name }} { - {{ sdk_enum_array_setter(param, v, "data", dst_var) }} + {{ sdk_enum_array_setter(manager, param, v, "data", dst_var) }} } {%- endif %} {%- endfor %} @@ -192,10 +192,10 @@ Some({{ val }}) {%- for k, v in param.data_type.fields.items() %} {%- if v.is_optional %} if let Some(val) = &{{ val_var }}.{{ v.local_name }} { - {{ set_request_data_from_input(builder_name, v, "val") }} + {{ set_request_data_from_input(manager, builder_name, v, "val") }} } {%- else %} - {{ set_request_data_from_input(builder_name, v, "&" + val_var + "." + v.local_name) }} + {{ set_request_data_from_input(manager, builder_name, v, "&" + val_var + "." + v.local_name) }} {%- endif %} {%- endfor %} @@ -220,10 +220,10 @@ Some({{ val }}) {%- for k, v in param.data_type.item_type.fields.items() %} {%- if v.is_optional %} if let Some(val) = &l{{ param.local_name }}.{{ v.local_name }} { - {{ set_request_data_from_input(builder_name, v, "val") }} + {{ set_request_data_from_input(manager, builder_name, v, "val") }} } {%- else %} - {{ set_request_data_from_input(builder_name, v, "&l" + param.local_name + "." + v.local_name) }} + {{ set_request_data_from_input(manager, builder_name, v, "&l" + param.local_name + "." + v.local_name) }} {%- endif %} {%- endfor %} @@ -261,7 +261,7 @@ Some({{ val }}) {%- endif %} {%- endmacro %} -{%- macro sdk_enum_array_setter(param, field, val_var, dst_var) %} +{%- macro sdk_enum_array_setter(manager, param, field, val_var, dst_var) %} {#- Set sdk array from cli input -#} {%- set original_type = field.data_type.original_data_type %} {%- if field.data_type.item_type.__class__.__name__ == "JsonValue" and original_type.__class__.__name__ == "StructInput" %} @@ -272,7 +272,7 @@ Some({{ val }}) serde_json::from_value::<{{ sdk_mod_path[-1] }}::{{ original_type.name }}>(v.to_owned())) .collect(); {{ dst_var }}.{{ param.remote_name }}( - {{ sdk_mod_path[-1] }}::{{ field.remote_name }}({{ builder_name }}) + {{ sdk_mod_path[-1] }}::{{ param.data_type.name }}::{{ field.remote_name }}({{ builder_name }}) ); {%- else %} {#- Normal array #}